我一直在训练我的模型,方法是使用存储在hdf5文件中的数据的训练和测试生成器来填充fit()
方法。(约25,000张图片和标签)。我最近已经处理了负面的情况下到一个新的hdf5文件与类似数量的图像,然而,在更新生成器读取两个文件后,抓住一半的批量大小数量的图像从每一套,并且将它们合并在一起,则在单个时期之后训练崩溃为Invalid argument: required broadcastable shapes at loc(unknown)
。
我已经确保模型输出、生成器输出和数据类型都是正确的(模型:UNet,sigmoid,classes=1,output shape =(...,1),output type = bool),正如同一问题的其他答案所暗示的那样,但我仍然得到同样的错误。
训练.py
db = h5py.File(db_output_path, 'r')
a = db['data'][200]
b = db['labels'][200]
db_neg = h5py.File(db_negatives_path, 'r')
train_neg_gen = kfold.split(db_neg['data'])
neg_idx = []
for t in train_neg_gen:
neg_idx.append(t)
batch_size=16
for train, test in kfold.split(db['data'], db['labels']):
train_neg_idx, test_neg_idx = neg_idx[fold_no-1]
gen_train = create_hdf5_generator(db_output_path, train, batch_size, CLASSES, db_negatives_path, train_neg_idx)
gen_val = create_hdf5_generator(db_output_path, test, batch_size, CLASSES, db_negatives_path, test_neg_idx)
model.load_weights('weights/weights_2022-11-20.h5')
# Generate a print
print('------------------------------------------------------------------------')
print(f'Training for fold {fold_no} ...')
steps_per_epoch = (2*len(train))//batch_size
validation_steps= (2*len(test))//batch_size
results = model.fit(gen_train,
epochs=10, validation_data=gen_val,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
callbacks=callbacks)
# Increase fold number
fold_no = fold_no + 1
发电机
def create_hdf5_generator(db_path, indices, batch_size, classes, neg_db_path=None, neg_indices=None):
db = h5py.File(db_path)
neg_db = h5py.File(neg_db_path)
while True:
if neg_indices is not None:
skip = batch_size//2
restart = 0
for i in np.arange(0, len(indices), skip):
j = i
#j tracks neg_db indices which is smaller in size than positive indices tracked by i
if i >= len(neg_indices):
j = restart
restart += skip
images = db['data'][indices[i:i+skip]]
labels = db['labels'][indices[i:i+skip]]
neg_images = neg_db['data'][neg_indices[j:j+skip]]
neg_labels = np.zeros(labels.shape).astype(np.float32)
images_concat = np.concatenate((images, neg_images), axis=0)
labels_concat = np.concatenate((labels, neg_labels), axis=0)
np.random.seed(123)
np.random.shuffle(images_concat)
np.random.seed(123)
np.random.shuffle(labels_concat)
yield images_concat, labels_concat.astype(bool)
控制台输出
------------------------------------------------------------------------
Training for fold 1 ...
Epoch 1/10
2773/2774 [============================>.] - ETA: 0s - loss: 0.1157 - mean_io_u_2: 0.4766 Traceback (most recent call last):
File "C:\Users\Noam\github\proj\train.py", line 181, in <module>
results = model.fit(gen_train,
File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1214, in fit
val_logs = self.evaluate(
File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1489, in evaluate
tmp_logs = self.test_function(iterator)
File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 889, in __call__
result = self._call(*args, **kwds)
File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 924, in _call
results = self._stateful_fn(*args, **kwds)
File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 3023, in __call__
return graph_function._call_flat(
File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1960, in _call_flat
return self._build_call_outputs(self._inference_function.call(
File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 591, in call
outputs = execute.execute(
File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: required broadcastable shapes at loc(unknown)
[[node binary_crossentropy/logistic_loss/mul (defined at C:\Users\Noam\github\proj\train.py:181) ]]
[[confusion_matrix/assert_non_negative_1/assert_less_equal/Assert/AssertGuard/pivot_f/_12/_33]]
(1) Invalid argument: required broadcastable shapes at loc(unknown)
[[node binary_crossentropy/logistic_loss/mul (defined at C:\Users\Noam\github\proj\train.py:181) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_test_function_79850]
Function call stack:
test_function -> test_function
2022-11-27 19:22:08.581553: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cudart64_110.dll
2022-11-27 19:22:18.055899: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library nvcuda.dll
2022-11-27 19:22:18.073779: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties:
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.8GHz coreCount: 82 deviceMemorySize: 24.00GiB deviceMemoryBandwidth: 871.81GiB/s
2022-11-27 19:22:18.073819: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cudart64_110.dll
2022-11-27 19:22:18.093917: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublas64_11.dll
2022-11-27 19:22:18.093939: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublasLt64_11.dll
2022-11-27 19:22:18.100311: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cufft64_10.dll
2022-11-27 19:22:18.102617: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library curand64_10.dll
2022-11-27 19:22:18.105904: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cusolver64_11.dll
2022-11-27 19:22:18.111640: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cusparse64_11.dll
2022-11-27 19:22:18.112034: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cudnn64_8.dll
2022-11-27 19:22:18.112100: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2022-11-27 19:22:18.112463: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-27 19:22:18.113094: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties:
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.8GHz coreCount: 82 deviceMemorySize: 24.00GiB deviceMemoryBandwidth: 871.81GiB/s
2022-11-27 19:22:18.113127: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2022-11-27 19:22:18.495306: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2022-11-27 19:22:18.495334: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264] 0
2022-11-27 19:22:18.495341: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0: N
2022-11-27 19:22:18.495486: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 21670 MB memory) -> physical GPU (device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6)
2022-11-27 19:22:21.753068: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2022-11-27 19:22:23.357640: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cudnn64_8.dll
2022-11-27 19:22:23.868767: I tensorflow/stream_executor/cuda/cuda_dnn.cc:359] Loaded cuDNN version 8201
2022-11-27 19:22:24.730172: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublas64_11.dll
2022-11-27 19:22:25.324257: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublasLt64_11.dll
2022-11-27 19:23:30.675901: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 19:29:53.026090: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 19:46:47.257803: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 19:50:09.871857: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 19:51:28.339643: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 20:22:00.445508: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 20:30:20.786297: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 20:45:59.779202: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 21:06:14.203518: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
联合国网络
sigmoid
binary_crossentropy
Model: "model_3"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_4 (InputLayer) [(None, 128, 128, 3) 0
__________________________________________________________________________________________________
conv2d_57 (Conv2D) (None, 128, 128, 32) 896 input_4[0][0]
__________________________________________________________________________________________________
dropout_27 (Dropout) (None, 128, 128, 32) 0 conv2d_57[0][0]
__________________________________________________________________________________________________
conv2d_58 (Conv2D) (None, 128, 128, 32) 9248 dropout_27[0][0]
__________________________________________________________________________________________________
max_pooling2d_12 (MaxPooling2D) (None, 64, 64, 32) 0 conv2d_58[0][0]
__________________________________________________________________________________________________
conv2d_59 (Conv2D) (None, 64, 64, 64) 18496 max_pooling2d_12[0][0]
__________________________________________________________________________________________________
dropout_28 (Dropout) (None, 64, 64, 64) 0 conv2d_59[0][0]
__________________________________________________________________________________________________
conv2d_60 (Conv2D) (None, 64, 64, 64) 36928 dropout_28[0][0]
__________________________________________________________________________________________________
max_pooling2d_13 (MaxPooling2D) (None, 32, 32, 64) 0 conv2d_60[0][0]
__________________________________________________________________________________________________
conv2d_61 (Conv2D) (None, 32, 32, 128) 73856 max_pooling2d_13[0][0]
__________________________________________________________________________________________________
dropout_29 (Dropout) (None, 32, 32, 128) 0 conv2d_61[0][0]
__________________________________________________________________________________________________
conv2d_62 (Conv2D) (None, 32, 32, 128) 147584 dropout_29[0][0]
__________________________________________________________________________________________________
max_pooling2d_14 (MaxPooling2D) (None, 16, 16, 128) 0 conv2d_62[0][0]
__________________________________________________________________________________________________
conv2d_63 (Conv2D) (None, 16, 16, 256) 295168 max_pooling2d_14[0][0]
__________________________________________________________________________________________________
dropout_30 (Dropout) (None, 16, 16, 256) 0 conv2d_63[0][0]
__________________________________________________________________________________________________
conv2d_64 (Conv2D) (None, 16, 16, 256) 590080 dropout_30[0][0]
__________________________________________________________________________________________________
max_pooling2d_15 (MaxPooling2D) (None, 8, 8, 256) 0 conv2d_64[0][0]
__________________________________________________________________________________________________
conv2d_65 (Conv2D) (None, 8, 8, 512) 1180160 max_pooling2d_15[0][0]
__________________________________________________________________________________________________
dropout_31 (Dropout) (None, 8, 8, 512) 0 conv2d_65[0][0]
__________________________________________________________________________________________________
conv2d_66 (Conv2D) (None, 8, 8, 512) 2359808 dropout_31[0][0]
__________________________________________________________________________________________________
conv2d_transpose_12 (Conv2DTran (None, 16, 16, 256) 524544 conv2d_66[0][0]
__________________________________________________________________________________________________
concatenate_12 (Concatenate) (None, 16, 16, 512) 0 conv2d_transpose_12[0][0]
conv2d_64[0][0]
__________________________________________________________________________________________________
conv2d_67 (Conv2D) (None, 16, 16, 256) 1179904 concatenate_12[0][0]
__________________________________________________________________________________________________
dropout_32 (Dropout) (None, 16, 16, 256) 0 conv2d_67[0][0]
__________________________________________________________________________________________________
conv2d_68 (Conv2D) (None, 16, 16, 256) 590080 dropout_32[0][0]
__________________________________________________________________________________________________
conv2d_transpose_13 (Conv2DTran (None, 32, 32, 128) 131200 conv2d_68[0][0]
__________________________________________________________________________________________________
concatenate_13 (Concatenate) (None, 32, 32, 256) 0 conv2d_transpose_13[0][0]
conv2d_62[0][0]
__________________________________________________________________________________________________
conv2d_69 (Conv2D) (None, 32, 32, 128) 295040 concatenate_13[0][0]
__________________________________________________________________________________________________
dropout_33 (Dropout) (None, 32, 32, 128) 0 conv2d_69[0][0]
__________________________________________________________________________________________________
conv2d_70 (Conv2D) (None, 32, 32, 128) 147584 dropout_33[0][0]
__________________________________________________________________________________________________
conv2d_transpose_14 (Conv2DTran (None, 64, 64, 64) 32832 conv2d_70[0][0]
__________________________________________________________________________________________________
concatenate_14 (Concatenate) (None, 64, 64, 128) 0 conv2d_transpose_14[0][0]
conv2d_60[0][0]
__________________________________________________________________________________________________
conv2d_71 (Conv2D) (None, 64, 64, 64) 73792 concatenate_14[0][0]
__________________________________________________________________________________________________
dropout_34 (Dropout) (None, 64, 64, 64) 0 conv2d_71[0][0]
__________________________________________________________________________________________________
conv2d_72 (Conv2D) (None, 64, 64, 64) 36928 dropout_34[0][0]
__________________________________________________________________________________________________
conv2d_transpose_15 (Conv2DTran (None, 128, 128, 32) 8224 conv2d_72[0][0]
__________________________________________________________________________________________________
concatenate_15 (Concatenate) (None, 128, 128, 64) 0 conv2d_transpose_15[0][0]
conv2d_58[0][0]
__________________________________________________________________________________________________
conv2d_73 (Conv2D) (None, 128, 128, 32) 18464 concatenate_15[0][0]
__________________________________________________________________________________________________
dropout_35 (Dropout) (None, 128, 128, 32) 0 conv2d_73[0][0]
__________________________________________________________________________________________________
conv2d_74 (Conv2D) (None, 128, 128, 32) 9248 dropout_35[0][0]
__________________________________________________________________________________________________
conv2d_75 (Conv2D) (None, 128, 128, 1) 33 conv2d_74[0][0]
==================================================================================================
Total params: 7,760,097
Trainable params: 7,760,097
Non-trainable params: 0
1条答案
按热度按时间bf1o4zei1#
经过一些调试后,错误出现在发生器的一个输出形状中。我总是保证
neg_labels
与labels
具有相同的形状,即使neg_images
可能不在第零个轴上。修复方法是将
neg_labels
的形状设置为neg_images
在前三个轴和labels
最后一个轴上的形状: