我正在尝试使用Grid Search调整CNN的超参数:
def create_model():
model = Sequential()
model.add(layers.Conv2D(32, (3, 3), activation = 'relu', input_shape=(178, 268, 1)))
...
model.compile(loss = 'binary_crossentropy',
optimizer = 'adam',
metrics = ['acc'])
return model
model = KerasClassifier(build_fn = create_model(), verbose = 1)
epochs = [10, 20, 30]
batch_size = [40, 60, 80, 100]
param_grid = dict(batch_size = batch_size, epochs = epochs)
grid = GridSearchCV(estimator = model, param_grid = param_grid, n_jobs = 3,error_score = "raise", cv = 3, scoring = "accuracy")
results = grid.fit(x_train, y_train)
我一直得到这个错误,我不知道我做错了什么。我需要给予x和y变量,但它说函数需要一个参数:
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Input In [57], in <cell line: 1>()
----> 1 results = grid.fit(x_train, y_train)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/sklearn/model_selection/_search.py:875, in BaseSearchCV.fit(self, X, y, groups, **fit_params)
869 results = self._format_results(
870 all_candidate_params, n_splits, all_out, all_more_results
871 )
873 return results
--> 875 self._run_search(evaluate_candidates)
...
-> 1389 evaluate_candidates(ParameterGrid(self.param_grid))
...
BaseSearchCV.fit.<locals>.evaluate_candidates(candidate_params, cv, more_results)
...
--> 822 out = parallel(
823 delayed(_fit_and_score)(
824 clone(base_estimator),
825 X,
826 y,
827 train=train,
828 test=test,
829 parameters=parameters,
830 split_progress=(split_idx, n_splits),
831 candidate_progress=(cand_idx, n_candidates),
832 **fit_and_score_kwargs,
833 )
...
--> 1098 self.retrieve()
...
--> 975 self._output.extend(job.get(timeout=self.timeout))
976 else:
977 self._output.extend(job.get())
...
--> 567 return future.result(timeout=timeout)
...
--> 439 return self.__get_result()
...
--> 391 raise self._exception
392 finally:
393 # Break a reference cycle with the exception in self._exception
394 self = None
TypeError: TargetReshaper.transform() takes 1 positional argument but 2 were given
有人能指出什么是错误的吗?谢谢。
1条答案
按热度按时间ut6juiuv1#
请检查scikeras版本。最新版本0.10.0与python 3.11.3不兼容。您可能需要降级python安装。