连接最后4层我该如何塑造以融入模型?

jyztefdp  于 2021-08-20  发布在  Java
关注(0)|答案(0)|浏览(278)

我正在尝试连接最后四个隐藏层。但是,我得到了一个错误:
图层meddra\u pt的输入0与图层不兼容::预期的最小值ndim=2,找到的ndim=1。收到完整形状:(3072,)


# Name of the BERT model to use

model_name = 'bert-base-uncased'

# Max length of tokens

max_length = 76

# Load transformers config

config = BertConfig.from_pretrained(model_name)
config.output_hidden_states = True

# Load BERT tokenizer

tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path = model_name, config = config)

# Load the Transformers BERT model

transformer_model = TFBertModel.from_pretrained(model_name, config = config)

# Input

input_ids = Input(shape=(max_length,), name='input_ids', dtype='int32')
attention_mask = Input(shape=(max_length,), name='attention_mask', dtype='int32')

# Load the Transformers BERT model as a layer in a Keras model

transformer = transformer_model([input_ids, attention_mask])
hidden_states = transformer[1] # get output_hidden_states
selected_hiddes_states = tf.keras.layers.Concatenate(axis=-1)([hidden_states[i] for i in [-1,-2,-3,-4]])

high_level = Dense(units=len(df.MEDDRA_PT_label.value_counts()), kernel_initializer=TruncatedNormal(stddev=config.initializer_range), name='MEDDRA_PT')(selected_hiddes_states)
low_level = Dense(units=len(df.MEDDRA_LLT_label.value_counts()), kernel_initializer=TruncatedNormal(stddev=config.initializer_range), name='MEDDRA_LLT')(selected_hiddes_states)

outputs = {'MEDDRA_PT': high_level, 'MEDDRA_LLT': low_level}

# And combine it all in a model object

model = Model(inputs=[input_ids, attention_mask], outputs=outputs, name='BERT_MultiLabel_MultiClass')

# Take a look at the model

model.summary()

当我最初仅使用TfBertMain图层设置模型时,图层的形状是(none,768)。我实际上不明白没有的2维图层和1维图层之间的区别。我如何塑造我的图层,使其兼容?

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题