pytorch 了解nlp中的torch.nn.LayerNorm

wfveoks0  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(263)

我试图理解torch.nn.LayerNorm在nlp模型中是如何工作的。假设输入数据是一批单词序列的嵌入:

batch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)
print("x: ", embedding)

layer_norm = torch.nn.LayerNorm(dim)
print("y: ", layer_norm(embedding))

# outputs:

"""
x:  tensor([[[ 0.5909,  0.1326,  0.8100,  0.7631],
         [ 0.5831, -1.7923, -0.1453, -0.6882],
         [ 1.1280,  1.6121, -1.2383,  0.2150]],

        [[-0.2128, -0.5246, -0.0511,  0.2798],
         [ 0.8254,  1.2262, -0.0252, -1.9972],
         [-0.6092, -0.4709, -0.8038, -1.2711]]])
y:  tensor([[[ 0.0626, -1.6495,  0.8810,  0.7060],
         [ 1.2621, -1.4789,  0.4216, -0.2048],
         [ 0.6437,  1.0897, -1.5360, -0.1973]],

        [[-0.2950, -1.3698,  0.2621,  1.4027],
         [ 0.6585,  0.9811, -0.0262, -1.6134],
         [ 0.5934,  1.0505, -0.0497, -1.5942]]],
       grad_fn=<NativeLayerNormBackward0>)
"""

根据document's description,我的理解是平均值和标准差是由每个样本的所有嵌入值计算的,所以我尝试手动计算y[0, 0, :]

mean = torch.mean(embedding[0, :, :])
std = torch.std(embedding[0, :, :])
print((embedding[0, 0, :] - mean) / std)

它给出的是tensor([ 0.4310, -0.0319, 0.6523, 0.6050]),这不是正确的输出,我想知道计算y[0, 0, :]的正确方法是什么?

kpbwa7wx

kpbwa7wx1#

正如我所看到的,平均值和标准差是在(batch_size, seq_size, embedding_dim)形状的每个样本的所有嵌入值上计算的。但在您的情况下,通过torch.mean(embedding[0, :, :])和标准差,每个批次而不是每个样本有一个单一的平均值。
torch.nn.LayerNorm中使用elementwise_affine = False,在没有grad属性的情况下也会产生相同的结果。在这里可以找到一个类似的问题和答案,即layer Normalization in pytorch?

代码
import torch

batch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)
print("x: ", embedding)

layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)
print("y: ", layer_norm(embedding))

# print(embedding[0, :, :].shape)

# print(embedding[0, :, :])

eps: float = 0.00001
mean = torch.mean(embedding[0, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[0, :, :] - mean).mean(dim=(-1), keepdim=True)

print("mean: ", mean)
print("y_custom: ", (embedding[0, :, :] - mean) / torch.sqrt(var + eps))

# std = torch.std(embedding[0, :, :], dim=(-1), keepdim=True)

# print((embedding[0, :, :] - mean) / std)
输出
x:  tensor([[[-0.0465,  1.8510, -0.6282,  0.4107],
         [ 0.0424, -1.1447,  0.8040, -0.9578],
         [ 0.8864,  0.4136,  0.7011,  0.5104]],

        [[ 1.1509, -0.2508,  1.0221, -1.3924],
         [ 1.1902,  1.2089,  1.5641, -0.6508],
         [-0.7237,  1.1343, -0.6231, -0.4966]]])
y:  tensor([[[-0.4835,  1.5862, -1.1180,  0.0152],
         [ 0.4525, -1.0546,  1.4195, -0.8173],
         [ 1.4233, -1.1797,  0.4034, -0.6471]],

        [[ 0.9822, -0.3696,  0.8580, -1.4705],
         [ 0.4177,  0.4394,  0.8492, -1.7063],
         [-0.7176,  1.7223, -0.5854, -0.4193]]])
mean:  tensor([[ 0.3968],
        [-0.3140],
        [ 0.6279]])
y_custom:  tensor([[-0.4835,  1.5862, -1.1180,  0.0152],
        [ 0.4525, -1.0546,  1.4195, -0.8173],
        [ 1.4233, -1.1797,  0.4034, -0.6471]])
自定义层定额实施示例
import torch

batch_size, seq_size, seq_dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, seq_dim)
print("x: ", embedding)
print(embedding.shape)
print()

layer_norm = torch.nn.LayerNorm(seq_dim, elementwise_affine=False)
layer_norm_output = layer_norm(embedding)
print("y: ", layer_norm_output)
print(layer_norm_output.shape)
print()

def custom_layer_norm(
        x: torch.Tensor, dim: int = -1, eps: float = 0.00001
) -> torch.Tensor:
    mean = torch.mean(x, dim=(dim,), keepdim=True)
    var = torch.square(x - mean).mean(dim=(dim,), keepdim=True)
    return (x - mean) / torch.sqrt(var + eps)

custom_layer_norm_output = custom_layer_norm(embedding)
print("y_custom: ", custom_layer_norm_output)
print(custom_layer_norm_output.shape)

assert torch.allclose(layer_norm_output, custom_layer_norm_output), 'Tensors do not match.'
输出
x:  tensor([[[-0.9086,  0.4347,  0.1375,  1.5539],
         [ 0.6767,  1.3798, -0.8703, -2.0393],
         [-0.2406,  1.5617, -0.5317, -0.3856]],

        [[ 0.6483, -2.7647, -0.5750, -0.2885],
         [ 1.4656,  0.5840, -0.4808,  0.2106],
         [-0.2964,  1.6529,  1.6285, -0.5499]]])
torch.Size([2, 3, 4])

y:  tensor([[[-1.3829,  0.1485, -0.1902,  1.4246],
         [ 0.6682,  1.1960, -0.4933, -1.3710],
         [-0.4020,  1.7193, -0.7446, -0.5727]],

        [[ 1.1140, -1.6148,  0.1359,  0.3649],
         [ 1.4534,  0.1981, -1.3180, -0.3336],
         [-0.8738,  1.0080,  0.9844, -1.1186]]])
torch.Size([2, 3, 4])

y_custom:  tensor([[[-1.3829,  0.1485, -0.1902,  1.4246],
         [ 0.6682,  1.1960, -0.4933, -1.3710],
         [-0.4020,  1.7193, -0.7446, -0.5727]],

        [[ 1.1140, -1.6148,  0.1359,  0.3649],
         [ 1.4534,  0.1981, -1.3180, -0.3336],
         [-0.8738,  1.0080,  0.9844, -1.1186]]])
torch.Size([2, 3, 4])

相关问题