将numpy/scipy函数Map到tensorflow.data.Dataset

uz75evzq  于 2024-01-08  发布在  其他
关注(0)|答案(1)|浏览(171)

我试图提取使用tf.keras.preprocessing.timeseries_dataset_from_array窗口化的1D信号的某些特征,从中我获得了一个tf.data.Dataset对象ds(参见下面的代码)。理想情况下,我希望使用数据集的内置map方法将特征函数(使用numpy和scipy)Map到数据集上。
然而,当我天真地尝试这样做时:

  1. import numpy as np
  2. import scipy as sc
  3. import tensorflow as tf
  4. def feat1_func(x, sf, axis):
  5. x = np.asarray(x)
  6. feat1_value = np.apply_along_axis(
  7. lambda vals: sc.integrate.trapezoid(abs(vals), dx=1 / sf), axis=axis, arr=x
  8. )
  9. return feat1_value
  10. features = ['feat1']
  11. feature_map = {'feat1': feat1_func}
  12. x = np.random.rand(100, 5)
  13. y = np.random.randint(low=0, high=2, size=100)
  14. sequence_length = 10
  15. sequence_stride = 3
  16. ds = tf.keras.preprocessing.timeseries_dataset_from_array(
  17. data=x,
  18. targets=y,
  19. sequence_length=sequence_length,
  20. sequence_stride=sequence_stride,
  21. batch_size=None,
  22. shuffle=False,
  23. )
  24. feat_lambda = lambda x, y: (np.array([feature_map[ft](x, sf=1000, axis=0) for ft in features]), y)
  25. ds = ds.map(feat_lambda)

字符串
我收到以下错误消息:

  1. NotImplementedError: Cannot convert a symbolic tf.Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.


这个问题最简单的解决方法是什么?当Map发生时,是否可以将符号Tensor转换为渴望Tensor?

mbskvtky

mbskvtky1#

解决的办法是改变路线

  1. feat_lambda = lambda x, y: (np.array([feature_map[ft](x, sf=1000, axis=0) for ft in features]), y)

字符串

  1. feat_lambda = lambda x, y: ([tf.numpy_function(feature_map[ft], [x, 1000, 0], tf.float64) for ft in features], y)


tf.numpy_function接受一个处理numpy数组的函数,并以渴望模式处理Dataset的Tensor(也就是具有真实的值的Tensor,而不是符号Tensor)。
ds = ds.map(feat_lambda)tf.py_function上也没有错误,但是当我试图循环数据集时,我得到了错误:

  1. # This didn't work
  2. feat_lambda = lambda x, y: ([tf.py_function(feature_map[ft], [x, 1000, 0], tf.float64) for ft in features], y)
  3. ds2 = ds.map(feat_lambda)
  4. for elem in ds2:
  5. print(elem) # here I got an error with tf.py_function, not with tf.numpy_function

展开查看全部

相关问题