在`pytorch_lightning`中处理多个数据集/数据加载器

new9mtju  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(254)

我正在使用pytorch_lightning处理多个数据集的训练。数据集有不同的长度---〉在相应的DataLoader s中有不同的批数。现在我试图通过使用字典来分开,因为我的最终目标是根据特定的数据集加权损失函数:

def train_dataloader(self): #returns a dict of dataloaders
    train_loaders = {}
    for key, value in self.train_dict.items():
        train_loaders[key] = DataLoader(value,
                                        batch_size = self.batch_size,
                                        collate_fn = collate)
    return train_loaders

然后,在training_step()中执行以下操作:

def training_step(self, batch, batch_idx):        
    total_batch_loss = 0

    for key, value in batch.items():
        anc, pos, neg  = value
        emb_anc = F.normalize(self.forward(anc.x,
                                           anc.edge_index,
                                           anc.weights,
                                           anc.batch,
                                           training=True
                                           ), 2, dim=1)
    
        emb_pos = F.normalize(self.forward(pos.x,
                                           pos.edge_index,
                                           pos.weights,
                                           pos.batch,
                                           training=True
                                           ), 2, dim=1)
    
        emb_neg = F.normalize(self.forward(neg.x,
                                           neg.edge_index,
                                           neg.weights,
                                           neg.batch,
                                           training=True
                                           ), 2, dim=1)
                                
        loss_dataset = LossFunc(emb_anc, emb_pos, emb_neg, anc.y, pos.y, neg.y)
        total_batch_loss += loss_dataset
        
    self.log("Loss", total_batch_loss, prog_bar=True, on_epoch=True)        
    return total_batch_loss

问题是,当最小的数据集被耗尽时,Lightning会抛出一个StopIteration,所以我不会完成对其他数据集的剩余批次的训练。我已经考虑过将所有内容连接到一个单一的训练DataLoader中,就像docs中建议的那样,但我不知道如何根据这种方法来实现不同的权重损失功能。

ftf50wuq

ftf50wuq1#

您可以使用CombinedLoader类并指定max_size模式,以根据可用的最长数据加载器进行迭代。

相关问题