在keras中使用SupervisedControlLoss时的invalidargumenterror

zqdjd7g9  于 2021-09-08  发布在  Java
关注(0)|答案(0)|浏览(219)

我正在尝试实现keras代码,以便在https://keras.io/examples/vision/supervised-contrastive-learning/ 使用自定义数据集。然而,keras示例使用cifar-10数据集,并使用以下提到的大小和形状的训练和测试输入,训练模型以优化稀疏分类交叉熵损失:

x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)

在我的例子中,我有一个热编码标签,因此大小和形状如下所示,我倾向于优化分类交叉熵损失:

x_train shape: (1919, 256, 256, 3) - y_train shape: (1919, 2)
x_test shape: (476, 256, 256, 3) - y_test shape: (476, 2)

我在这些数据上训练了一个基于基本vgg-16的分类器,该分类器具有两个输出节点和softmax激活,以最小化分类交叉熵损失,并获得了一些分类性能。
然后,我按照keras示例代码执行监督对比学习。在第一阶段,我在最深的卷积层截断vgg16模型,并添加投影头以优化监督对比损失。完整代码如下所示:


# Supervised contrastive learning loss function

    class SupervisedContrastiveLoss(keras.losses.Loss):
        def __init__(self, temperature=1, name=None):
            super(SupervisedContrastiveLoss, self).__init__(name=name)
            self.temperature = temperature

        def __call__(self, labels, feature_vectors, sample_weight=None):
            # Normalize feature vectors
            feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)
            # Compute logits
            logits = tf.divide(
                tf.matmul(
                    feature_vectors_normalized, tf.transpose(feature_vectors_normalized)
                ),
                self.temperature,
            )
            return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
    #%%
    #add projection head and train model
    vgg16 = keras.applications.VGG16(include_top=False, weights='imagenet', 
                                     input_tensor=model_input)
    base_model_vgg16=Model(inputs=vgg16.input,outputs=vgg16.get_layer('block5_conv3').output)
    x = base_model_vgg16.output
    x = GlobalAveragePooling2D()(x)
    outputs = layers.Dense(projection_units, activation="relu")(x)
    encoder_with_projection_head = keras.Model(inputs=vgg16.input, outputs=outputs, 
                                               name="vgg16-encoder_with_projection-head")
    encoder_with_projection_head.summary()   

    #%%
    #train the model
    sgd = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)  
    encoder_with_projection_head.compile(optimizer=sgd,loss=SupervisedContrastiveLoss(temperature))
    filepath = 'weights1/' + encoder_with_projection_head.name + '.{epoch:02d}-{val_loss:.4f}.h5'
    checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, 
                                  save_weights_only=False, save_best_only=True, 
                                  mode='min', save_freq='epoch')
    callbacks_list = [checkpoint]
    t=time.time()
    encoder_with_projection_head_history = encoder_with_projection_head.fit(
datagen.flow(X_train, Y_train,batch_size=batch_size),
steps_per_epoch=X_train.shape[0] // batch_size,
callbacks=callbacks_list,
epochs=epochs,
shuffle=True,
verbose=1,
validation_data=(X_test, Y_test))

print('Training time: %s' % (time.time()-t))

唯一的区别是我的代码使用了一个热编码标签,但示例中没有。运行代码时,出现以下错误:

Epoch 1/32
Traceback (most recent call last):

  File "C:\Users\xxx\code1.py", line 433, in <module>
    validation_data=(X_test, Y_test))

  File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)

  File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\def_function.py", line 828, in __call__
    result = self._call(*args,**kwds)

  File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\def_function.py", line 888, in _call
    return self._stateless_fn(*args,**kwds)

  File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\function.py", line 2943, in __call__
    filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access

  File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\function.py", line 1919, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))

  File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\function.py", line 560, in call
    ctx=ctx)

  File "c:\users\xxx\appdata\local\continuum\anaconda3\envs\tf_2.4\lib\site-packages\tensorflow\python\eager\execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)

InvalidArgumentError:  logits and labels must be broadcastable: logits_size=[16,16] labels_size=[32,16]
     [[{{node PartitionedCall/softmax_cross_entropy_with_logits_1}}]] [Op:__inference_train_function_3780]

Function call stack:
train_function

我怀疑“SupervisedControlLoss”类需要修改以支持使用一个热编码标签的培训。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题