unilm `infer()`函数的`text_ids`、`text_labels`和`text_masks`参数分别来自哪里?

lp0sw83n  于 2个月前  发布在  其他
关注(0)|答案(3)|浏览(51)

我想使用VLMo在NLVR任务上进行推理,但是我对这些内容感到困惑:
text_ids = batch[f"text_ids{do_mlm}"]
text_labels = batch[f"text_labels{do_mlm}"]
text_masks = batch[f"text_masks"]
我仔细阅读了代码,但是我没有找到它们来自哪里。

cotxawn7

cotxawn71#

在我看来,数据集是通过以下函数创建的:

def dataset_cls(self):
    raise NotImplementedError("return tuple of dataset class")

但是它还没有被实现,因此输入的数据集与用于推理的数据集不一致。

j1dl9f46

j1dl9f462#

在我看来,数据集是由以下函数创建的:

def dataset_cls(self):
    raise NotImplementedError("return tuple of dataset class")

但是它还没有被实现,所以输入的数据集与用于推理的数据集不一致。

它们是由位于 /vlmo/datasets/BaseDataset.py 的名为 /vlmo/datamodules/multitask_datamodule.py 的函数生成的

下面是如何找到它的方法。
VLMo 的数据模块在 self.dms 中定义,它将一系列 "unitask_datamodule" 在 /vlmo/datamodules/nlvr2_datamodule.py 中连接起来。在你的情况下,"unitask_datamodule" 来自 BaseDataModule 的子类 /vlmo/datamodules/BaseDataModule.py 中的 /vlmo/datamodules/BaseDataModule.py
显然,数据模块从数据集中加载并预处理数据。在 set_train_dataset 中,方法 self.dataset_cls() 设置了应该通过 NLVR2DataModule 加载哪个数据集。实际上,它是在子类 vlmo.datasets 中定义的,名为 collate_fn,它从 /vlmo/datamodules/BaseDataModule.py 导入。你还可以获取原始数据。

return {
        "image_0": image_tensor_0,
        "image_1": image_tensor_1,
        "text": text,
        "answers": answers,
        "table_name": self.table_names[index],
    }

数据加载器中的 id、标签、掩码和其他信息是通过参数 train_dataloader 在 dataloader 中创建的。回到 /vlmo/datamodules/BaseDataModule/train_dataloader,在 collate_fn 中,你可以找到这样一个参数,它在数据加载器将数据提供给模型之前对数据进行批处理。
self.train_dataset.collate 被设置为 self.train_dataset,而 NLVR2DatasetNLVR2Dataset 的一个示例。同样,x20n20 也是 x20n22 的子类。别忘了我们的目标是找到位于 x20n23 n24 中的名为 "collate" 的方法。
这个函数读取数据的 batch_size,将其放入列表中,并执行其他操作,如标记化、单词掩码和将列表中的每种数据转换为大型Tensor,然后将其输入到模型中。

0sgqnhkj

0sgqnhkj3#

在我看来,数据集是由以下函数创建的:

def dataset_cls(self):
    raise NotImplementedError("return tuple of dataset class")

但是它还没有被实现,所以输入的数据集与用于推理的数据集不一致。

它们是由位于 /vlmo/datasets/BaseDataset.py 的名为 /vlmo/datamodules/multitask_datamodule.py 的函数生成的

下面是如何找到它的方法。
VLMo 的数据模块在 self.dms 中定义,它将一系列 "unitask_datamodule" 在 /vlmo/datamodules/nlvr2_datamodule.py 中连接起来。在你的情况下,"unitask_datamodule" 来自 BaseDataModule 的子类 /vlmo/datamodules/BaseDataModule.py
显然,数据模块从数据集中加载并预处理数据。在 /vlmo/datamodules/BaseDataModule.py 中,方法 set_train_dataset 通过 self.dataset_cls() 设置了应该加载哪个数据集。实际上,它是在子类 NLVR2DataModule(从 vlmo.datasets 导入)中定义的,名为 NLVR2Dataset。你还可以获取原始数据。

return {
        "image_0": image_tensor_0,
        "image_1": image_tensor_1,
        "text": text,
        "answers": answers,
        "table_name": self.table_names[index],
    }

数据加载器中的 id、标签、掩码和其他信息是通过参数 collate_fn 创建的。回到 /vlmo/datamodules/BaseDataModule.py,在 train_dataloader 中,你可以找到这样一个参数,它在数据加载器将数据提供给模型之前对数据进行批处理。
/vlmo/datamodules/BaseDataModule/train_dataloadercollate_fn 被设置为 self.train_dataset.collate,而 self.train_datasetNLVR2Dataset 的一个示例。NLVR2Dataset 也是 BaseDataset 的子类。别忘了我们的目标是找到位于 BaseDataset 末尾的名为 "collate" 的方法,它位于 /vlmo/datasets/BaseDataset.py 中。
这样的函数读取数据的 batch_size,将其放入列表中,并执行其他操作,如标记化、单词掩码和将列表中的每种数据转换为大型Tensor,然后将其输入到模型中。

def collate(self, batch, mlm_collator):
    batch_size = len(batch)
    keys = set([key for b in batch for key in b.keys()])
    dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys}
    img_keys = [k for k in list(dict_batch.keys()) if "image" in k]
    for img_key in img_keys:
        new_imgs = [tmp_img[0] for tmp_img in dict_batch[img_key]]
        batch_new_imgs = torch.stack(new_imgs, dim=0)
        dict_batch[img_key] = [batch_new_imgs]

    txt_keys = [k for k in list(dict_batch.keys()) if "text" in k]
    if len(txt_keys) != 0:
        texts = [[d[0] for d in dict_batch[txt_key]] for txt_key in txt_keys]
        encodings = [[d[1] for d in dict_batch[txt_key]] for txt_key in txt_keys]
        draw_text_len = len(encodings)
        flatten_encodings = [e for encoding in encodings for e in encoding]
        flatten_mlms = mlm_collator(flatten_encodings)

        for i, txt_key in enumerate(txt_keys):
            texts, encodings = (
                [d[0] for d in dict_batch[txt_key]],
                [d[1] for d in dict_batch[txt_key]],
            )

            mlm_ids, mlm_labels = (
                flatten_mlms["input_ids"][batch_size * (i) : batch_size * (i + 1)],
                flatten_mlms["labels"][batch_size * (i) : batch_size * (i + 1)],
            )

            input_ids = torch.zeros_like(mlm_ids)
            attention_mask = torch.zeros_like(mlm_ids)
            for _i, encoding in enumerate(encodings):
                _input_ids, _attention_mask = (
                    torch.tensor(encoding["input_ids"]),
                    torch.tensor(encoding["attention_mask"]),
                )
                input_ids[_i, : len(_input_ids)] = _input_ids
                attention_mask[_i, : len(_attention_mask)] = _attention_mask

            dict_batch[txt_key] = texts
            dict_batch[f"{txt_key}_ids"] = input_ids
            dict_batch[f"{txt_key}_labels"] = torch.full_like(input_ids, -100)
            dict_batch[f"{txt_key}_ids_mlm"] = mlm_ids
            dict_batch[f"{txt_key}_labels_mlm"] = mlm_labels
            dict_batch[f"{txt_key}_masks"] = attention_mask

    return dict_batch

明白了。谢谢你的详细回复。

相关问题