gan生成器模型不保存

cclgggtu  于 2021-09-08  发布在  Java
关注(0)|答案(0)|浏览(337)

我训练这个模型没有问题,但是当要保存它的时候,我做不到。出于某种原因,我的模型在某一点上认为Tensor是非类型的。

  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from keras.layers import *
  4. from keras.models import *
  5. import keras.backend as K
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from PIL import Image
  9. import os
  10. import random
  11. import pickle
  12. from google.colab import drive
  13. drive.mount('/content/drive')
  14. BATCH_SIZE = 16
  15. LR = 1e-4
  16. EPOCHS = 100
  17. DATASET_DIR = "drive/MyDrive/Imagine/images/"
  18. SUBDATASET_SIZE = 64
  19. filenames = os.listdir(DATASET_DIR)
  20. MAX_DATASET = len( filenames) // SUBDATASET_SIZE
  21. def AdaIN(x):
  22. mean = K.mean(x[0], axis = [0,1],keepdims=True)
  23. std = K.std(x[0], axis = [0,1], keepdims = True )
  24. y = (x[0] - mean) / std
  25. pool_size = [-1,1,1,y.shape[-1]]
  26. scale = K.reshape(x[1],pool_size)
  27. bias = K.reshape(x[2],pool_size)
  28. #print(x[0].shape)
  29. #print(scale.shape)
  30. #print(bias.shape)
  31. return y * scale + bias
  32. def fit(x):
  33. height = x[1].shape[1]
  34. width = x[1].shape[2]
  35. #print("input_noise:",x[0].shape)
  36. #print("input:",x[1].shape)
  37. return x[0][:,height*2,width*2,:]
  38. def g_block(x,latent,input_noise,filters,kernel_size,stride):
  39. #print(i)
  40. #print(x.shape)
  41. gamma = Dense(filters)(latent)
  42. beta = Dense(filters)(latent)
  43. noise = Lambda(fit)([input_noise,x,i])
  44. noise = Dense(filters)(noise)
  45. #print(x.shape)
  46. out = UpSampling2D()(x)
  47. out = Conv2DTranspose(filters,kernel_size,stride)(out)
  48. out = add([out,noise])
  49. out = Lambda(AdaIN)([out,gamma,beta])
  50. out = LeakyReLU()(out)
  51. return out
  52. latent_input = Input([256])
  53. noise_vec = Input([541,961,1])
  54. latent = Dense(256,activation="relu")(latent_input)
  55. latent = Dense(256,activation="relu")(latent)
  56. latent = Dense(256,activation="relu")(latent)
  57. tensor = Dense(1)(latent_input)
  58. tensor = Lambda(lambda x: x * 0 + 1)(tensor)
  59. tensor = Dense(2*6*256,activation="relu")(tensor)
  60. tensor = Reshape((2,6,256))(tensor)
  61. print(tensor.shape)
  62. tensor = g_block(tensor,latent,noise_vec,256,(2,3),(2,1))
  63. tensor = g_block(tensor,latent,noise_vec,128,3,2)
  64. tensor = g_block(tensor,latent,noise_vec,64,(2,3),1)
  65. tensor = g_block(tensor,latent,noise_vec,32,(2,5),1)
  66. tensor = g_block(tensor,latent,noise_vec,16,(1,5),1)
  67. tensor = g_block(tensor,latent,noise_vec,8,(1,9),1)
  68. # tensor = UpSampling2D()(tensor)
  69. tensor = Conv2D(3,1)(tensor)
  70. output = Activation('sigmoid')(tensor)
  71. generator = Model(inputs=[noise_vec,latent_input], outputs=output)
  72. generator.summary()
  73. generator.save('drive/MyDrive/generator_1')

错误:

  1. TypeError Traceback (most recent call last)
  2. <ipython-input-219-42b9e6c20aef> in <module>()
  3. 29 generator = Model(inputs=[noise_vec,latent_input], outputs=output)
  4. 30 generator.summary()
  5. ---> 31 generator.save('drive/MyDrive/generator_1')
  6. 32
  7. 75 frames
  8. <ipython-input-215-e7978059de2a> in fit(x)
  9. return x[0][:,height*2,width*2,:]
  10. TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题