tensorflow 阅读并解码.tfrecord文件,并将其保存到.npz文件或在numpy/pytorch中读取

goqiplq2  于 2023-06-24  发布在  其他
关注(0)|答案(1)|浏览(116)

我目前正在实现this tensorflow项目的pytorch版本。他们发表了一篇论文,可以找到here

    • 数据集格式和在哪里可以找到它们可用的数据集是. tfrecord文件,可以通过他们提供的脚本下载,该脚本可以从storage. google-apis站点获得它们。我能够下载文件(如果你不想运行他们的脚本,你可以在this驱动器链接中找到一个数据集文件的例子。有一些数据集here的. npz版本,我想让所有这些都以这种格式用于工作流目的)。数据集的. npz版本可以作为字典加载,关键字为“simulation_trajectory_0”,“simulation_trajectory_1”...“simulation_trajectory_999”。这些键中的每一个包含形状[320,x,2]的阵列,其中x从大约100到2000变化。这是有意义的,因为它们是x粒子在2维中320个时间步的轨迹的模拟。* * 我的目标我的目标是读取这些tfrecords,并理想地将它们转换为. npz文件,然后将其加载到numpy中并转换为torchTensor。我甚至满足于用tensorflow读取它们并直接将它们转换为torchTensor,但由于我正在使用pytorch进行团队项目,我想自己处理数据预处理,而不是强迫我的同事也使用tensorflow。
    • 问题,以及我尝试了什么**我对tensorflow非常陌生,所以我无法正确读取文件。据我所知,tfrecord文件是大型数据集的序列化,通过流式方法读取,通常不会一次处理所有数据集。我读了很多关于stackoverflow的问题,也读了TF文档,看起来我需要学习我的. tf文件的特性,创建一个解析函数,并将文件+解析函数交给tf.data.TFRecordDataset。我能够使用以下代码从. tfrecord文件中提取特征:
import tensorflow as tf

def list_record_features(tfrecords_path):
    # Dict of extracted feature information
    features = {}
    # Iterate records
    for rec in tf.data.TFRecordDataset([str(tfrecords_path)]):
        # Get record bytes
        example_bytes = rec.numpy()
        # Parse example protobuf message
        example = tf.train.Example()
        example.ParseFromString(example_bytes)
        # Iterate example features
        for key, value in example.features.feature.items():
            # Kind of data in the feature
            kind = value.WhichOneof('kind')
            # Size of data in the feature
            size = len(getattr(value, kind).value)
            # Check if feature was seen before
            if key in features:
                # Check if values match, use None otherwise
                kind2, size2 = features[key]
                if kind != kind2:
                    kind = None
                if size != size2:
                    size = None
            # Save feature data
            features[key] = (kind, size)
    return features

this question中找到

features = list_record_features("valid.tfrecord")
features

给我这个输出:'key':('int64_list',1),'particle_type':('bytes_list',1)}
现在如果我编写一个decode函数并将其提供给TFRecordDataset

def decode_fn(record_bytes):
  return tf.io.parse_single_example(
      # Data
      record_bytes,

      # Schema
      {"x": tf.io.FixedLenFeature([], dtype=tf.train.Int64List),
       "y": tf.io.FixedLenFeature([], dtype=tf.train.BytesList)} #these two types give error
  )

for batch in tf.data.TFRecordDataset("valid.tfrecord").map(decode_fn):
  print("x = {x:.4f},  y = {y:.4f}".format(**batch))

我不知道在dtype=后面放什么,也不知道如何将数据转换为可以保存到. npz的正确格式。
唯一一个和需要的有点类似的东西是question,但是当我运行它的时候,它给了我这样一个表key particle_type position step_context 0 0...[b'\xe6\x81\x88\x88\x88\x88]什么?...我不知道... 11...什么?你知道吗?我不知道... 2 2...我不知道... 3 3...我不知道... 4 4...(7?我的天我不知道......
我很抱歉这个问题很长,但这是我第一次问,我想确保有所有需要的信息。任何帮助或想法都非常感谢。

k4emjkb1

k4emjkb11#

tf.io.parse_single_example调用中使用的tf.io.FixedLenFeaturedtype参数应该是tf.dtypes的类型,而不是tf.train的python类型。尝试类似这样的操作(注意,您希望字典键与从list_record_features打印的字符串名称相匹配):

tf.io.parse_single_example(
      record_bytes,
      {"key": tf.io.FixedLenFeature([], dtype=tf.int64),
       "particle_type": tf.io.FixedLenFeature([], dtype=tf.string)})

相关问题