将python2dndarray加载到android中,以便在tflite上进行推理

xe55xuns  于 2021-06-30  发布在  Java
关注(0)|答案(1)|浏览(498)

我想测试一个tensorflow lite模型的推断,我已经加载到一个android项目中。
我有一些在python环境中生成的输入,我想保存到一个文件中,加载到我的android应用程序中,并用于tflite推断。我的输入有点大,一个例子是: <class 'numpy.ndarray'>, dtype: float32, shape: (1, 596, 80) 我需要一些方法来序列化这个数组并将其加载到android中。
关于tflite推断的更多信息可以在这里找到。本质上,这应该是原始浮点的多维数组,或者bytebuffer。
最简单的方法是:
在python端序列化这个数组
从文件中反序列化java端的这个blob
谢谢!

axzmvihb

axzmvihb1#

我最终发现了这一点,有一个名为javanpy的方便的java库,它允许您用java打开.npy文件,因此可以用android打开。
在Python方面,我保存了一个扁平的 .npy 以正常方式:

data_flat = data.flatten()
print(data_flat.shape)
np.save(file="data.npy", arr=data_flat)

在android中,我把它放到 assets 文件夹。
然后我把它加载到javanpy中:

InputStream stream = context.getAssets().open("data.npy")
Npy npy = new Npy(stream);
float[] npyData = npy.floatElements();

最后把它变成了张紧缓冲器:

int[] inputShape = new int[]{1, 596, 80};   //the data shape before I flattened it
TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(inputShape, DataType.FLOAT32);
tensorBuffer.loadArray(npyData);

然后我用这个Tensor缓冲器来推断我的tflite模型。

相关问题