将CNN从keras代码翻译为pytorch

fcg9iug3  于 2023-08-06  发布在  其他
关注(0)|答案(1)|浏览(99)

我正在尝试将下面的3layer CNN架构从keras翻译为pytorch。该模型的用途是从DNA序列(input_shape_hot)预测表达值(input_shape_val)。序列是一个热编码。该架构最初旨在连续训练模型CNN(3层)-FC(2层),在所有层之后应用批量归一化和权重丢弃,并在CNN层之后应用最大池化(参考论文,ref-code)。

def POC_model2(input_shape_hot,input_shape_val,DR):

    X_input1 = Input(shape = input_shape_hot)
    X_input2 = Input(shape = input_shape_val)
    # L 1: CONV
    X1 = Conv1D(filters=32, kernel_size=10, strides=1, activation='relu')(X_input1) 
    X1 = BatchNormalization()(X1)
    X1 = Dropout(DR)(X1)
    X1 = MaxPooling1D(pool_size=4, strides=4)(X1)
    # L 3: CONV
    X1 = Conv1D(filters=64, kernel_size=10, strides=1, activation='relu')(X_input1) 
    X1 = BatchNormalization()(X1)
    X1 = Dropout(DR)(X1)
    X1 = MaxPooling1D(pool_size=4, strides=4)(X1)
    # L 2: CONV
    X1 = Conv1D(filters=128, kernel_size=10, strides=1, activation='relu')(X_input1) 
    X1 = BatchNormalization()(X1)
    X1 = Dropout(DR)(X1)
    X1 = MaxPooling1D(pool_size=4, strides=4)(X1)

    X1 = Flatten()(X1)

    X1 = Concatenate(axis=1)([X1,X_input2])
    # fully connected
    X = Dense(64, activation='relu')(X1)
    X = BatchNormalization()(X)
    X = Dropout(DR)(X)

    X = Dense(1)(X)

    model = Model(inputs = [X_input1,X_input2], outputs = X)

    return model

字符串
我试着用这个代码:

from typing import List
class DNA_CNN_test2(nn.Module):
    def __init__(self,
                 seq_len: int,
                 num_filters: List[int] = [32, 64,128],
                 kernel_size: int = 3,
                 p = 0.2):
        super().__init__()
        self.seq_len = seq_len
        # CNN module
        self.conv_net = nn.Sequential()
        num_filters = [4] + num_filters
        for idx in range(len(num_filters) - 1):
            self.conv_net.add_module(
                f"conv_{idx}",
                nn.Conv1d(num_filters[idx], num_filters[idx + 1],
                          kernel_size=kernel_size, padding='same')
            )
            self.conv_net.add_module(f"relu_{idx}", nn.ReLU(inplace=True))
            self.conv_net.add_module(f"batchNor_{idx}",nn.BatchNorm1d(num_filters[idx + 1]))
            self.conv_net.add_module(f"dropout_{idx}",nn.Dropout(0.2))
            self.conv_net.add_module(f"MaxP_{idx}",nn.MaxPool1d(4,stride= 4))
        self.conv_net.add_module("flatten", nn.Flatten())
        self.conv_net.add_module("linear",nn.Linear(num_filters[-1]*seq_len, 1))
        
    def forward(self, xb: torch.Tensor):
        """Forward pass."""
        xb = xb.permute(0, 2, 1) 
        out = self.conv_net(xb)
        return out


并收到错误消息:

mat1 and mat2 shapes cannot be multiplied (2048x1920 and 128000x1)

ztmd8pv5

ztmd8pv51#

您忘记了BN和dropout的add_module函数:

self.conv_net.add_module(
  f'bn_{idx}',
  nn.BatchNorm1d(num_filters[idx + 1]),
)
self.conv_net.add_module(f'dropout_{idx}', nn.Dropout(p))

字符串

相关问题