我正在使用tensorflow和tensorflow_插件库为chatbot开发seq2seq模型。在解码器中,我使用 tfa.seq2seq.dynamic_decode
生成序列。我发现了,因为 maximum_iteration = 15
,有时解码器生成一批序列,所有这些序列的长度为8(启用填充),有时一批序列的长度为7,有时为4,等等。下面是解码器的代码部分:
infer_decoder = tfa.seq2seq.BeamSearchDecoder(
cell=self.decoder_cell,
beam_width=args['beam_width'],
output_layer=self.output_dense)
infer_output, _, _ = tfa.seq2seq.dynamic_decode(
decoder=infer_decoder,
swap_memory=True,
maximum_iterations=args['max_len'],
decoder_init_input=self.embedding,
decoder_init_kwargs={
'start_tokens': tf.tile(tf.constant([args['SOS_ID']], dtype=tf.int32), [tf.shape(context_with_latent)[1]]),
'end_token': args['EOS_ID'],
'initial_state': tfa.seq2seq.tile_batch(init_state_tuple, args['beam_width'])
})
infer_predicted_ids = infer_output.predicted_ids[:, :, 0]
我还注意到,随着训练的进行,生成的序列的长度趋向于接近最大值_迭代,但长度略微徘徊在15左右,而不是每个批次都保持在15左右。
那么这种行为是预期的吗?
以下是一些其他信息:
tensorflow版本:2.3.1
tensorflow_插件版本:0.11.2
暂无答案!
目前还没有任何答案,快来回答吧!