Pytorch中的分类-预测时间融合Transformer

qgelzfjb  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(116)

我正在实现一个TFT模型,遇到了这个表:https://pytorch-forecasting.readthedocs.io/en/stable/models.html
它指出,TFT模型可以用于分类任务,这对我来说似乎不直观,因为它用于时间序列预测,这通常是一个回归任务。
我脑子里有两个问题:
1.你怎么看:使用TFT模型进行分类是否有意义?
1.我使用BCEWithLogitsLoss作为损失函数来实现它,并设置pos_weights参数以更高地加权正标签,因为零值膨胀数据集:

positives = np.sum(train_data['fridge'].values == 1)
negatives = np.sum(train_data['fridge'].values == 0)
positive_weight = torch.tensor(negatives/positives, dtype=torch.float)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=LEARNING_RATE,
    lstm_layers=2,
    hidden_size=16,
    attention_head_size=4,
    dropout=0.2,
    hidden_continuous_size=8,
    output_size=1,
    loss=convert_torchmetric_to_pytorch_forecasting_metric(
                                                           torch.nn.BCEWithLogitsLoss(
                                                           pos_weight=positive_weight)),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

字符串
然而,现在它预测的是负值,因为TFT模型使用ReLu作为激活函数,我不能改变它,所以类似于sigmoid。你知道如何克服这个问题并从TFT模型中获得可用的分类吗?

a0zr77ik

a0zr77ik1#

1-我注意到PyTorch Forecasting支持TFT的分类,但最初,TFT模型是为回归而设计的,而不是分类。一般来说,您应该将模型用于其预期目的。
2-PyTorch Forecasting提供了一个如何实现自定义分类模型的示例:
https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/building.html#Classification的网站。
您可以对TFT模型执行相同的操作:
https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting/models/temporal_fusion_transformer.html#TemporalFusionTransformer的网站。
另外,确保为两个类设置output_size=2

相关问题