tensorflow 如何保存状态依赖于输入的有状态TFLite模型?

krcsximq  于 2023-01-26  发布在  其他
关注(0)|答案(1)|浏览(113)

我尝试在tensorflow-lite中做一些相当简单的事情,但我不确定这是否可行。
我想定义一个状态图,其中状态变量的形状是在模型加载时定义的,而不是在保存时。
举一个简单的例子--假设我只想计算一个时间差--即一个返回两个连续调用的输入之间的差的图。

func = load_tflite_model_func(tflite_model_file_path)
runtime_shape = 60, 80
rng = np.random.RandomState(1234)
ims = [rng.randn(*runtime_shape).astype(np.float32) for _ in range(3)]
assert np.allclose(func(ims[0]), ims[0])
assert np.allclose(func(ims[1]), ims[1]-ims[0])
assert np.allclose(func(ims[2]), ims[2]-ims[1])

现在,要创建和保存模型,我执行以下操作:

@dataclass
class TimeDelta(tf.Module):
    _last_val: Optional[tf.Tensor] = None
    def compute_delta(self, arr: tf.Tensor):
        if self._last_val is None:
            self._last_val = tf.Variable(tf.zeros(tf.shape(arr)))
        delta = arr-self._last_val
        self._last_val.assign(arr)
        return delta

compile_time_shape = 30, 40
# compile_time_shape = None, None  # Causes UnliftableError
tflite_model_file_path = tempfile.mktemp()
delta = TimeDelta()
save_signatures_to_tflite_model(
    {'delta': tf.function(delta.compute_delta, input_signature=[tf.TensorSpec(shape=compile_time_shape)])},
    path=tflite_model_file_path,
    parent_object=delta
)

当然,问题是如果我的编译时形状与运行时形状不同,它就会崩溃,试图用compile_time_shape = None, None动态地塑造图形也会失败,当我试图保存图形时会导致UnliftableError(因为它需要变量的具体维度)。
A full Colab-Notebook demonstrating the problem is here.
所以,总结一下,问题是:

    • 如何在tflite中保存有状态图,其中图的状态形状取决于输入的形状?**
62lalag4

62lalag41#

好吧,我找到了一个解决方案,这是不理想的,但做的工作:使变量的大小达到运行时所能想象到的最大值,然后取一个切片并赋值。Here is a modified notebook就是这样做的。
这种方法的一个缺点是,您最终会得到非常大的tflite文件(在我的例子中,24MB的零)。

相关问题