pytorch 我应该同时继承nn.Module和ABC吗?

ars1skjm  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(103)

我应该创建一个同时继承torch.nn.ModuleABC的类吗?调用ABC__init__()函数是否可以接受?(我想这是可以的,因为ABC类只是object的一个普通子类)
如果我应该使用NotImplemented方式,我如何决定何时使用哪种方式?
我使用AbstractModel初始化所有子模块的配置。

import torch.nn as nn
from abc import ABC, abstractmethod

class AbstractModel(nn.Module, ABC):
    def __init__(self, config):
        super().__init__()
        self.config = config
    
    @abstractmethod
    def generate(self):
        pass
    
class sub(AbstractMode):
    def __init__(self, config):
        super().__init__(config)

    def generate(self):
        print(self.config)

字符串

xxls0lw8

xxls0lw81#

我在一些基本模块中使用ABC,它也继承自nn.Module,到目前为止,我没有遇到任何问题。它们在提供广告方面没有冲突,正如您所说的,ABC是一个相当简单的类
你应该喜欢它而不是做NotImplemented
ABC表示
定义抽象基类(ABCs)的元类。使用此元类创建ABC。ABC可以直接子类化,然后作为一个混合类。
因此它充当了一个mixin类,这意味着它很小,轻量级,基本上为类添加了一个特定的功能,并且不应该与您继承的任何其他类发生冲突。

u7up0aaq

u7up0aaq2#

作为@Nopileos答案的补充,你应该使用NotImplementedErrorabc,因为它不允许继承类调用super()并使用这些方法。
一个例子可以是:

import abc

import torch

class Base(torch.nn.Module, abc.ABC):
    def __init__(self, out_features: int):
        super().__init__()
        self.out_channels = out_channels
        self.module = torch.nn.Sequential(
            torch.nn.LazyLinear(out_features * 2),
            torch.nn.GELU(),
            torch.nn.LazyLinear(out_features),
        )

    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError

class ResnetLike(Base):
    def forward(self, x: torch.Tensor):
        # super().forward(x) would raise an error correctly
        return self.module(x) + x

class TestNetwork(Base):
    def forward(self, x: torch.Tensor):
        return self.module(x) * 2

字符串

相关问题