PyTorch SimCSE损耗实现

n3schb8v  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(115)

我希望使用PyTorch的正负对来实现有监督的Simple Contrastive Learning of Sentence Embeddings(SimCSE)损失。

有没有一种方法可以用广播操作和/或矩阵乘法来向量化下面的简单实现?

import torch
import torch.nn.functional as F

batch_size = 4
feature_dim = 1024
h = -2*torch.randn(batch_size, 3, feature_dim)+1  # (batch dim, contrastive triplet, features dim)
temp = 10

num = torch.exp(F.cosine_similarity(h[:, 0, :], h[:, 1, :], dim=1) / temp)

denom = torch.empty_like(num)
for j in range(batch_size):
    denomjj = 0
    for jj in range(batch_size):
        denomjj += torch.exp(F.cosine_similarity(h[j, 0, :], h[jj, 1, :], dim=0) / temp)
        denomjj += torch.exp(F.cosine_similarity(h[j, 0, :], h[jj, 2, :], dim=0) / temp)
    denom[j] = denomjj

loss = -torch.log(num / denom)
jyztefdp

jyztefdp1#

当然.
你的分子看起来很好,所以我会矢量化分母。
我会尽可能地坚持你的符号:

norm_hi = torch.sqrt(torch.sum(torch.square(h[:, 0, :]), dim=1))
norm_hj_plus = torch.sqrt(torch.sum(torch.square(h[:, 1, :]), dim=1))
norm_hj_minus = torch.sqrt(torch.sum(torch.square(h[:, 2, :]), dim=1))

sim_denom1 = torch.outer(norm_hi, norm_hj_plus) * temp
sim_denom2 = torch.outer(norm_hi, norm_hj_minus) * temp

v1 = h[:, 0, :] @ h[:, 1, :].t() / sim_denom1
v2 = h[:, 0, :] @ h[:, 2, :].t() / sim_denom2

vec_denom = torch.sum(torch.exp(v1) + torch.exp(v2), dim=1)

你可以验证它是否像这样计算你的分母:

print(torch.allclose(loss, -torch.log(num / vec_denom)))

相关问题