在tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp中新增一个Attention的融合逻辑:
class fuse_multiheadattention_pass_19 : public fuse_multiheadattention_pass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
15 14
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 attn_mask
nn.Linear op_0 1 1 input 2 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight
Tensor.reshape op_1 1 1 2 3 shape=(%batch,%size,3,%num_heads,%feat_per_head)
torch.permute op_2 1 1 3 4 dims=(2,0,3,1,4)
torch.unbind op_3 1 3 4 5 6 7 dim=0
torch.permute op_4 1 1 6 8 dims=(0,1,3,2)
torch.matmul op_5 2 1 4 8 9
pnnx.Expression op_6 2 1 9 attn_mask 10 expr=add(div(@0,%sqrt_feat_per_head),@1)
F.softmax op_7 1 1 10 11 dim=-1
torch.matmul op_8 2 1 11 7 12
torch.permute op_9 1 1 12 13 dims=(0,2,1,3)
Tensor.reshape op_10 1 1 13 14 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 14 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 attn_mask
nn.MultiheadAttention attention 2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=True add_zero_attn=False add_bias_kv=False
pnnx.Output output 1 0 out
)PNNXIR";
}
bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int embed_dim = captured_params.at("embed_dim").i;
const int qkv_out_features = captured_params.at("qkv_out_features").i;
const int num_heads = captured_params.at("num_heads").i;
const int feat_per_head = captured_params.at("feat_per_head").i;
const float sqrt_feat_per_head = captured_params.at("sqrt_feat_per_head").f;
if (qkv_out_features != embed_dim * 3)
return false;
if (embed_dim != num_heads * feat_per_head)
return false;
if (!NearlyEqual(sqrt_feat_per_head, sqrt(feat_per_head), 0.001))
return false;
return true;
}
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
fuse_multiheadattention_pass::write(ops, captured_params, captured_attrs);
const int size = captured_params.at("size").i;
const int head = captured_params.at("num_heads").i;
Operator* op_attr = ops.at("attn_mask");
fprintf(stderr, "op_attr->attrs[data] type %d\n", op_attr->attrs["data"].type);
// hack attn_mask shape
op_attr->attrs["data"].shape = {1, head,size, size};
// hack attn_mask value
std::vector<char>& data = op_attr->attrs["data"].data;
size_t len = data.size();
data.resize(len * size);
for (int i = 1; i < size; i++)
{
memcpy(&data[len * i], &data[0], len);
}
}
};
并在fuse_multiheadattention函数中增加
fuse_multiheadattention_pass_19 r;
pnnx_graph_rewrite(graph, &r, opindex);
其他没有代码变化,重新编译
运行测试用例:
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
from einops import rearrange
from typing import Any, Optional, Tuple, Union
import math
class TSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x, atten_mask):
_, N, C = x.shape
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
attn = q.matmul(k.permute((0, 1, 3, 2)))
attn = attn * self.scale
attn = attn + atten_mask
attn = F.softmax(attn, dim=-1)
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x)
return x
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.attention_0_0 = TSelfAttention(embed_dim=768, num_heads=12)
def forward(self, x, attention_mask):
a = self.attention_0_0(x,attention_mask)
return a
def test():
net = Model()
net.eval()
torch.manual_seed(0)
x = torch.rand(2, 128, 768)
attention_mask = torch.rand(2,12,128, 128)
r = net(x, attention_mask)
# export torchscript
mod = torch.jit.trace(net, (x, attention_mask))
mod.save("test_bert_fused.pt")
# torchscript to pnnx
import os
os.system("../build/src/pnnx test_bert_fused.pt inputshape=[2,128,768],[2,12,128,128]")
return True
if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
会报错
python3 test_bert_fused.py > test.log
pnnxparam = test_bert_fused.pnnx.param
pnnxbin = test_bert_fused.pnnx.bin
pnnxpy = test_bert_fused_pnnx.py
pnnxonnx = test_bert_fused.pnnx.onnx
ncnnparam = test_bert_fused.ncnn.param
ncnnbin = test_bert_fused.ncnn.bin
ncnnpy = test_bert_fused_ncnn.py
fp16 = 1
optlevel = 2
device = cpu
inputshape = [2,128,768]f32,[2,12,128,128]f32
inputshape2 =
customop =
moduleop =
############# pass_level0
inline module = TSelfAttention
inline module = TSelfAttention
----------------
############# pass_level1
############# pass_level2
############# pass_level3
############# pass_level4
############# pass_level5
pnnx build without onnx-zero support, skip saving onnx
############# pass_ncnn
fallback batch axis 233 for operand 0
fallback batch axis 233 for operand 1
fallback batch axis 233 for operand 2
fallback batch axis 233 for operand 3
fallback batch axis 233 for operand 4
fallback batch axis 233 for operand 5
fallback batch axis 233 for operand 6
fallback batch axis 233 for operand 7
fallback batch axis 233 for operand 8
fallback batch axis 233 for operand 9
fallback batch axis 233 for operand 11
fallback batch axis 233 for operand 12
fallback batch axis 233 for operand 13
fallback batch axis 233 for operand 14
fallback batch axis 233 for operand 15
fallback batch axis 233 for operand pnnx_expr_8_mul(9,1.250000e-01)
fallback batch axis 233 for operand pnnx_expr_8_add(mul(9,1.250000e-01),1)
permute 5-rank tensor is not supported yet!
看上去是因为没有融合的原因吗?
5条答案
按热度按时间oyt4ldly1#
这里的输出 5 没有使用到?
svujldwt2#
这里的输出 5 没有使用到?
sorry
我改了一下:
python中这样保存
也是不对的
3wabscal3#
我应该怎么样调试才能正确的融合这种类型的算子呢?
inn6fuwd4#
经过调试发现这样是可以的:
感谢大佬!
不知道后续是否可以支持
或者能否提MR
rlcwz9us5#
做这个的原因是因为
目前的看到fuse_multiheadattention.cpp 中的attention mask都是attribute
按道理说attention mask应该是输入,因为有可能每次推理的mask都是不一样的