我面临以下关于tf.reshape
的问题。
编码:
class ReshapeLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(ReshapeLayer, self).__init__(**kwargs)
def call(self, x):
assert len(x.shape) == 3
b, n, d = x.shape
x = tf.transpose(x, [1, 0, 2])
x = tf.reshape(x, (1, n, b*d))
return x
x = tf.random.normal((20, 10, 5))
ReshapeLayer()(x).shape
Tensor形状([1,10,100])
然而:
第一个
我知道模型创建期间的批处理大小是None
。我尝试了tf.shape(x)
来将尺寸视为变量,但它不起作用。我该如何解决这个异常?
最后,我希望在按照建议更改层定义后能够执行类似的操作。
class ReshapeLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(ReshapeLayer, self).__init__(**kwargs)
def call(self, x):
assert len(x.shape) == 3
n = tf.shape(x)[1] # x.shape[1] should work as well
x = tf.transpose(x, [1, 0, 2])
x = tf.reshape(x, (1, n, -1))
return x
x = tf.random.normal((20, 10, 5))
ReshapeLayer()(x).shape
现在,我想使用整形层的结果作为另一个层的输入,可能是自定义的层。
第一个
如何处理此错误?
干杯
1条答案
按热度按时间esbemjvw1#
这是可行的
你只需要访问
tf.shape
的结果有点不同,因为它返回一个Tensor。请注意,
model = tf.keras.Model(input=inputs, output=output)
崩溃,您需要使用model = tf.keras.Model(inputs=inputs, outputs=output)
(kwargs是复数)。最后,您还可以这样做
您可以在整形调用中使用
-1
,基本上告诉Tensorflow“自己想办法”。