tensorflow [适应性]如果(显式/隐式)指定了SparseTensor的索引,则应根据当前索引进行适应,

flvlnr44  于 5个月前  发布在  其他
关注(0)|答案(1)|浏览(134)

系统信息

  • TensorFlow版本(您正在使用的):2.5.0
  • 您是否愿意为其做出贡献(是/否):

我正在使用TensorFlow版本2.5.0,并尝试将稀疏Tensor适应深度学习模型。然而,有趣的是,尽管索引已经由numpy指定,但即使我们尝试了,索引始终被转换为uint64,即使是从numpy或tensorflow(tf.convert_to_tensor || tf.constant)。由于我们的矩阵形状为(10000, 7465),建议使用uint16作为数据类型索引,而不是uint64。
如果索引作为Python列表或元组传递,则更喜欢将其转换为uint64,但是如果输入为numpy数组,则希望能够修改相应的数据类型,使其范围与dense array相当,而不是安全地将其转换为uint64作为输入。
此实现可以帮助减少计算资源的最大占用,因为即使矩阵不够稀疏,存储索引也可能相当昂贵。在我的情况下,幸运的是,数据是uint8,密度为2.24%,所以没有这个问题。
示例:
当前版本:内存成本:数据(uint8),索引(uint64)->内存成本:与密集数组相比38.08%
此实现1:数据(uint8),索引(uint16)->内存成本:与密集数组相比11.02%
此实现2:数据(uint8),索引(uint32)->内存成本:与密集数组相比20.16%
此外,还应该使keras.Input(sparse=True)与Scipy.sparse兼容,而不是发出警告

相关问题