看来我没有正确设计我的编码器,这就是为什么我需要Maven的意见,因为我是初学者变压器和DL模型设计。
我在编码器中有两种不同类型的变压器网络,如下所示:
的数据
每个分支的嵌入维数为256,它们通过线性层融合
self.fusion_head = nn.Linear(2*self.num_features, self.num_features) #self_num_features = 256
字符串
我在编码器中有一个向前的功能
def transformer_forward(self,x):
"""
:param x: The embeddings + pos_embed
:return:
"""
x_t1 = self.transformer_type1(x) # torch.Size([1, 1280, 256])
x_t2 = self.transformer_type2.forward(x) # torch.Size([1, 1280, 256])
# x = x_t1 + x_t2
x = torch.cat([x_t1,x_t2],dim=2)
x = self.fusion_head(x)
return x
型
然而,在训练模型并加载检查点之后,我意识到self.fusion_head
位于transformer_type1
模块之后
. 3.0.fn.to ','模块.编码器.Transformer_type1.3.layers.3.0.fn.to ','模块.编码器.Transformer_type1.3.layers.3.0.fn.to',' module.encoder.transformer_type1.3.layers.3.1.norm.weight ','module.encoder.transformer_type1.3.layers.3.1.norm.bias',' module.encoder.transformer_type1.3.layers.3.1.fn.net.0.weight ','module.encoder.transformer_type1.3.layers.3.1.fn.net.0.bias',' module.encoder.transformer_type1.3.layers.3.1.fn.net.3.weight ','module.encoder.transformer_type1.3.layers.3.1.fn.net.3.bias',' module.encoder.mlp_head.0.weight ',' module.encoder.mlp_head.0.bias','module.encoder.mlp_head.1.weight','module.encoder.mlp_head.1.bias',' module.encoder.fusion_head. weight ',' module.encoder.fusion_head. bias','module.encoder.transformer_type2.pos_embed','module.encoder.transformer_type2.patch_embed.proj.weight','module.encoder.transformer_type2.patch_embed.proj.bias',' module.encoder.transformer_type2.patch_embed.norm.weight ',' module.encoder.transformer_type2.patch_embed.norm.bias','module.encoder.transformer_type2.blocks.0.norm1.weight','module.encoder.transformer_type2.blocks.0.norm1.bias','module.encoder.transformer_type2.blocks.0.filter.complex_weight','module.encoder.transformer_type2.blocks.0.norm2.weight','module.encoder.transformer_type2.blocks.0.norm2.bias',' module.encoder.transformer_type2.blocks.0.mlp.fc1.weight ',.
这个级联层的位置(即,fusion_head
在forward函数中正确吗?为什么它被放置在transformet_type1
之后?fusion_head
层不应该在transformet_type1
和transformer_type2
之后吗?
1条答案
按热度按时间ia2d9nvy1#
你在这里看到的是nn.Module的
__repr__
的实现。它打印了你在网络的__init__
方法中注册的模块。forward
方法没有定义顺序,这是有意义的,因为你可以在forward
中多次调用模块,或者根本不调用模块。字符串
输出量:
型