import tensorflow as tf
# your dataset
ds = tf.data.Dataset.from_tensor_slices(
tf.ragged.constant([[1, 2, 3, 4], [], [5, 6, 7], [8], []]))
# map a function over your dataset that calculates the size of each element
ds_lengths = ds.map(lambda x: tf.size(x))
# get the maximum length
max_length = tf.reduce_max(list(ds_lengths.as_numpy_iterator()))
print(max_length) # prints: 4
2条答案
按热度按时间ffx8fchx1#
我们可以使用函数
reduce
:uubf1zoe2#
您可以通过以下步骤获取TensorFlow数据集中不规则Tensor的所有元素的最大长度:
1.在数据集上Map一个函数,用于计算每个不规则Tensor元素的长度。
1.减少长度以找到最大长度。
以下是如何实现这些步骤:
此脚本使用tf.data.Dataset.map计算数据集中每个元素的大小。tf.size函数计算每个Tensor的大小。然后将大小列表传递给计算最大值的tf.reduce_max。
请注意,TensorFlow的即时执行需要启用**as_numpy_iterator()才能工作,这是TensorFlow 2.0及更高版本中的默认行为。如果您使用的是旧版本的TensorFlow,则可能需要使用tf.enable_eager_execution()**手动启用eager执行。