我正在实现一个TFT模型,遇到了这个表:https://pytorch-forecasting.readthedocs.io/en/stable/models.html
它指出,TFT模型可以用于分类任务,这对我来说似乎不直观,因为它用于时间序列预测,这通常是一个回归任务。
我脑子里有两个问题:
1.你怎么看:使用TFT模型进行分类是否有意义?
1.我使用BCEWithLogitsLoss
作为损失函数来实现它,并设置pos_weights
参数以更高地加权正标签,因为零值膨胀数据集:
positives = np.sum(train_data['fridge'].values == 1)
negatives = np.sum(train_data['fridge'].values == 0)
positive_weight = torch.tensor(negatives/positives, dtype=torch.float)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=LEARNING_RATE,
lstm_layers=2,
hidden_size=16,
attention_head_size=4,
dropout=0.2,
hidden_continuous_size=8,
output_size=1,
loss=convert_torchmetric_to_pytorch_forecasting_metric(
torch.nn.BCEWithLogitsLoss(
pos_weight=positive_weight)),
log_interval=10,
reduce_on_plateau_patience=4,
)
字符串
然而,现在它预测的是负值,因为TFT模型使用ReLu作为激活函数,我不能改变它,所以类似于sigmoid。你知道如何克服这个问题并从TFT模型中获得可用的分类吗?
1条答案
按热度按时间a0zr77ik1#
1-我注意到PyTorch Forecasting支持TFT的分类,但最初,TFT模型是为回归而设计的,而不是分类。一般来说,您应该将模型用于其预期目的。
2-PyTorch Forecasting提供了一个如何实现自定义分类模型的示例:
https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/building.html#Classification的网站。
您可以对TFT模型执行相同的操作:
https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting/models/temporal_fusion_transformer.html#TemporalFusionTransformer的网站。
另外,确保为两个类设置
output_size=2
。