tensorflow 如何解决“索引错误:元组索引超出范围”?

lhcgjxsq  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(182)

我尝试在变压器中做一个时间序列图预测。输入大小是(无,30)。但是,这里发生了一个错误。

x = layers.MultiHeadAttention(
      5 key_dim=1, num_heads=1, dropout=dropout
----> 6 )(inputs, inputs)
      7 x = layers.Dropout(dropout)(x)
      8 x = layers.LayerNormalization(epsilon=1e-6)(x)

此处发生错误。IndexError:元组索引超出范围
第一个
我尝试在变压器中做一个时间序列图预测。输入大小是(无,30)。但是,这里发生了一个错误。

dy1byipe

dy1byipe1#

进行以下更改,

X_train = tf.expand_dims(X_train, -1) #change your input
input_shape = X_train.shape[1:] #input shape should change to (30,1)
model_mlp = build_model(
    input_shape,
    head_size=256,
    num_heads=1,
    ff_dim=1,
    num_transformer_blocks=4,
    mlp_units=[128],
    mlp_dropout=0.4,
    dropout=0.25,
)

model_mlp.compile(optimizer = adam, loss = root_mean_squared_error)
model_mlp.summary()

然后检查,

model_mlp(X_train)

相关问题