我正在用CTC损失函数训练一个简单的RNN模型(GRU)。下面是代码和模型摘要。我一直得到这个错误如下。似乎在模型中的某个地方,数据维度,可能是输入数据长度(即,[batch_size,length,mfcc_feature]中的长度)减少2。我哪里错了?
def data_generator(batch_size, wav_files, trn_files, numcep, pinyin_dict):
for i in range(len(wav_files)//batch_size):
print("\n##Start Batch: ", i)
mfcc_datasets = []
mfcc_form_orig_len_datasets = []
pinyin_datasets = []
pinyin_code_orig_len_datasets = []
begin = i * batch_size
end = begin + batch_size
print("begin: ", begin, "end: ", end)
dataset_indices = list(range(begin, end))
print("dataset_indices: ", dataset_indices)
wav_files_subset = [wav_files[index] for index in dataset_indices]
trn_files_subset = [trn_files[index] for index in dataset_indices]
train_wav_max_len_batch = get_wav_max_len(wav_files_subset, numcep)
train_pinyin_max_len_batch = get_pinyin_max_len(trn_files_subset, pinyin_dict)
for index in dataset_indices:
# transform wav to mfcc
mfcc_form = wav_to_mfcc(wav_files[index], numcep)
mfcc_form_expanded_padded, mfcc_form_orig_len = expand_pad_mfcc(mfcc_form, train_wav_max_len_batch)
mfcc_datasets.append(mfcc_form_expanded_padded)
mfcc_form_orig_len_datasets.append(mfcc_form_orig_len)
# transform trn to pinyin code
code = trn_pinyin_to_code(trn_files[index], pinyin_dict)
pinyin_code_expanded, pinyin_code_orig_len = expand_trn(code, train_pinyin_max_len_batch)
pinyin_datasets.append(pinyin_code_expanded)
pinyin_code_orig_len_datasets.append(pinyin_code_orig_len)
mfcc_datasets = np.array(mfcc_datasets)
mfcc_form_orig_len_datasets = np.array(mfcc_form_orig_len_datasets)
pinyin_datasets = np.array(pinyin_datasets)
pinyin_code_orig_len_datasets = np.array(pinyin_code_orig_len_datasets)
inputs = {'Inputs': mfcc_datasets, # size = (batch_size, length, num of features, channel)
'CTC_labels': pinyin_datasets, # size = (batch_size, length)
'CTC_input_length': mfcc_form_orig_len_datasets,
'CTC_label_length': pinyin_code_orig_len_datasets,
}
outputs = {'ctc': np.zeros(mfcc_datasets.shape[0],)}
print("mfcc_datasets.shape: ", mfcc_datasets.shape)
print("mfcc_form_orig_len_datasets: ", mfcc_form_orig_len_datasets)
print("pinyin_datasets.shape: ", pinyin_datasets.shape)
print("pinyin_code_orig_len_datasets: ", pinyin_code_orig_len_datasets)
print("outputs.shape: ", np.zeros(mfcc_datasets.shape[0],).shape)
print("##End Batch: ", i)
yield inputs, outputs
def ctc_lambda_func(args):
y_pred, labels, input_length, label_length = args
y_pred = y_pred[:, 2:, :]
return tf.keras.backend.ctc_batch_cost(labels, y_pred, input_length, label_length)
def ctc_model(inputs, y_pred):
labels = tf.keras.Input(name='CTC_labels', shape=[None], dtype='float32')
input_length = tf.keras.Input(name='CTC_input_length', shape=[1], dtype='int64')
label_length = tf.keras.Input(name='CTC_label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
ctc_model = Model(inputs=[inputs, labels, input_length, label_length], outputs=loss_out)
print(ctc_model.summary())
return ctc_model
def simple_rnn_model(input_feat_dim, output_feat_dim):
inputs = tf.keras.Input(name='Inputs', shape=(None, input_feat_dim))
x = GRU(name='GRU_1', units=output_feat_dim, return_sequences=True, kernel_initializer='he_normal')(inputs)
y_pred = Activation('softmax', name='Softmax')(x)
model = Model(inputs=inputs, outputs=y_pred)
print(model.summary())
ctc_model_0 = ctc_model(inputs, y_pred)
return ctc_model_0
model_0 = simple_rnn_model(input_feat_dim=MFCC_FEATURES, output_feat_dim=pinyin_dict.shape[0])
以下是模型摘要:型号:“model_8”
图层类型输出形状参数#
Inputs(InputLayer)[(None,None,13)] 0
GRU_1(GRU)(无,无,29)3828
Softmax(激活)(无,无,29)0
总参数:3,828个可训练参数:3,828个不可训练的参数:0
产品型号:“model_9”
图层(类型)输出形状参数#连接到
Inputs(InputLayer)[(None,None,13)] 0
GRU_1(GRU)(无,无,29)3828输入[0][0]
Softmax(激活)(无,无,29)0 GRU_1[0][0]
CTC_labels(InputLayer)[(None,None)] 0
CTC_input_length(InputLayer)[(None,1)] 0
CTC_label_length(InputLayer)[(None,1)] 0
ctc(Lambda)(无,1)0 Softmax[0][0] CTC_labels[0][0] CTC_input_length[0][0] CTC_label_length[0][0]
总参数:3,828个可训练参数:3,828个不可训练的参数:0
没有一
下面是错误消息,以及用于分析的中间文件:
#Training Epoch:... 0
##Start Batch: 0
begin: 0 end: 2
dataset_indices: [0, 1]
mfcc_datasets.shape: (2, 883, 13)
mfcc_form_orig_len_datasets: [779 883]
pinyin_datasets.shape: (2, 34)
pinyin_code_orig_len_datasets: [31 34]
outputs.shape: (2,)
##End Batch: 0
##Start Batch: 1
begin: 2 end: 4
dataset_indices: [2, 3]
mfcc_datasets.shape: (2, 819, 13)
mfcc_form_orig_len_datasets: [819 794]
pinyin_datasets.shape: (2, 33)
pinyin_code_orig_len_datasets: [33 32]
outputs.shape: (2,)
##End Batch: 1
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-56-f01d15ca6481> in <module>
----> 1 hist = train_model(model_0)
<ipython-input-55-72d8aa1713c5> in train_model(model)
20 print('#Training Epoch:... ', epoch)
21 batch = data_generator(BATCH_SIZE, train_wav_files[0:8], train_trn_files[0:8], NUMCEP, pinyin_dict)
---> 22 hist = current_model.fit(batch, steps_per_epoch=BATCH_NUM, epochs=1, verbose=1)
23
24 return hist
~\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1098 _r=1):
1099 callbacks.on_train_batch_begin(step)
-> 1100 tmp_logs = self.train_function(iterator)
1101 if data_handler.should_sync:
1102 context.async_wait()
~\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
~\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
886 # Lifting succeeded, so variables are initialized and we can run the
887 # stateless function.
--> 888 return self._stateless_fn(*args, **kwds)
889 else:
890 _, _, _, filtered_flat_args = \
~\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
2940 (graph_function,
2941 filtered_flat_args) = self._maybe_define_function(args, kwargs)
-> 2942 return graph_function._call_flat(
2943 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
2944
~\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1916 and executing_eagerly):
1917 # No tape is watching; skip to running the function.
-> 1918 return self._build_call_outputs(self._inference_function.call(
1919 ctx, args, cancellation_manager=cancellation_manager))
1920 forward_backward = self._select_forward_and_backward_functions(
~\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
553 with _InterpolateFunctionError(self):
554 if cancellation_manager is None:
--> 555 outputs = execute.execute(
556 str(self.signature.name),
557 num_outputs=self._num_outputs,
~\Anaconda3\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 try:
58 ctx.ensure_initialized()
---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
InvalidArgumentError: sequence_length(1) <= 881
[[node model_9/ctc/CTCLoss (defined at <ipython-input-34-5693b53d741a>:8) ]] [Op:__inference_train_function_11579]
Function call stack:
train_function
1条答案
按热度按时间mznpcxlj1#
好吧...我知道为什么了...因为我有这个
y_pred = y_pred[:,2:,:]
在ctc_lambda_loss函数中