keras transform()接受1个位置参数,但给出了2个

l7mqbcuq  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(154)

我正在尝试使用Grid Search调整CNN的超参数:

def create_model():
    model = Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation = 'relu', input_shape=(178, 268, 1)))
    ...
    model.compile(loss = 'binary_crossentropy', 
        optimizer = 'adam',
        metrics = ['acc'])

    return model
model = KerasClassifier(build_fn = create_model(), verbose = 1)
epochs = [10, 20, 30]
batch_size = [40, 60, 80, 100]
param_grid = dict(batch_size = batch_size, epochs = epochs)
grid = GridSearchCV(estimator = model, param_grid = param_grid, n_jobs = 3,error_score = "raise", cv = 3, scoring = "accuracy")
results = grid.fit(x_train, y_train)

我一直得到这个错误,我不知道我做错了什么。我需要给予x和y变量,但它说函数需要一个参数:

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Input In [57], in <cell line: 1>()
----> 1 results = grid.fit(x_train, y_train)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/sklearn/model_selection/_search.py:875, in BaseSearchCV.fit(self, X, y, groups, **fit_params)
    869     results = self._format_results(
    870         all_candidate_params, n_splits, all_out, all_more_results
    871     )
    873     return results
--> 875 self._run_search(evaluate_candidates)
   ...
-> 1389     evaluate_candidates(ParameterGrid(self.param_grid))
   ...
BaseSearchCV.fit.<locals>.evaluate_candidates(candidate_params, cv, more_results)
   ...
--> 822 out = parallel(
    823     delayed(_fit_and_score)(
    824         clone(base_estimator),
    825         X,
    826         y,
    827         train=train,
    828         test=test,
    829         parameters=parameters,
    830         split_progress=(split_idx, n_splits),
    831         candidate_progress=(cand_idx, n_candidates),
    832         **fit_and_score_kwargs,
    833     )
   ...
--> 1098     self.retrieve()
   ...
--> 975         self._output.extend(job.get(timeout=self.timeout))
    976     else:
    977         self._output.extend(job.get())
   ...
--> 567     return future.result(timeout=timeout)
   ...
--> 439     return self.__get_result()
   ...
--> 391         raise self._exception
    392     finally:
    393         # Break a reference cycle with the exception in self._exception
    394         self = None

TypeError: TargetReshaper.transform() takes 1 positional argument but 2 were given

有人能指出什么是错误的吗?谢谢。

ut6juiuv

ut6juiuv1#

请检查scikeras版本。最新版本0.10.0与python 3.11.3不兼容。您可能需要降级python安装。

相关问题