python 如何打印www.example.com中的数据集示例tf.data?

li9yvcax  于 2022-12-17  发布在  Python
关注(0)|答案(3)|浏览(115)

我在tf.data中有一个数据集,如何轻松地打印(或抓取)数据集中的一个元素?
类似于:

print(dataset[0])
nnsrf1az

nnsrf1az1#

在TF 1.x中,你可以使用以下的迭代器。这里提供了不同的迭代器(有些在未来的版本中可能会被弃用)。

import tensorflow as tf

d = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])
diter = d.make_one_shot_iterator()
e1 = diter.get_next()

with tf.Session() as sess:
  print(sess.run(e1))

或在TF 2.x中

import tensorflow as tf

d = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])
print(next(iter(d)).numpy())

## You can also use loops as follows to traverse the full set one item at a time
for elem in d:
    print(elem)
siv3szwd

siv3szwd2#

如果您的TensorFlow数据集命名为dataset,您可以按如下方式访问第一个元素:

list(dataset.as_numpy_iterator())[0]

参见文件。

pwuypxnk

pwuypxnk3#

也可以将J.V.提出的as_numpy_iterator方法与take方法结合使用,take方法允许您指定要从tf数据集中提取多少个元素,例如:

import tensorflow as tf

>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.take(1) # take one element (the first)
>>> list(dataset.as_numpy_iterator())
0

更改take方法中的数量将允许您提取不同数量的元素(按照它们插入数据集中的顺序)。

相关问题