有没有办法连接3个或更多的tf.data.dataset

pieyvz9o  于 2021-09-29  发布在  Java
关注(0)|答案(2)|浏览(329)

我想在tensorflow中连接3个或更多数据集。要连接2个数据集,

dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset1.concatenate(dataset2)

但是,通过这种方式,3个或更多数据集无法连接。所以我想做喜欢的事

dataset1 = tf.data.Dataset.range(1, 4)
dataset2 = tf.data.Dataset.range(4, 8)
dataset3 = tf.data.Dataset.range(8, 12)
concatenate(dataset1,dataset2,dataset3)

有什么办法吗?

t98cgbkg

t98cgbkg1#

在这个特定的示例中,您可以

concat_dataset = dataset1.concatenate(dataset2).concatenate(dataset3)

请注意,您必须指定 concatenate 到一个新的变量!它没有在适当的地方运行。
当然,如果您有许多数据集,那么这种方法不能很好地扩展,但应该可以:

datasets = [dataset1, dataset2, dataset3]  # can be more than 3 of course

concat_dataset = datasets[0]
for dset in datasets[1:]:
    concat_dataset = concat_dataset.concatenate(dset)
pexxcrt2

pexxcrt22#

import tensorflow as tf

dataset1 = tf.data.Dataset.range(1, 4)

dataset2 = tf.data.Dataset.range(4, 8)

dataset3 = tf.data.Dataset.range(8, 12)

def func(*datasets):

    out = {}

    for dataset in datasets:

        for key in dataset:

            if key in out:

                _value = out[key]

                out[key] = tf.concat([_value, dataset[key]], axis=-1)

            else:

                out[key] = dataset[key]

    return out

tf.data.Dataset.zip((dataset1, dataset2, dataset3)).map(func)

相关问题