我正在使用pytorch_lightning
处理多个数据集的训练。数据集有不同的长度---〉在相应的DataLoader
s中有不同的批数。现在我试图通过使用字典来分开,因为我的最终目标是根据特定的数据集加权损失函数:
def train_dataloader(self): #returns a dict of dataloaders
train_loaders = {}
for key, value in self.train_dict.items():
train_loaders[key] = DataLoader(value,
batch_size = self.batch_size,
collate_fn = collate)
return train_loaders
然后,在training_step()
中执行以下操作:
def training_step(self, batch, batch_idx):
total_batch_loss = 0
for key, value in batch.items():
anc, pos, neg = value
emb_anc = F.normalize(self.forward(anc.x,
anc.edge_index,
anc.weights,
anc.batch,
training=True
), 2, dim=1)
emb_pos = F.normalize(self.forward(pos.x,
pos.edge_index,
pos.weights,
pos.batch,
training=True
), 2, dim=1)
emb_neg = F.normalize(self.forward(neg.x,
neg.edge_index,
neg.weights,
neg.batch,
training=True
), 2, dim=1)
loss_dataset = LossFunc(emb_anc, emb_pos, emb_neg, anc.y, pos.y, neg.y)
total_batch_loss += loss_dataset
self.log("Loss", total_batch_loss, prog_bar=True, on_epoch=True)
return total_batch_loss
问题是,当最小的数据集被耗尽时,Lightning会抛出一个StopIteration
,所以我不会完成对其他数据集的剩余批次的训练。我已经考虑过将所有内容连接到一个单一的训练DataLoader
中,就像docs中建议的那样,但我不知道如何根据这种方法来实现不同的权重损失功能。
1条答案
按热度按时间ftf50wuq1#
您可以使用CombinedLoader类并指定
max_size
模式,以根据可用的最长数据加载器进行迭代。