我目前正在实现this tensorflow项目的pytorch版本。他们发表了一篇论文,可以找到here。
- 数据集格式和在哪里可以找到它们可用的数据集是. tfrecord文件,可以通过他们提供的脚本下载,该脚本可以从storage. google-apis站点获得它们。我能够下载文件(如果你不想运行他们的脚本,你可以在this驱动器链接中找到一个数据集文件的例子。有一些数据集here的. npz版本,我想让所有这些都以这种格式用于工作流目的)。数据集的. npz版本可以作为字典加载,关键字为“simulation_trajectory_0”,“simulation_trajectory_1”...“simulation_trajectory_999”。这些键中的每一个包含形状[320,x,2]的阵列,其中x从大约100到2000变化。这是有意义的,因为它们是x粒子在2维中320个时间步的轨迹的模拟。* * 我的目标我的目标是读取这些tfrecords,并理想地将它们转换为. npz文件,然后将其加载到numpy中并转换为torchTensor。我甚至满足于用tensorflow读取它们并直接将它们转换为torchTensor,但由于我正在使用pytorch进行团队项目,我想自己处理数据预处理,而不是强迫我的同事也使用tensorflow。
- 问题,以及我尝试了什么**我对tensorflow非常陌生,所以我无法正确读取文件。据我所知,tfrecord文件是大型数据集的序列化,通过流式方法读取,通常不会一次处理所有数据集。我读了很多关于stackoverflow的问题,也读了TF文档,看起来我需要学习我的. tf文件的特性,创建一个解析函数,并将文件+解析函数交给
tf.data.TFRecordDataset
。我能够使用以下代码从. tfrecord文件中提取特征:
- 问题,以及我尝试了什么**我对tensorflow非常陌生,所以我无法正确读取文件。据我所知,tfrecord文件是大型数据集的序列化,通过流式方法读取,通常不会一次处理所有数据集。我读了很多关于stackoverflow的问题,也读了TF文档,看起来我需要学习我的. tf文件的特性,创建一个解析函数,并将文件+解析函数交给
import tensorflow as tf
def list_record_features(tfrecords_path):
# Dict of extracted feature information
features = {}
# Iterate records
for rec in tf.data.TFRecordDataset([str(tfrecords_path)]):
# Get record bytes
example_bytes = rec.numpy()
# Parse example protobuf message
example = tf.train.Example()
example.ParseFromString(example_bytes)
# Iterate example features
for key, value in example.features.feature.items():
# Kind of data in the feature
kind = value.WhichOneof('kind')
# Size of data in the feature
size = len(getattr(value, kind).value)
# Check if feature was seen before
if key in features:
# Check if values match, use None otherwise
kind2, size2 = features[key]
if kind != kind2:
kind = None
if size != size2:
size = None
# Save feature data
features[key] = (kind, size)
return features
在this question中找到
features = list_record_features("valid.tfrecord")
features
给我这个输出:'key':('int64_list',1),'particle_type':('bytes_list',1)}
现在如果我编写一个decode函数并将其提供给TFRecordDataset
def decode_fn(record_bytes):
return tf.io.parse_single_example(
# Data
record_bytes,
# Schema
{"x": tf.io.FixedLenFeature([], dtype=tf.train.Int64List),
"y": tf.io.FixedLenFeature([], dtype=tf.train.BytesList)} #these two types give error
)
for batch in tf.data.TFRecordDataset("valid.tfrecord").map(decode_fn):
print("x = {x:.4f}, y = {y:.4f}".format(**batch))
我不知道在dtype=
后面放什么,也不知道如何将数据转换为可以保存到. npz的正确格式。
唯一一个和需要的有点类似的东西是question,但是当我运行它的时候,它给了我这样一个表key particle_type position step_context 0 0...[b'\xe6\x81\x88\x88\x88\x88]什么?...我不知道... 11...什么?你知道吗?我不知道... 2 2...我不知道... 3 3...我不知道... 4 4...(7?我的天我不知道......
我很抱歉这个问题很长,但这是我第一次问,我想确保有所有需要的信息。任何帮助或想法都非常感谢。
1条答案
按热度按时间k4emjkb11#
tf.io.parse_single_example
调用中使用的tf.io.FixedLenFeature
的dtype
参数应该是tf.dtypes
的类型,而不是tf.train
的python类型。尝试类似这样的操作(注意,您希望字典键与从list_record_features
打印的字符串名称相匹配):