Keras粗糙Tensor迭代的正确方法

tjjdgumg  于 2022-12-04  发布在  其他
关注(0)|答案(1)|浏览(163)

我有一个输入的Tensorflow粗糙Tensor,结构类似于[batch num_images width height channels],我需要迭代维度num_images,以提取一些与下游应用相关的特征。示例代码如下:

from tensorflow.keras.applications.efficientnet import EfficientNetB7
from tensorflow.keras.layers import Input
import tensorflow as tf

eff_net = EfficientNetB7(weights='imagenet', include_top=False)
input_claim = Input(shape=(None, 600, 600, 3), name='input_1', ragged=True)
eff_out = tf.map_fn(fn=eff_net, 
                    elems=input_claim, fn_output_signature=tf.float32)

第一个Input维度被设置为None,因为它可以在数据点之间不同,并且由于这个原因,输入接收tf.RaggedTensor的示例。
这段代码以TypeError的方式中断TypeError: Could not build a TypeSpec for KerasTensor(type_spec=RaggedTensorSpec(TensorShape([None, None, 600, 600, 3]), tf.float32, 1, tf.int64), name='input_1', description="created by layer 'input_1'") of unsupported type <class 'keras.engine.keras_tensor.RaggedKerasTensor'>.我怀疑有更好的方法来执行这种类型的预处理
更新:需要num_images,因为(尽管这里没有描述)我正在此维度上执行以下reduce操作

63lcw9qa

63lcw9qa1#

您可以使用tf.ragged.map_flat_values来实现相同的
创建如下所示的模型:

def eff_net(x): #dummy eff_net for testing that returns [batch, dim]
    return tf.random.normal(shape=tf.shape(x)[:2])

input_claim = keras.Input(shape=(None, 600, 600, 3), name='input_1', ragged=True)

class RaggedMapLayer(layers.Layer):
    def call(self, x):
        return tf.ragged.map_flat_values(eff_net, x)

outputs = RaggedMapLayer()(input_claim)

model = keras.Model(inputs=input_claim, outputs=outputs)

测试,

inputs = tf.RaggedTensor.from_row_splits( tf.random.normal(shape=(10, 600, 600, 3)), row_splits=[0, 2, 5,10])
#shape [3, None, 600, 600, 3]

model(inputs).shape
#[3, None, 600]

相关问题