Paddle Dataset里面使用multiprocessing会报错

unhi4e5o  于 5个月前  发布在  其他
关注(0)|答案(3)|浏览(98)

bug描述 Describe the Bug

paddle的dataset中使用multiprocessing会报错,pytorch就没问题。导致李沐动手学习深度学习的CI超时。
d2l-ai/d2l-zh#1178 详情请看这个pr里面的 chapter_natural-language-processing-applications/natural-language-inference-bert.md

#@tab paddle
class SNLIBERTDataset(paddle.io.Dataset):
    def __init__(self, dataset, max_len, vocab=None):
        all_premise_hypothesis_tokens = [[
            p_tokens, h_tokens] for p_tokens, h_tokens in zip(
            *[d2l.tokenize([s.lower() for s in sentences])
              for sentences in dataset[:2]])]
        self.labels = paddle.to_tensor(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        (self.all_token_ids, self.all_segments,
         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        print('read ' + str(len(self.all_token_ids)) + ' examples')
    def _preprocess(self, all_premise_hypothesis_tokens):
#         pool = multiprocessing.Pool(4)  # 使用4个进程
#         out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        out = []
        for i in all_premise_hypothesis_tokens:
            tempOut = self._mp_worker(i)
            out.append(tempOut)
        
        all_token_ids = [
            token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        return (paddle.to_tensor(all_token_ids, dtype='int64'),
                paddle.to_tensor(all_segments, dtype='int64'),
                paddle.to_tensor(valid_lens))
    def _mp_worker(self, premise_hypothesis_tokens):
        p_tokens, h_tokens = premise_hypothesis_tokens
        self._truncate_pair_of_tokens(p_tokens, h_tokens)
        tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
                             * (self.max_len - len(tokens))
        segments = segments + [0] * (self.max_len - len(segments))
        valid_len = len(tokens)
        return token_ids, segments, valid_len
    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()
    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx]), self.labels[idx]
    def __len__(self):
        return len(self.all_token_ids)

其他补充信息 Additional Supplementary Information

No response

ocebsuys

ocebsuys1#

when I use PASSL repo, I meet the same TypeError Traceback content like yours.
I solve this problem as following:

not using Tensor in dataset.transforms and keep the dim same to paddle.Tensor by using img.transpose((2, 0, 1))

you can see the difference here --ToCHWImage vs ToTensor

ehxuflar

ehxuflar2#

你好,报啥错呢,报错信息麻烦贴一下

lrl1mhuk

lrl1mhuk3#

只要使用了这两行:

def _preprocess(self, all_premise_hypothesis_tokens):
        pool = multiprocessing.Pool(4)  # 使用4个进程
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)

就会有:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [6], in <cell line: 5>()
      3 batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
      4 data_dir = d2l.download_extract('SNLI')
----> 5 train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
      6 test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
      7 train_iter = paddle.io.DataLoader(train_set, batch_size=batch_size, shuffle=True, return_list=True)

Input In [5], in SNLIBERTDataset.__init__(self, dataset, max_len, vocab)
     10 self.vocab = vocab
     11 self.max_len = max_len
     12 (self.all_token_ids, self.all_segments,
---> 13  self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
     14 print('read ' + str(len(self.all_token_ids)) + ' examples')

Input In [5], in SNLIBERTDataset._preprocess(self, all_premise_hypothesis_tokens)
     16     def _preprocess(self, all_premise_hypothesis_tokens):
     17         pool = multiprocessing.Pool(4)  # 使用4个进程
---> 18         out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
     19 #         out = []
     20 #         for i in all_premise_hypothesis_tokens:
     21 #             tempOut = self._mp_worker(i)
     22 #             out.append(tempOut)
     24         all_token_ids = [
     25             token_ids for token_ids, segments, valid_len in out]

File ~/anaconda3/envs/d2l/lib/python3.8/multiprocessing/pool.py:364, in Pool.map(self, func, iterable, chunksize)
    359 def map(self, func, iterable, chunksize=None):
    360     '''
    361     Apply `func` to each element in `iterable`, collecting the results
    362     in a list that is returned.
    363     '''
--> 364     return self._map_async(func, iterable, mapstar, chunksize).get()

File ~/anaconda3/envs/d2l/lib/python3.8/multiprocessing/pool.py:771, in ApplyResult.get(self, timeout)
    769     return self._value
    770 else:
--> 771     raise self._value

File ~/anaconda3/envs/d2l/lib/python3.8/multiprocessing/pool.py:537, in Pool._handle_tasks(taskqueue, put, outqueue, pool, cache)
    535     break
    536 try:
--> 537     put(task)
    538 except Exception as e:
    539     job, idx = task[:2]

File ~/anaconda3/envs/d2l/lib/python3.8/multiprocessing/connection.py:206, in _ConnectionBase.send(self, obj)
    204 self._check_closed()
    205 self._check_writable()
--> 206 self._send_bytes(_ForkingPickler.dumps(obj))

File ~/anaconda3/envs/d2l/lib/python3.8/multiprocessing/reduction.py:51, in ForkingPickler.dumps(cls, obj, protocol)
     48 @classmethod
     49 def dumps(cls, obj, protocol=None):
     50     buf = io.BytesIO()
---> 51     cls(buf, protocol).dump(obj)
     52     return buf.getbuffer()

TypeError: cannot pickle 'Tensor' object

相关问题