我正在尝试实现keras代码,以便在https://keras.io/examples/vision/supervised-contrastive-learning/ 使用自定义数据集。然而,keras示例使用cifar-10数据集,并使用以下提到的大小和形状的训练和测试输入,训练模型以优化稀疏分类交叉熵损失:
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)
在我的例子中,我有一个热编码标签,因此大小和形状如下所示,我倾向于优化分类交叉熵损失:
x_train shape: (1919, 256, 256, 3) - y_train shape: (1919, 2)
x_test shape: (476, 256, 256, 3) - y_test shape: (476, 2)
我在这些数据上训练了一个基于基本vgg-16的分类器,该分类器具有两个输出节点和softmax激活,以最小化分类交叉熵损失,并获得了一些分类性能。
然后,我按照keras示例代码执行监督对比学习。在第一阶段,我在最深的卷积层截断vgg16模型,并添加投影头以优化监督对比损失。完整代码如下所示:
# Supervised contrastive learning loss function
class SupervisedContrastiveLoss(keras.losses.Loss):
def __init__(self, temperature=1, name=None):
super(SupervisedContrastiveLoss, self).__init__(name=name)
self.temperature = temperature
def __call__(self, labels, feature_vectors, sample_weight=None):
# Normalize feature vectors
feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)
# Compute logits
logits = tf.divide(
tf.matmul(
feature_vectors_normalized, tf.transpose(feature_vectors_normalized)
),
self.temperature,
)
return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
#%%
#add projection head and train model
vgg16 = keras.applications.VGG16(include_top=False, weights='imagenet',
input_tensor=model_input)
base_model_vgg16=Model(inputs=vgg16.input,outputs=vgg16.get_layer('block5_conv3').output)
x = base_model_vgg16.output
x = GlobalAveragePooling2D()(x)
outputs = layers.Dense(projection_units, activation="relu")(x)
encoder_with_projection_head = keras.Model(inputs=vgg16.input, outputs=outputs,
name="vgg16-encoder_with_projection-head")
encoder_with_projection_head.summary()
#%%
#train the model
sgd = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
encoder_with_projection_head.compile(optimizer=sgd,loss=SupervisedContrastiveLoss(temperature))
filepath = 'weights1/' + encoder_with_projection_head.name + '.{epoch:02d}-{val_loss:.4f}.h5'
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1,
save_weights_only=False, save_best_only=True,
mode='min', save_freq='epoch')
callbacks_list = [checkpoint]
t=time.time()
encoder_with_projection_head_history = encoder_with_projection_head.fit(
datagen.flow(X_train, Y_train,batch_size=batch_size),
steps_per_epoch=X_train.shape[0] // batch_size,
callbacks=callbacks_list,
epochs=epochs,
shuffle=True,
verbose=1,
validation_data=(X_test, Y_test))
print('Training time: %s' % (time.time()-t))
唯一的区别是我的代码使用了一个热编码标签,但示例中没有。运行代码时,出现以下错误:
Epoch 1/32
Traceback (most recent call last):
File "C:\Users\xxx\code1.py", line 433, in <module>
validation_data=(X_test, Y_test))
File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1100, in fit
tmp_logs = self.train_function(iterator)
File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\def_function.py", line 828, in __call__
result = self._call(*args,**kwds)
File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\def_function.py", line 888, in _call
return self._stateless_fn(*args,**kwds)
File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\function.py", line 2943, in __call__
filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\function.py", line 1919, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\function.py", line 560, in call
ctx=ctx)
File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
InvalidArgumentError: logits and labels must be broadcastable: logits_size=[16,16] labels_size=[32,16]
[[{{node PartitionedCall/softmax_cross_entropy_with_logits_1}}]] [Op:__inference_train_function_3780]
Function call stack:
train_function
我怀疑“SupervisedControlLoss”类需要修改以支持使用一个热编码标签的培训。
暂无答案!
目前还没有任何答案,快来回答吧!