bug描述 Describe the Bug
paddle的dataset中使用multiprocessing会报错,pytorch就没问题。导致李沐动手学习深度学习的CI超时。
d2l-ai/d2l-zh#1178 详情请看这个pr里面的 chapter_natural-language-processing-applications/natural-language-inference-bert.md
#@tab paddle
class SNLIBERTDataset(paddle.io.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = paddle.to_tensor(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
# pool = multiprocessing.Pool(4) # 使用4个进程
# out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
out = []
for i in all_premise_hypothesis_tokens:
tempOut = self._mp_worker(i)
out.append(tempOut)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (paddle.to_tensor(all_token_ids, dtype='int64'),
paddle.to_tensor(all_segments, dtype='int64'),
paddle.to_tensor(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)
其他补充信息 Additional Supplementary Information
No response
3条答案
按热度按时间ocebsuys1#
when I use PASSL repo, I meet the same TypeError Traceback content like yours.
I solve this problem as following:
not using Tensor in dataset.transforms and keep the dim same to paddle.Tensor by using
img.transpose((2, 0, 1))
you can see the difference here --ToCHWImage vs ToTensor
ehxuflar2#
你好,报啥错呢,报错信息麻烦贴一下
lrl1mhuk3#
只要使用了这两行:
就会有: