pytorch CNN+LSTM用于视频分类

hgqdbh6s  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(129)

我试图产生一个模型,将接受多个视频帧作为输入,并提供一个标签作为输出(a.k.a.视频分类)。我在几个地方看到过类似于下面的代码来执行这个任务。但是我有一个困惑,因为在'out,hidden = self.lstm(x.unsqueeze(0))'行中,一旦for循环完成,out最终只会保存最后一帧的输出,因此在前向传递结束时返回的x将仅基于最后一帧,是吗?这种架构与单独处理最后一帧有何不同?首先,CNN-LSTM模型是解决这类问题的合适架构吗?

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet101

class CNNLSTM(nn.Module):
    def __init__(self, num_classes=2):
        super(CNNLSTM, self).__init__()
        self.resnet = resnet101(pretrained=True)
        self.resnet.fc = nn.Sequential(nn.Linear(self.resnet.fc.in_features, 300))
        self.lstm = nn.LSTM(input_size=300, hidden_size=256, num_layers=3)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, num_classes)
       
    def forward(self, x_3d):
        hidden = None
        for t in range(x_3d.size(1)):
            with torch.no_grad():
                x = self.resnet(x_3d[:, t])  
            out, hidden = self.lstm(x.unsqueeze(0))         

        x = self.fc1(out.squeeze())
        x = F.relu(x)
        x = self.fc2(x)
        return x

字符串

wi3ka0sx

wi3ka0sx1#

我在另一个论坛上得到了一些帮助。下面是一个架构,它将解决我所担心的问题。

class CNNLSTM(nn.Module):
def __init__(self, num_classes=2):
    super(CNNLSTM, self).__init__()
    self.resnet = resnet101(pretrained=True)
    self.resnet.fc = nn.Sequential(nn.Linear(self.resnet.fc.in_features, 300))
    self.lstm = nn.LSTM(input_size=300, hidden_size=256, num_layers=3)
    self.fc1 = nn.Linear(256, 128)
    self.fc2 = nn.Linear(128, num_classes)
   
def forward(self, x_3d):
    hidden = None

    # Iterate over each frame of a video in a video of batch * frames * channels * height * width
    for t in range(x_3d.size(1)):
        with torch.no_grad():
            x = self.resnet(x_3d[:, t])  
        # Pass latent representation of frame through lstm and update hidden state
        # Hidden state keeps record of information learned from prior frames
        out, hidden = self.lstm(x.unsqueeze(0), hidden)         

    # Get the last hidden state (hidden is a tuple with both hidden and cell state in it)
    x = self.fc1(hidden[0][-1])
    x = F.relu(x)
    x = self.fc2(x)

    return x

字符串

相关问题