如何绘制Keras/Tensorflow子类化API模型?

uttx8gqw  于 2023-05-23  发布在  其他
关注(0)|答案(4)|浏览(192)

我使用Keras Subclassing API创建了一个正确运行的模型。model.summary()也能正常工作。当尝试使用tf.keras.utils.plot_model()来可视化模型的架构时,它只会输出以下图像:

这几乎感觉像是Keras开发团队的一个笑话。这是完整的架构:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.datasets import load_diabetes
import tensorflow as tf
tf.keras.backend.set_floatx('float64')
from tensorflow.keras.layers import Dense, GaussianDropout, GRU, Concatenate, Reshape
from tensorflow.keras.models import Model

X, y = load_diabetes(return_X_y=True)

data = tf.data.Dataset.from_tensor_slices((X, y)).\
    shuffle(len(X)).\
    map(lambda x, y: (tf.divide(x, tf.reduce_max(x)), y))

training = data.take(400).batch(8)
testing = data.skip(400).map(lambda x, y: (tf.expand_dims(x, 0), y))

class NeuralNetwork(Model):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.dense1 = Dense(16, input_shape=(10,), activation='relu', name='Dense1')
        self.dense2 = Dense(32, activation='relu', name='Dense2')
        self.resha1 = Reshape((1, 32))
        self.gru1 = GRU(16, activation='tanh', recurrent_dropout=1e-1)
        self.dense3 = Dense(64, activation='relu', name='Dense3')
        self.gauss1 = GaussianDropout(5e-1)
        self.conca1 = Concatenate()
        self.dense4 = Dense(128, activation='relu', name='Dense4')
        self.dense5 = Dense(1, name='Dense5')

    def call(self, x, *args, **kwargs):
        x = self.dense1(x)
        x = self.dense2(x)
        a = self.resha1(x)
        a = self.gru1(a)
        b = self.dense3(x)
        b = self.gauss1(b)
        x = self.conca1([a, b])
        x = self.dense4(x)
        x = self.dense5(x)
        return x

skynet = NeuralNetwork()
skynet.build(input_shape=(None, 10))
skynet.summary()

model = tf.keras.utils.plot_model(model=skynet,
         show_shapes=True, to_file='/home/nicolas/Desktop/model.png')
kd3sttzy

kd3sttzy1#

我发现了一些变通方法,可以使用模型子类化API进行绘图。由于明显的原因,子类API不支持顺序或函数API,如model.summary()和使用plot_model的良好可视化。在这里,我将展示两者。

class my_model(keras.Model):
    def __init__(self, dim):
        super(my_model, self).__init__()
        self.Base  = keras.keras.applications.VGG16(
            input_shape=(dim), 
            include_top = False, 
            weights = 'imagenet'
        )
        self.GAP   = L.GlobalAveragePooling2D()
        self.BAT   = L.BatchNormalization()
        self.DROP  = L.Dropout(rate=0.1)
        self.DENS  = L.Dense(256, activation='relu', name = 'dense_A')
        self.OUT   = L.Dense(1, activation='sigmoid')
    
    def call(self, inputs):
        x  = self.Base(inputs)
        g  = self.GAP(x)
        b  = self.BAT(g)
        d  = self.DROP(b)
        d  = self.DENS(d)
        return self.OUT(d)
    
    # AFAIK: The most convenient method to print model.summary() 
    # similar to the sequential or functional API like.
    def build_graph(self):
        x = Input(shape=(dim))
        return Model(inputs=[x], outputs=self.call(x))

dim = (124,124,3)
model = my_model((dim))
model.build((None, *dim))
model.build_graph().summary()

它将产生如下内容:

Layer (type)                 Output Shape              Param #   
=================================================================
input_67 (InputLayer)        [(None, 124, 124, 3)]     0         
_________________________________________________________________
vgg16 (Functional)           (None, 3, 3, 512)         14714688  
_________________________________________________________________
global_average_pooling2d_32  (None, 512)               0         
_________________________________________________________________
batch_normalization_7 (Batch (None, 512)               2048      
_________________________________________________________________
dropout_5 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_A (Dense)              (None, 256)               402192    
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 785       
=================================================================
Total params: 14,848,321
Trainable params: 14,847,297
Non-trainable params: 1,024

现在,通过使用build_graph函数,我们可以简单地绘制整个架构。

# Just showing all possible argument for newcomer.  
tf.keras.utils.plot_model(
    model.build_graph(),                      # here is the trick (for now)
    to_file='model.png', dpi=96,              # saving  
    show_shapes=True, show_layer_names=True,  # show shapes and layer name
    expand_nested=False                       # will show nested block
)

它将产生如下内容:- )

类似QnA:

  1. Retrieving Keras Layer Properties from a tf.keras.Model
  2. Visualize nested keras.Model (SubClassed API) GAN model
yftpprvb

yftpprvb2#

另一个解决方法:使用tf2onnx将savemodel格式模型转换为onnx,然后使用netron查看模型架构。
下面是netron中模型的一部分:

l2osamch

l2osamch3#

更新(2021年1月4日):这似乎是可能的;见@M.Innat的answer

这是不可能做到的,因为基本上模型子类化,因为它是在TensorFlow中实现的,与使用Functional/Sequential API(在TF术语中称为Graph networks)创建的模型相比,在特性和功能方面受到限制。如果检查plot_model源代码,您将在model_to_dot函数(由plot_model调用)中看到以下检查:

if not model._is_graph_network:
  node = pydot.Node(str(id(model)), label=model.name)
  dot.add_node(node)
  return dot

正如我提到的,子类模型不是图网络,因此只有包含模型名称的节点才会为这些模型绘制(即你看到的一样)。
这已经在Github issue中讨论过了,TensorFlow的一位开发人员通过给出以下参数证实了这种行为:
@omalleyt12评论:
是的,一般来说,我们不能假设任何关于子类模型的结构。如果你的模型可以被看作是层的块,并且你希望像这样可视化它,我们建议你查看Functional API

ef1yzkbh

ef1yzkbh4#

我创建了一个github仓库来演示我的解决方案:https://github.com/Meidozuki/light-keras-plot
同样的问题我遇到过好几次。首先,我也使用Model(inputs=[x], outputs=self.call(x))。但随着时间的推移,每次我想绘制一个新的模型,我需要改变输入的形状,所以我找到了一种方法来自动捕捉输入的形状。
我让它只显示一次。
使用方式

@plotable()
def build(self,input_shape):
    super().build(input_shape)

何处

def plotable(silent=False):
    '''
    Used on model.build to call tf.keras.utils.plot_model
    '''
    
    def decorate(func):
        @wraps(func)
        def wrapper(self,input_shape):
            result=func(self,input_shape)

            if not silent:
                from tensorflow.keras import layers
                from IPython.display import display
                if isinstance(input_shape,(tuple,tf.TensorShape)):
                    inputs=layers.Input(input_shape[1:])
                elif isinstance(input_shape,list):
                    inputs=[layers.Input(s[1:]) for s in input_shape]
                else:
                    raise AssertionError

                outputs=self.call(inputs)
                model=tf.keras.Model(inputs=inputs,outputs=outputs)
                display(tf.keras.utils.plot_model(model,show_shapes=True))
            return result
        return wrapper
    return decorate

相关问题