tensorflow 将tf tf.data.Dataset元组拆分为多个数据集

blmhpbnm  于 2023-11-21  发布在  其他
关注(0)|答案(1)|浏览(127)

我有一个tf.data.Dataset,其形状如下:

<ConcatenateDataset shapes: ((None, None, 12), (None, 5)), types: (tf.float64, tf.float64)>

字符串
我可以拆分这个数据集以获得两个数据集,如下所示:

<Dataset shapes: (None, None, 12), types: tf.float64>
<Dataset shapes: (None, 5), types: tf.float64>

rjee0c15

rjee0c151#

您可以使用map函数来拆分它们。
演示:

import tensorflow as tf

# Create a random tensorflow dataset.
dataset1 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([40, 10, 12]), tf.random.uniform([40, 5]))).batch(16)
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([40, 12, 12]), tf.random.uniform([40, 5]))).batch(16)

dataset = dataset1.concatenate(dataset2)
dataset
>> <ConcatenateDataset shapes: ((None, None, 12), (None, 5)), types: (tf.float32, tf.float32)>

字符串
为了分裂:

data = dataset.map(lambda x, y: x)
labels = dataset.map(lambda x, y: y)

相关问题