因此,我正在做一个知觉实验,为此我设计了一个中等复杂度的刺激,它需要大量的计算来创建。它工作正常,但生成速度很慢。
快速头脑风暴:嘿,试试在Tensorflow中实现刺激计算吧!一点点的工作得到了回报,刺激生成时间缩短了一半(实际上快到可以实时显示),而且Tensorflow模型证明了更紧凑的设计。太棒了!
但后来我开始梦想。我希望能够通过一个应用程序进行试验,但不想用Swift或Java或任何移动的/网络的流行语言重做所有代码。但如果所有繁重的代码都嵌入在Tensorflow精简模型中,那么在合理的时间框架内,一个适用于iOS/Android/(可能是javascript)的小 Package 器将是可行的。
这就是我的问题所在。在离线刺激生成模型中,我配置参数,让它生成视频文件,然后让我的受试者观看。如果我的理论应用只是采用tensorflow 模型而不是视频文件,那么我实际上只是缩短了下载时间。我真正希望能够做的是在应用中调整刺激参数,而不是猜测。生成并再次上传。
因此(从这里开始,我只是在我的TF技能范围内即兴发挥),我把我的配置参数变成了tf.Variables,把它们插入到模型中,瞧,我现在可以在Python CLI中动态调整我的刺激了。太好了!现在我只需要保存模型...
哦,不好意思
保存tf.Variables是如何工作的?下面是我的代码的一个简单子集来演示这个问题。从一个计算时间输入的正弦的层开始,它具有可配置的相位、频率和幅度:
class Sine(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
super(Sine, self).__init__(*args, **kwargs)
self._twopi = tf.constant(np.pi * 2.0)
def call(self, parameters):
time = parameters[0]
scale = parameters[1]
frequency = parameters[2]
base = parameters[3]
phase = parameters[4]
# def call(self, time, scale, frequency, base, phase):
time = tf.cast(time, tf.float32)
return scale*tf.sin(self._twopi * frequency * time + phase) + base
- 请注意,这里我尝试将五个参数压缩到一个列表中,看看会发生什么 *。
这里有一个愚蠢的模型使用它:
class StupidModel:
def __init__(self, frequency, amplitude, base, phase):
self._frequency = tf.Variable(frequency, name="frequency", dtype=tf.float32)
self._amplitude = tf.Variable(amplitude, name="amplitude", dtype=tf.float32)
self._base = tf.Variable(base, name="base", dtype=tf.float32)
self._phase = tf.Variable(phase, name="phase", dtype=tf.float32)
self._model = self._build_model()
def _build_model(self):
input = tf.keras.layers.Input(1)
out = Sine()([input, self._frequency, self._amplitude, self._base, self._phase])
model = tf.keras.Model(inputs=[input], outputs = out)
model._myfrequency = self._frequency
model._myamplitude = self._amplitude
model._mybase = self._base
model._myphase = self._phase
return model
def __call__(self, time):
return self._model.predict(time)
当我试图保存它时,会发生以下情况:
>>> sm._model.save("foo.tf")
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: foo.tf/assets
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/j/.miniforge3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/j/.miniforge3/lib/python3.9/json/encoder.py", line 199, in encode
chunks = self.iterencode(o, _one_shot=True)
File "/Users/j/.miniforge3/lib/python3.9/json/encoder.py", line 257, in iterencode
return _iterencode(o, 0)
TypeError: Unable to serialize <tf.Variable 'frequency:0' shape=() dtype=float32, numpy=0.5> to JSON. Unrecognized type <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>.
我相信如果我使用TF1.X,我会看到占位符和提要字典。但是我现在能做什么呢?我知道我正在做的可能超出了正常的tensorflow 使用范围,例如没有训练,飞行中调整等,但是它一直工作得很好,直到保存...
1条答案
按热度按时间uqxowvwt1#
进一步的探索已经取得了成果。它看起来很复杂,我不敢想象我需要做什么来转换到Tensorflow Lite,但到目前为止,我可以:
1.添加任意配置参数
1.保存模型
1.加载模型
1.检索配置参数名称
1.设置配置参数
我发现的关键是使StupidModel继承自tf. Module。我添加了一个中间类来封装参数逻辑:
然后我更新了StupidModel以继承
Parameterized
,并确保添加了适当的@tf.function装饰器:现在,当我重新加载时:
而且我可以 checkout 我的参数(按名称),设置它们,并查看更改:
还有最后一点要注意:提供的
tf.function
参数允许我使用任意输入。(如果没有它,我只能调用带有参数签名的__call__
,该参数签名 * 完全 * 匹配 * 在保存之前 * 进行的调用,例如,如果我调用sm([1,2,3])然后保存,如果我调用sm,我会得到一个错误返回([10,20,30])。在@tf.function中指定签名可防止出现这种情况。):这就是我所学到的。它看起来很复杂,但我绝对愿意接受更简单的方法。