如何用uint 8数据训练tf/keras模型?

fdbelqdn  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(209)

我有uint 8格式的数据。尝试在此数据上训练tf/keras模型时给予以下错误:

  1. Failed to convert a NumPy array to a Tensor (Unsupported object type int)

其他人在这里提出的问题建议把数字转换成浮点数;例如使用:

  1. data.astype('float32')

然而,它会显著增加内存使用。有没有办法在不增加内存使用的情况下,将uint 8数据输入到一个tf/keras模型中进行训练?

11dmarpk

11dmarpk1#

这意味着数据集中的某些值必须为负数或超出unsignedInt8范围。
不过,您可以使用混合精确度定型,以最少的内存来定型模型。

  1. from tensorflow.keras import mixed_precision
  2. mixed_precision.set_global_policy('mixed_float16')

这样做的目的是让你的模型尽可能地在float16精度上训练,以优化内存使用。但是,只有当你的GPU具有7.0或更高的计算能力时,它才起作用。更多信息请访问https://www.tensorflow.org/guide/mixed_precision
或者,您可以手动将其硬编码为float16,以减少内存使用。

相关问题