我正在尝试从Github训练3D分割网络。我的模型是由Keras(Python)实现的,这是一个典型的U-Net模型。该模型总结如下,
Model: "functional_3"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 128, 128, 4) 0
__________________________________________________________________________________________________
gaussian_noise (GaussianNoise) (None, 128, 128, 4) 0 input_1[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 128, 128, 64) 1088 gaussian_noise[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 64) 256 conv2d[0][0]
__________________________________________________________________________________________________
p_re_lu (PReLU) (None, 128, 128, 64) 64 batch_normalization[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 128, 128, 64) 36928 p_re_lu[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 64) 256 conv2d_1[0][0]
__________________________________________________________________________________________________
p_re_lu_1 (PReLU) (None, 128, 128, 64) 64 batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 128, 128, 64) 36928 p_re_lu_1[0][0]
__________________________________________________________________________________________________
add (Add) (None, 128, 128, 64) 0 conv2d[0][0]
conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 64, 64, 128) 32896 add[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 64, 64, 128) 512 conv2d_3[0][0]
__________________________________________________________________________________________________
p_re_lu_2 (PReLU) (None, 64, 64, 128) 128 batch_normalization_2[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 64, 64, 128) 147584 p_re_lu_2[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 64, 64, 128) 512 conv2d_4[0][0]
__________________________________________________________________________________________________
p_re_lu_3 (PReLU) (None, 64, 64, 128) 128 batch_normalization_3[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 64, 64, 128) 147584 p_re_lu_3[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 64, 64, 128) 0 conv2d_3[0][0]
conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 32, 32, 256) 131328 add_1[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 32, 32, 256) 1024 conv2d_6[0][0]
__________________________________________________________________________________________________
p_re_lu_4 (PReLU) (None, 32, 32, 256) 256 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 32, 32, 256) 590080 p_re_lu_4[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32, 32, 256) 1024 conv2d_7[0][0]
__________________________________________________________________________________________________
p_re_lu_5 (PReLU) (None, 32, 32, 256) 256 batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 32, 32, 256) 590080 p_re_lu_5[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 32, 32, 256) 0 conv2d_6[0][0]
conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 16, 16, 512) 524800 add_2[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 16, 16, 512) 2048 conv2d_9[0][0]
__________________________________________________________________________________________________
p_re_lu_6 (PReLU) (None, 16, 16, 512) 512 batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 16, 16, 512) 2359808 p_re_lu_6[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 16, 16, 512) 2048 conv2d_10[0][0]
__________________________________________________________________________________________________
p_re_lu_7 (PReLU) (None, 16, 16, 512) 512 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 16, 16, 512) 2359808 p_re_lu_7[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 16, 16, 512) 0 conv2d_9[0][0]
conv2d_11[0][0]
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D) (None, 32, 32, 512) 0 add_3[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 32, 32, 256) 524544 up_sampling2d[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 32, 32, 512) 0 add_2[0][0]
conv2d_12[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 32, 32, 512) 2048 concatenate[0][0]
__________________________________________________________________________________________________
p_re_lu_8 (PReLU) (None, 32, 32, 512) 512 batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 32, 32, 256) 1179904 p_re_lu_8[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 32, 32, 256) 1024 conv2d_13[0][0]
__________________________________________________________________________________________________
p_re_lu_9 (PReLU) (None, 32, 32, 256) 256 batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 32, 32, 256) 131072 concatenate[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 32, 32, 256) 590080 p_re_lu_9[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 32, 32, 256) 0 conv2d_15[0][0]
conv2d_14[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 64, 64, 256) 0 add_4[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 64, 64, 128) 131200 up_sampling2d_1[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 64, 64, 256) 0 add_1[0][0]
conv2d_16[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 64, 64, 256) 1024 concatenate_1[0][0]
__________________________________________________________________________________________________
p_re_lu_10 (PReLU) (None, 64, 64, 256) 256 batch_normalization_10[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 64, 64, 128) 295040 p_re_lu_10[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 64, 64, 128) 512 conv2d_17[0][0]
__________________________________________________________________________________________________
p_re_lu_11 (PReLU) (None, 64, 64, 128) 128 batch_normalization_11[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 64, 64, 128) 32768 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 64, 64, 128) 147584 p_re_lu_11[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 64, 64, 128) 0 conv2d_19[0][0]
conv2d_18[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 128, 128, 128 0 add_5[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 128, 128, 64) 32832 up_sampling2d_2[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 128, 128, 128 0 add[0][0]
conv2d_20[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 128, 128, 128 512 concatenate_2[0][0]
__________________________________________________________________________________________________
p_re_lu_12 (PReLU) (None, 128, 128, 128 128 batch_normalization_12[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 128, 128, 64) 73792 p_re_lu_12[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 128, 128, 64) 256 conv2d_21[0][0]
__________________________________________________________________________________________________
p_re_lu_13 (PReLU) (None, 128, 128, 64) 64 batch_normalization_13[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 128, 128, 64) 8192 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 128, 128, 64) 36928 p_re_lu_13[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 128, 128, 64) 0 conv2d_23[0][0]
conv2d_22[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 128, 128, 64) 256 add_6[0][0]
__________________________________________________________________________________________________
p_re_lu_14 (PReLU) (None, 128, 128, 64) 64 batch_normalization_14[0][0]
__________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, 128, 128, 4) 260 p_re_lu_14[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 128, 128, 4) 0 conv2d_24[0][0]
==================================================================================================
Total params: 10,159,748
Trainable params: 10,153,092
Non-trainable params: 6,656
__________________________________________________________________________________________________
我的训练文件在(batch, Height, Width, Channel)
中输入形状。我将训练图像和标签保存在两个Numpy文件(.npy)中。其中,x_training.npy
包含图像(形状:(20,128,128,4))和y_training.npy
包含图像的标签(形状:(20,128,128,4))。然后我使用自定义数据生成器读取数据。
def img_msk_gen(X33_train,Y_train,seed):
'''
a custom generator that performs data augmentation on both patches and their corresponding targets (masks)
'''
datagen = ImageDataGenerator(horizontal_flip=True,data_format="channels_last")
datagen_msk = ImageDataGenerator(horizontal_flip=True,data_format="channels_last")
image_generator = datagen.flow(X33_train,batch_size=4,seed=seed)
y_generator = datagen_msk.flow(Y_train,batch_size=4,seed=seed)
while True:
yield(image_generator.next(), y_generator.next())
最后,我试着训练我的模型
#load data from disk
X_patches=np.load("./x_training.npy").astype(np.float32)
Y_labels_valid=np.load("./y_training.npy").astype(np.float32)
X33_train=X_patches
Y_train=Y_labels
train_generator=img_msk_gen(X33_train=X_patches,Y_train=Y_labels,seed= 9999)
model.fit_generator(train_generator,steps_per_epoch=len(X33_train)//batch_size,
verbose=1)
但是,它抛出一个错误,像这样...
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got [1, 3]
如果你有任何建议或想法,这将有助于我。***我在colab中的完整模型实现是here,在Google Drive中的数据是here。***虽然类似类型的问题可用,但我无法解决我的问题。任何形式的帮助都将不胜感激。谢谢,提前。
2条答案
按热度按时间lhcgjxsq1#
错误直接说:你给予[1,3],这是一个列表,它期望一个数字或一个切片。
也许你的意思是[1:3]?
你似乎给予了[1,3],所以也许应该改变:
到
这至少是有效的语法,我不确定它是否符合你的要求。
am46iovg2#
您可以使用Label encoder来拟合和转换Y_train(目标变量)以修复此错误。