我试图修改HF UNet的扩散模型。我是通过在down和up块中添加条件来实现的。这是问题的一个最小例子。看起来最后一个down_block在第一个down_block之前被触发。UNet模型的源代码中没有任何内容表明它应该以相反的顺序启动。
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import UNet2DConditionModel
# config
SD_MODEL = "runwayml/stable-diffusion-v1-5"
DIM = 15
unet = UNet2DConditionModel.from_pretrained(SD_MODEL, subfolder="unet")
bs = 2
timestep = torch.randint(0, 100, (bs,))
noise = torch.randn((bs, 4, 64, 64))
text_encoding = torch.randn((bs, 77, 768))
condition = torch.randn((bs, DIM))
DownOutput = tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
class ConditionResnet(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.call_count = 0
self.projector = nn.Linear(in_dim, out_dim)
self.conv1 = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1)
self.non_linearity = F.silu
def forward(self, out: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
self.call_count += 1
input_vector = out
out = self.conv1(out) + self.projector(condition)[:, :, None, None]
return input_vector + self.non_linearity(out)
# down blocks return tuples, so need slightly modified version
class ConditionResnetDown(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.condition_resnet = ConditionResnet(in_dim, out_dim)
def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
return self.condition_resnet(x[0], condition), x[1]
class UNetWithConditions(nn.Module):
def __init__(self, unet: nn.Module, col_channels: int, down_block_sizes: list[int], up_block_sizes: list[int]):
super().__init__()
self.unet = unet
self.down_block_condition_resnets = nn.ModuleList([ConditionResnetDown(col_channels, out_channel) for out_channel in down_block_sizes])
self.up_block_condition_resnets = nn.ModuleList([ConditionResnet(col_channels, out_channel) for out_channel in up_block_sizes])
self.condition = None
# forward hooks
for i in range(len(self.unet.down_blocks)):
self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
for i in range(len(self.unet.up_blocks)):
self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))
def forward(self, noise, timestep, text_encoding, condition):
self.condition = condition
out = self.unet(noise, timestep, text_encoding).sample
self.condition = None
return out
unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
out2 = unet_with_conditions(noise, timestep, text_encoding, condition)
字符串
我知道最后一个down_block被触发的原因是因为我通过( [a.condition_resnet.call_count for a in unet_with_conditions.down_block_condition_resnets], [a.call_count for a in unet_with_conditions.up_block_condition_resnets], )
查看ConditionResnet
的call_count
,我得到了这个:([0, 0, 0, 1], [0, 0, 0, 0])
的值。
潜在原因
- 如果模型是编译的(也许是用JIT),是否可能触发了错误的钩子?如果我修改上面的模型代码并执行它,由于某种原因,它似乎会与旧代码一起崩溃。
- 不确定前向钩子是否允许传入额外的输入。
错误日志:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
~tmp/ipykernel_574/3305635741.py in <cell line: 2>()
1 unet_with_conditions = UNetWithConditions(unet, DIM, [320, 640, 1280, 1280], [1280, 1280, 640, 320])
----> 2 out2 = unet_with_conditions(noise, timestep, text_encoding, condition)
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~tmp/ipykernel_574/2058376754.py in forward(self, noise, timestep, text_encoding, condition)
59 def forward(self, noise, timestep, text_encoding, condition):
60 self.condition = condition
---> 61 out = self.unet(noise, timestep, text_encoding).sample
62 self.condition = None
63 return out
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~nix/store/vzqny68wq33dcg4hkdala51n5vqhpnwc-python3-3.9.12/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)
795 for downsample_block in self.down_blocks:
796 if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 797 sample, res_samples = downsample_block(
798 hidden_states=sample,
799 temb=emb,
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1213 if _global_forward_hooks or self._forward_hooks:
1214 for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1215 hook_result = hook(self, input, result)
1216 if hook_result is not None:
1217 result = hook_result
~tmp/ipykernel_574/2058376754.py in <lambda>(module, inputs, outputs)
53 # forward hooks
54 for i in range(len(self.unet.down_blocks)):
---> 55 self.unet.down_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.down_block_condition_resnets[i](outputs, self.condition))
56 for i in range(len(self.unet.up_blocks)):
57 self.unet.up_blocks[i].register_forward_hook(lambda module, inputs, outputs: self.up_block_condition_resnets[i](outputs, self.condition))
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~tmp/ipykernel_574/2058376754.py in forward(self, x, condition)
40
41 def forward(self, x: DownOutput, condition: torch.Tensor) -> DownOutput:
---> 42 return self.condition_resnet(x[0], condition), x[1]
43
44 class UNetWithConditions(nn.Module):
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~tmp/ipykernel_574/2058376754.py in forward(self, out, condition)
30 self.call_count += 1
31 input_vector = out
---> 32 out = self.conv1(out) + self.projector(condition)[:, :, None, None]
33 return input_vector + self.non_linearity(out)
34
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in forward(self, input)
461
462 def forward(self, input: Tensor) -> Tensor:
--> 463 return self._conv_forward(input, self.weight, self.bias)
464
465 class Conv3d(_ConvNd):
~app/creator_content_publish_server/models/template_predict_title_trainer/torch_wrapper_layer.runfiles/pypi_torch/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
457 weight, bias, self.stride,
458 _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
460 self.padding, self.dilation, self.groups)
461
RuntimeError: Given groups=1, weight of size [1280, 1280, 3, 3], expected input[2, 320, 32, 32] to have 1280 channels, but got 320 channels instead
型
1条答案
按热度按时间kr98yfug1#
我看到您正在使用在注册forward hook的循环中定义的lambda functions。
字符串
但是
lambda
函数通过引用捕获变量i
。这意味着当执行lambda
函数时,它们将使用i
的当前值,而不是创建lambda
函数时的值(“后期绑定”)。由于
i
在循环的每次迭代中递增,因此在实际调用lambda
函数时,i
等于循环结束时的最终值。这就解释了为什么最后一个down_block
在第一个之前被触发,因为所有的lambda函数都使用i
的最终值。为了解决这个问题,你可以在
lambda
函数中使用一个默认参数来捕获循环每次迭代中i
的当前值,如下所示:型
这将为每个
lambda
函数捕获i
的当前值,确保在以后调用这些函数时使用正确的索引。