keras 如何在tf.Dataset上调整TextVectorization图层

hlswsv35  于 2023-01-21  发布在  其他
关注(0)|答案(1)|浏览(130)

我像这样加载我的数据集:

self.train_ds = tf.data.experimental.make_csv_dataset(
            self.config["input_paths"]["data"]["train"],
            batch_size=self.params["batch_size"],
            shuffle=False,
            label_name="tags",
            num_epochs=1,
        )

我的TextVectorization图层如下所示:

vectorizer = tf.keras.layers.TextVectorization(
            standardize=code_standaridization,
            split="whitespace",
            output_mode="int",
            output_sequence_length=params["input_dim"],
            max_tokens=100_000,
        )

我想这就足够了:

vectorizer.adapt(data_provider.train_ds)

但它不是,我有这个错误:

TypeError: Expected string, but got Tensor("IteratorGetNext:0", shape=(None, None), dtype=string) of type 'Tensor'.

我可以在TensorFlow数据集上调整矢量器吗?

zujrkrfu

zujrkrfu1#

最有可能的问题是,当您尝试适应时,您在没有.unbatch()train_ds中使用batch_size
您必须做到:

vectorizer.adapt(train_ds.unbatch().map(lambda x, y: x).batch(BATCH_SIZE))

.unbatch()可解决您当前看到的错误,而.map()是必需的,因为TextVectorization图层对字符串批处理,因此您需要从数据集中获取这些字符串

相关问题