keras 如何找出不规则数据集中维度的最大长度

fdbelqdn  于 2023-06-30  发布在  其他
关注(0)|答案(2)|浏览(153)

如果我有一个从一个不规则的Tensor构建的数据集,我如何获得所有元素的最大长度(在这个例子中为4)?

ds = tf.data.Dataset.from_tensor_slices(
    tf.ragged.constant([[1, 2, 3, 4], [], [5, 6, 7], [8], []]))
ffx8fchx

ffx8fchx1#

我们可以使用函数reduce

ds.reduce(0, lambda state, value: tf.math.maximum(state, len(value))).numpy()
uubf1zoe

uubf1zoe2#

您可以通过以下步骤获取TensorFlow数据集中不规则Tensor的所有元素的最大长度:
1.在数据集上Map一个函数,用于计算每个不规则Tensor元素的长度。
1.减少长度以找到最大长度。
以下是如何实现这些步骤:

import tensorflow as tf

# your dataset
ds = tf.data.Dataset.from_tensor_slices(
    tf.ragged.constant([[1, 2, 3, 4], [], [5, 6, 7], [8], []]))

# map a function over your dataset that calculates the size of each element
ds_lengths = ds.map(lambda x: tf.size(x))

# get the maximum length
max_length = tf.reduce_max(list(ds_lengths.as_numpy_iterator()))

print(max_length)  # prints: 4

此脚本使用tf.data.Dataset.map计算数据集中每个元素的大小。tf.size函数计算每个Tensor的大小。然后将大小列表传递给计算最大值的tf.reduce_max
请注意,TensorFlow的即时执行需要启用**as_numpy_iterator()才能工作,这是TensorFlow 2.0及更高版本中的默认行为。如果您使用的是旧版本的TensorFlow,则可能需要使用tf.enable_eager_execution()**手动启用eager执行。

相关问题