如何使用pytorch多头关注进行分类任务?

acruukt9  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(106)

我有一个数据集,其中x形状是(10000,102,300),例如(samples,feature-length,dimension),y(10000,)是我的二进制标签。我想使用PyTorch来使用多头关注。我看过here的PyTorch文档,但没有关于如何使用它的说明。我如何使用我的数据集来使用多头关注进行分类?

vc9ivgsu

vc9ivgsu1#

我将为分类编写一个简单而漂亮的代码,这将很好地工作,如果您需要实现详细信息,则此部分与Transformer中的Encoder层相同,但在最后一部分中,您需要GlobalAveragePooling层和Dense层来进行分类

attention_layer = nn.MultiHeadAttion(300 , 300%num_of_heads==0,dropout=0.1)
neural_net_output = point_wise_neural_network(attention_layer)
normalize = LayerNormalization(input + neural_net_output)
globale_average_pooling = nn.GlobalAveragePooling(normalize)
nn.Linear(input , num_of_classes)(global_average_pooling)

相关问题