陈明威:
if is_classify:
pyreader = fluid.layers.py_reader(
capacity=50,
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1,args.paragram_conut,args.paragram_max_len, 1],
[-1, 1], [-1, 1]],
dtypes=[ 'int64', 'int64', 'int64', 'int64', 'float32', 'int64', 'int64', 'int64'
],
lod_levels=[0, 0, 0, 0, 0, 1, 0, 0],
name=task_name + "_" + pyreader_name,
use_double_buffer=True)
陈明威:
这个是官方提供的代码,我就加了 [-1,args.paragram_conut,args.paragram_max_len, 1]倒数第三个
陈明威:
(src_ids, sent_ids, pos_ids, task_ids, input_mask, paragraph, labels,
qids) = fluid.layers.read_file(pyreader)
陈明威:
数据处理部分是def _pad_batch_records(self, batch_records,paragram_conut):
"""change data type to model"""
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_labels = [record.label_id for record in batch_records]
batch_contents_ids = [record.contens_ids for record in batch_records]
if self.is_classify:
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1])
elif self.is_regression:
batch_labels = np.array(batch_labels).astype("float32").reshape([-1, 1])
if batch_records[0].qid or batch_records[0].qid == 0:
batch_qids = [record.qid for record in batch_records]
batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
else:
batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding
padded_token_ids, input_mask = pad_batch_data(
batch_token_ids, pad_idx=self.pad_id, return_input_mask=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids, pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids, pad_idx=self.pad_id)
padded_contents_ids = pad_batch_content_data(
batch_contents_ids,paragram_conut=paragram_conut, pad_idx=self.pad_id)
padded_task_ids = np.ones_like(padded_token_ids, dtype="int64") * self.task_id
# print( np.max(padded_contents_ids), np.min(padded_contents_ids),'111111111111111111111111')
if padded_contents_ids.shape[0]<16:
print(padded_contents_ids.shape,'111111111111111111111111111111111')
# print(padded_token_ids.shape)
# if np.max(padded_contents_ids)>500 or np.min(padded_contents_ids)<0:
#
# print('11111111111111111111111111111111111111111111111')
label_tensor = fluid.LoDTensor()
label_tensor.set(padded_contents_ids, fluid.CPUPlace())
return_list = [ padded_token_ids, padded_text_type_ids, pa
陈明威:
其中padded_contents_ids = pad_batch_content_data(
batch_contents_ids,paragram_conut=paragram_conut, pad_idx=self.pad_id)
陈明威:
是我添加的,padded_contents_ids 在这个地方最大值也就476,数值类型都是int64
陈明威:
读取用的是train_pyreader.decorate_tensor_provider(train_data_generator)
陈明威:
然后最后读出来的数据竟然id变得很大
陈明威:
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
Exception: /paddle/paddle/fluid/operators/lookup_table_op.cu:36 Assertion id < N
failed (received id: 4140473109978529412).
17条答案
按热度按时间deyfvvtc16#
@chenmingwei00 请问一下,为什么要修改这个shape呢?
qcuzuvrc17#
没有人回答呀