def collate(self, batch, mlm_collator):
batch_size = len(batch)
keys = set([key for b in batch for key in b.keys()])
dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys}
img_keys = [k for k in list(dict_batch.keys()) if "image" in k]
for img_key in img_keys:
new_imgs = [tmp_img[0] for tmp_img in dict_batch[img_key]]
batch_new_imgs = torch.stack(new_imgs, dim=0)
dict_batch[img_key] = [batch_new_imgs]
txt_keys = [k for k in list(dict_batch.keys()) if "text" in k]
if len(txt_keys) != 0:
texts = [[d[0] for d in dict_batch[txt_key]] for txt_key in txt_keys]
encodings = [[d[1] for d in dict_batch[txt_key]] for txt_key in txt_keys]
draw_text_len = len(encodings)
flatten_encodings = [e for encoding in encodings for e in encoding]
flatten_mlms = mlm_collator(flatten_encodings)
for i, txt_key in enumerate(txt_keys):
texts, encodings = (
[d[0] for d in dict_batch[txt_key]],
[d[1] for d in dict_batch[txt_key]],
)
mlm_ids, mlm_labels = (
flatten_mlms["input_ids"][batch_size * (i) : batch_size * (i + 1)],
flatten_mlms["labels"][batch_size * (i) : batch_size * (i + 1)],
)
input_ids = torch.zeros_like(mlm_ids)
attention_mask = torch.zeros_like(mlm_ids)
for _i, encoding in enumerate(encodings):
_input_ids, _attention_mask = (
torch.tensor(encoding["input_ids"]),
torch.tensor(encoding["attention_mask"]),
)
input_ids[_i, : len(_input_ids)] = _input_ids
attention_mask[_i, : len(_attention_mask)] = _attention_mask
dict_batch[txt_key] = texts
dict_batch[f"{txt_key}_ids"] = input_ids
dict_batch[f"{txt_key}_labels"] = torch.full_like(input_ids, -100)
dict_batch[f"{txt_key}_ids_mlm"] = mlm_ids
dict_batch[f"{txt_key}_labels_mlm"] = mlm_labels
dict_batch[f"{txt_key}_masks"] = attention_mask
return dict_batch
3条答案
按热度按时间cotxawn71#
在我看来,数据集是通过以下函数创建的:
但是它还没有被实现,因此输入的数据集与用于推理的数据集不一致。
j1dl9f462#
在我看来,数据集是由以下函数创建的:
但是它还没有被实现,所以输入的数据集与用于推理的数据集不一致。
它们是由位于
/vlmo/datasets/BaseDataset.py
的名为/vlmo/datamodules/multitask_datamodule.py
的函数生成的下面是如何找到它的方法。
VLMo 的数据模块在
self.dms
中定义,它将一系列 "unitask_datamodule" 在/vlmo/datamodules/nlvr2_datamodule.py
中连接起来。在你的情况下,"unitask_datamodule" 来自BaseDataModule
的子类/vlmo/datamodules/BaseDataModule.py
中的/vlmo/datamodules/BaseDataModule.py
。显然,数据模块从数据集中加载并预处理数据。在
set_train_dataset
中,方法self.dataset_cls()
设置了应该通过NLVR2DataModule
加载哪个数据集。实际上,它是在子类vlmo.datasets
中定义的,名为collate_fn
,它从/vlmo/datamodules/BaseDataModule.py
导入。你还可以获取原始数据。数据加载器中的 id、标签、掩码和其他信息是通过参数
train_dataloader
在 dataloader 中创建的。回到/vlmo/datamodules/BaseDataModule/train_dataloader
,在collate_fn
中,你可以找到这样一个参数,它在数据加载器将数据提供给模型之前对数据进行批处理。self.train_dataset.collate
被设置为self.train_dataset
,而NLVR2Dataset
是NLVR2Dataset
的一个示例。同样,x20n20 也是 x20n22 的子类。别忘了我们的目标是找到位于 x20n23 n24 中的名为 "collate" 的方法。这个函数读取数据的 batch_size,将其放入列表中,并执行其他操作,如标记化、单词掩码和将列表中的每种数据转换为大型Tensor,然后将其输入到模型中。
0sgqnhkj3#
在我看来,数据集是由以下函数创建的:
但是它还没有被实现,所以输入的数据集与用于推理的数据集不一致。
它们是由位于
/vlmo/datasets/BaseDataset.py
的名为/vlmo/datamodules/multitask_datamodule.py
的函数生成的下面是如何找到它的方法。
VLMo 的数据模块在
self.dms
中定义,它将一系列 "unitask_datamodule" 在/vlmo/datamodules/nlvr2_datamodule.py
中连接起来。在你的情况下,"unitask_datamodule" 来自BaseDataModule
的子类/vlmo/datamodules/BaseDataModule.py
。显然,数据模块从数据集中加载并预处理数据。在
/vlmo/datamodules/BaseDataModule.py
中,方法set_train_dataset
通过self.dataset_cls()
设置了应该加载哪个数据集。实际上,它是在子类NLVR2DataModule
(从vlmo.datasets
导入)中定义的,名为NLVR2Dataset
。你还可以获取原始数据。数据加载器中的 id、标签、掩码和其他信息是通过参数
collate_fn
创建的。回到/vlmo/datamodules/BaseDataModule.py
,在train_dataloader
中,你可以找到这样一个参数,它在数据加载器将数据提供给模型之前对数据进行批处理。/vlmo/datamodules/BaseDataModule/train_dataloader
和collate_fn
被设置为self.train_dataset.collate
,而self.train_dataset
是NLVR2Dataset
的一个示例。NLVR2Dataset
也是BaseDataset
的子类。别忘了我们的目标是找到位于BaseDataset
末尾的名为 "collate" 的方法,它位于/vlmo/datasets/BaseDataset.py
中。这样的函数读取数据的 batch_size,将其放入列表中,并执行其他操作,如标记化、单词掩码和将列表中的每种数据转换为大型Tensor,然后将其输入到模型中。
明白了。谢谢你的详细回复。