keras KeyError:使用时间序列生成器时为96

vdgimpew  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(182)

我有一个数据,它有2个输入列和42个输出列。下面是我的代码:

data_columns=["value","average"]
prediction_columns=[]
for i in range(43):
    prediction_columns.append("s"+str(i))

from keras.preprocessing.sequence import TimeseriesGenerator

n_input = 24*4 #how many samples/rows/timesteps to look in the past in order to forecast the next sample
n_features= len(prediction_columns)#X_train.shape[1] # how many predictors/Xs/features we have to predict y
b_size = 7 # Number of timeseries samples in each batch
generator = TimeseriesGenerator(columns[data_columns], columns[prediction_columns], length=n_input, batch_size=b_size)

print(generator[0][0].shape)

这将失败,原因如下:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /usr/lib/python3.10/site-packages/pandas/core/indexes/base.py:3621, in Index.get_loc(self, key, method, tolerance)
   3620 try:
-> 3621     return self._engine.get_loc(casted_key)
   3622 except KeyError as err:

File /usr/lib/python3.10/site-packages/pandas/_libs/index.pyx:136, in pandas._libs.index.IndexEngine.get_loc()

File /usr/lib/python3.10/site-packages/pandas/_libs/index.pyx:163, in pandas._libs.index.IndexEngine.get_loc()

File pandas/_libs/hashtable_class_helper.pxi:5198, in pandas._libs.hashtable.PyObjectHashTable.get_item()

File pandas/_libs/hashtable_class_helper.pxi:5206, in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 96

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Input In [38], in <cell line: 8>()
      5 b_size = 7 # Number of timeseries samples in each batch
      6 generator = TimeseriesGenerator(columns[data_columns], columns[prediction_columns], length=n_input, batch_size=b_size)
----> 8 print(generator[0][0].shape)

File /usr/lib/python3.10/site-packages/keras/preprocessing/sequence.py:176, in __getitem__(self, index)
    172     rows = np.random.randint(
    173         self.start_index, self.end_index + 1, size=self.batch_size
    174     )
    175 else:
--> 176     i = self.start_index + self.batch_size * self.stride * index
    177     rows = np.arange(
    178         i,
    179         min(i + self.batch_size * self.stride, self.end_index + 1),
    180         self.stride,
    181     )
    183 samples = np.array(
    184     [
    185         self.data[row - self.length : row : self.sampling_rate]
    186         for row in rows
    187     ]
    188 )

File /usr/lib/python3.10/site-packages/keras/preprocessing/sequence.py:176, in <listcomp>(.0)
    172     rows = np.random.randint(
    173         self.start_index, self.end_index + 1, size=self.batch_size
    174     )
    175 else:
--> 176     i = self.start_index + self.batch_size * self.stride * index
    177     rows = np.arange(
    178         i,
    179         min(i + self.batch_size * self.stride, self.end_index + 1),
    180         self.stride,
    181     )
    183 samples = np.array(
    184     [
    185         self.data[row - self.length : row : self.sampling_rate]
    186         for row in rows
    187     ]
    188 )

File /usr/lib/python3.10/site-packages/pandas/core/frame.py:3505, in DataFrame.__getitem__(self, key)
   3503 if self.columns.nlevels > 1:
   3504     return self._getitem_multilevel(key)
-> 3505 indexer = self.columns.get_loc(key)
   3506 if is_integer(indexer):
   3507     indexer = [indexer]

File /usr/lib/python3.10/site-packages/pandas/core/indexes/base.py:3623, in Index.get_loc(self, key, method, tolerance)
   3621     return self._engine.get_loc(casted_key)
   3622 except KeyError as err:
-> 3623     raise KeyError(key) from err
   3624 except TypeError:
   3625     # If we have a listlike key, _check_indexing_error will raise
   3626     #  InvalidIndexError. Otherwise we fall through and re-raise
   3627     #  the TypeError.
   3628     self._check_indexing_error(key)

KeyError: 96

我想可能是我提供了无效的列名,但是columns[data_columns].head()columns[prediction_columns].head()都可以正常执行
列形状为(42749,45)
问题的根源可能是什么?

col17t5w

col17t5w1#

将 Dataframe 转换为numpy后,问题消失

generator = TimeseriesGenerator(work[data_columns].to_numpy(), work[prediction_columns].to_numpy(), length=n_input, batch_size=b_size)

相关问题