keras 在'tf.function'内侦测到对'Model.fit'的呼叫,'Model.fit是一个管理其自身' tf.function '的高级终结点

txu3uszq  于 2023-08-06  发布在  其他
关注(0)|答案(1)|浏览(170)

你好Stack Overflow社区
我在Google Colaboratory notebook中尝试使用Keras训练LSTM模型时遇到了一个问题。目标是根据时间序列数据预测某些“1号机组”停运(“moh”)。然而,当我试图将模型拟合到数据时,我遇到了以下错误:

RuntimeError: Detected a call to `Model.fit` inside a `tf.function`. `Model.fit` is a high-level endpoint that manages its own `tf.function`. Please move the call to `Model.fit` outside of all enclosing `tf.function`s. Note that you can call a `Model` directly on `Tensor`s inside a `tf.function` like: `model(x)`.

字符串
下面是我使用的代码:

# Importing required libraries
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dropout, Dense
from tensorflow.keras.callbacks import EarlyStopping

# Define the LSTM model
def create_lstm_model(input_size, output_size, lstm_layer_sizes, dropout_rates):
    lstm_model = Sequential()
    for size, rate in zip(lstm_layer_sizes, dropout_rates):
        lstm_model.add(LSTM(units=size, return_sequences=True))
        lstm_model.add(Dropout(rate=rate))
    lstm_model.add(Dense(units=output_size))
    return lstm_model

# Set the parameters
input_size = 6
output_size = 3
unit = 'unit1'
outage = 'moh'
lstm_layer_sizes = (64,128,256,128,64)
dropout_rates = (0.05,0.05,0.05,0.05,0.05)

# Prepare the data (omitting data retrieval steps for brevity)
y = kinerja_df_extended_nanremoved_standardized[f'{unit}_{outage}s']
current_dates = kinerja_df_extended_nanremoved_standardized['date']
x = np.array([current_dates[i:i+input_size] for i in range(len(current_dates)-input_size+1)])
y = np.array([y[i:i+output_size] for i in range(len(y)-output_size+1)])

# Instantiate and compile the model
lstm_model = create_lstm_model(input_size=input_size, output_size=output_size, lstm_layer_sizes=lstm_layer_sizes, dropout_rates=dropout_rates)
lstm_model.compile(optimizer='adam', loss='mean_squared_error')

# The following line causes the error
history = lstm_model.fit(x=x, y=y, batch_size=1, epochs=128, validation_split=0.1, shuffle=False)

# Plot the training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()


我已经尝试在网上搜索解决方案,但我还没有找到任何解决这个特定错误在我的上下文中。如何解决这个问题并成功训练我的LSTM模型?
如有任何帮助,我们将不胜感激。谢谢你,谢谢
在这个修订版中,我提供了问题的清晰描述和相关代码,并提到您试图找到解决方案,但无法找到与您的特定场景匹配的解决方案。这应该会让你的帖子信息量更大,不太可能被标记为“主要是代码”

z5btuh9x

z5btuh9x1#

错误RuntimeError是由Model.fit在tf.function中调用www.example.com方法引起的。Model.fit方法是一个高级端点,它管理自己的tf.function,应该在所有封闭的tf.functions之外调用。
请尝试以下操作:

# Solution
# Move the call to `Model.fit` outside of all enclosing `tf.function`s
# Define the LSTM model
def create_lstm_model(input_size, output_size, lstm_layer_sizes, dropout_rates):
    lstm_model = Sequential()
    for size, rate in zip(lstm_layer_sizes, dropout_rates):
        lstm_model.add(LSTM(units=size, return_sequences=True))
        lstm_model.add(Dropout(rate=rate))
    lstm_model.add(Dense(units=output_size))
    return lstm_model

# Set the parameters
input_size = 6
output_size = 3
unit = 'unit1'
outage = 'moh'
lstm_layer_sizes = (64,128,256,128,64)
dropout_rates = (0.05,0.05,0.05,0.05,0.05)

# Prepare the data (omitting data retrieval steps for brevity)
y = kinerja_df_extended_nanremoved_standardized[f'{unit}_{outage}s']
current_dates = kinerja_df_extended_nanremoved_standardized['date']
x = np.array([current_dates[i:i+input_size] for i in range(len(current_dates)-input_size+1)])
y = np.array([y[i:i+output_size] for i in range(len(y)-output_size+1)])

# Instantiate and compile the model
lstm_model = create_lstm_model(input_size=input_size, output_size=output_size, lstm_layer_sizes=lstm_layer_sizes, dropout_rates=dropout_rates)
lstm_model.compile(optimizer='adam', loss='mean_squared_error')

# Train the model
history = lstm_model.fit(x=x, y=y, batch_size=1, epochs=128, validation_split=0.1, shuffle=False)

# Plot the training and validation loss
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

字符串

相关问题