pytorch 焊炬点火错误前钩

vbopmzt1  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(114)

我试图修改HF UNet的扩散模型。我是通过在down和up块中添加条件来实现的。这是问题的一个最小例子。看起来最后一个down_block在第一个down_block之前被触发。UNet模型的源代码中没有任何内容表明它应该以相反的顺序启动。

import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers import UNet2DConditionModel

# config
SD_MODEL = "runwayml/stable-diffusion-v1-5"
DIM = 15

unet = UNet2DConditionModel.from_pretrained(SD_MODEL, subfolder="unet")

bs = 2
timestep = torch.randint(0, 100, (bs,))
noise = torch.randn((bs, 4, 64, 64))
text_encoding = torch.randn((bs, 77, 768))
condition = torch.randn((bs, DIM))

DownOutput = tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]

class ConditionResnet(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.call_count = 0
        self.projector = nn.Linear(in_dim, out_dim)
        self.conv1 = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1)
        self.non_linearity = F.silu
                        
    def forward(self, out: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        self.call_count += 1
        input_vector = out
        out = self.conv1(out) + self.projector(condition)[:, :, None, None]
        return input_vector + self.non_linearity(out)

# down blocks return tuples, so need slightly modified version  
class ConditionResnetDown(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.condition_resnet = ConditionResnet(in_dim, out_dim)
        
    def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
        return self.condition_resnet(x[0], condition), x[1]

class UNetWithConditions(nn.Module):
    def __init__(self, unet: nn.Module, col_channels: int, down_block_sizes: list[int], up_block_sizes: list[int]):
        super().__init__()
        self.unet = unet
        self.down_block_condition_resnets = nn.ModuleList([ConditionResnetDown(col_channels, out_channel) for out_channel in down_block_sizes])
        self.up_block_condition_resnets = nn.ModuleList([ConditionResnet(col_channels, out_channel) for out_channel in up_block_sizes])
        
        self.condition = None
        
        # forward hooks
        for i in range(len(self.unet.down_blocks)):
            self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
        for i in range(len(self.unet.up_blocks)):
            self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))
        
    def forward(self, noise, timestep, text_encoding, condition):
        self.condition = condition
        out = self.unet(noise, timestep, text_encoding).sample
        self.condition = None
        return out

unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
out2 = unet_with_conditions(noise, timestep, text_encoding, condition)

字符串
我知道最后一个down_block被触发的原因是因为我通过( [a.condition_resnet.call_count for a in unet_with_conditions.down_block_condition_resnets], [a.call_count for a in unet_with_conditions.up_block_condition_resnets], )查看ConditionResnetcall_count,我得到了这个:([0, 0, 0, 1], [0, 0, 0, 0])的值。

潜在原因

  • 如果模型是编译的(也许是用JIT),是否可能触发了错误的钩子?如果我修改上面的模型代码并执行它,由于某种原因,它似乎会与旧代码一起崩溃。
  • 不确定前向钩子是否允许传入额外的输入。

错误日志:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~tmp/ipykernel_574/3305635741.py in <cell line: 2>()
      1 unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
----> 2 out2 = unet_with_conditions(noise, timestep, text_encoding, condition)

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~tmp/ipykernel_574/2058376754.py in forward(self, noise, timestep, text_encoding, condition)
     59     def forward(self, noise, timestep, text_encoding, condition):
     60         self.condition = condition
---> 61         out = self.unet(noise, timestep, text_encoding).sample
     62         self.condition = None
     63         return out

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~nix/store/vzqny68wq33dcg4hkdala51n5vqhpnwc-python3-3.9.12/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)
    795         for downsample_block in self.down_blocks:
    796             if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 797                 sample, res_samples = downsample_block(
    798                     hidden_states=sample,
    799                     temb=emb,

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1213         if _global_forward_hooks or self._forward_hooks:
   1214             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1215                 hook_result = hook(self, input, result)
   1216                 if hook_result is not None:
   1217                     result = hook_result

~tmp/ipykernel_574/2058376754.py in <lambda>(module, inputs, outputs)
     53         # forward hooks
     54         for i in range(len(self.unet.down_blocks)):
---> 55             self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
     56         for i in range(len(self.unet.up_blocks)):
     57             self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~tmp/ipykernel_574/2058376754.py in forward(self, x, condition)
     40 
     41     def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
---> 42         return self.condition_resnet(x[0], condition), x[1]
     43 
     44 class UNetWithConditions(nn.Module):

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~tmp/ipykernel_574/2058376754.py in forward(self, out, condition)
     30         self.call_count += 1
     31         input_vector = out
---> 32         out = self.conv1(out) + self.projector(condition)[:, :, None, None]
     33         return input_vector + self.non_linearity(out)
     34 

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in forward(self, input)
    461 
    462     def forward(self, input: Tensor) -> Tensor:
--> 463         return self._conv_forward(input, self.weight, self.bias)
    464 
    465 class Conv3d(_ConvNd):

~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    457                             weight, bias, self.stride,
    458                             _pair(0), self.dilation, self.groups)
--> 459         return F.conv2d(input, weight, bias, self.stride,
    460                         self.padding, self.dilation, self.groups)
    461 

RuntimeError: Given groups=1, weight of size [1280, 1280, 3, 3], expected input[2, 320, 32, 32] to have 1280 channels, but got 320 channels instead

kr98yfug

kr98yfug1#

我看到您正在使用在注册forward hook的循环中定义的lambda functions

# forward hooks
for i in range(len(self.unet.down_blocks)):
    self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
for i in range(len(self.unet.up_blocks)):
    self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))

字符串
但是lambda函数通过引用捕获变量i。这意味着当执行lambda函数时,它们将使用i的当前值,而不是创建lambda函数时的值(“后期绑定”)。
由于i在循环的每次迭代中递增,因此在实际调用lambda函数时,i等于循环结束时的最终值。这就解释了为什么最后一个down_block在第一个之前被触发,因为所有的lambda函数都使用i的最终值。
为了解决这个问题,你可以在lambda函数中使用一个默认参数来捕获循环每次迭代中i的当前值,如下所示:

# forward hooks
for i in range(len(self.unet.down_blocks)):
    self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs, i=i: self.down_block_condition_resnets[i](outputs, self.condition))
for i in range(len(self.unet.up_blocks)):
    self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs, i=i: self.up_block_condition_resnets[i](outputs, self.condition))


这将为每个lambda函数捕获i的当前值,确保在以后调用这些函数时使用正确的索引。

相关问题