Подтвердить что ты не робот

(Python - sklearn) Как передать параметры для настройки класса ModelTransformer с помощью gridsearchcv

Ниже мой конвейер, и кажется, что я не могу передать параметры моим моделям с помощью класса ModelTransformer, который я беру из ссылки (http://zacstewart.com/2014/08/05/pipelines-of-featureunions-of-pipelines.html)

Сообщение об ошибке имеет смысл для меня, но я не знаю, как это исправить. Любая идея, как это исправить? Спасибо.

# define a pipeline
pipeline = Pipeline([
('vect', DictVectorizer(sparse=False)),
('scale', preprocessing.MinMaxScaler()),
('ess', FeatureUnion(n_jobs=-1, 
                     transformer_list=[
     ('rfc', ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1,  n_estimators=100))),
     ('svc', ModelTransformer(SVC(random_state=1))),],
                     transformer_weights=None)),
('es', EnsembleClassifier1()),
])

# define the parameters for the pipeline
parameters = {
'ess__rfc__n_estimators': (100, 200),
}

# ModelTransformer class. It takes it from the link
(http://zacstewart.com/2014/08/05/pipelines-of-featureunions-of-pipelines.html)
class ModelTransformer(TransformerMixin):
    def __init__(self, model):
        self.model = model
    def fit(self, *args, **kwargs):
        self.model.fit(*args, **kwargs)
        return self
    def transform(self, X, **transform_params):
        return DataFrame(self.model.predict(X))

grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1, refit=True)

Сообщение об ошибке: ValueError: недопустимый параметр n_estimators для оценщика ModelTransformer.

4b9b3361

Ответ 1

GridSearchCV имеет специальное соглашение об именах для вложенных объектов. В вашем случае ess__rfc__n_estimators означает ess.rfc.n_estimators и, согласно определению pipeline, он указывает на свойство n_estimators of

ModelTransformer(RandomForestClassifier(n_jobs=-1, random_state=1,  n_estimators=100)))

Очевидно, что экземпляры ModelTransformer не имеют такого свойства.

Исправить легко: для доступа к базовому объекту ModelTransformer необходимо использовать поле model. Таким образом, параметры сетки становятся

parameters = {
  'ess__rfc__model__n_estimators': (100, 200),
}

P.S. Это не единственная проблема с вашим кодом. Чтобы использовать несколько заданий в GridSearchCV, вам нужно сделать все объекты, которые вы используете для копирования. Это достигается с помощью методов get_params и set_params, вы можете брать их из BaseEstimator mixin.