From 8101eff63ec9f70ebaa1dbe2bd7b17243e753ef8 Mon Sep 17 00:00:00 2001 From: Alec Date: Fri, 3 May 2019 11:48:41 -0700 Subject: [PATCH] update to gridsearch --- baseline.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/baseline.py b/baseline.py index 32da755..a424aed 100644 --- a/baseline.py +++ b/baseline.py @@ -5,8 +5,8 @@ import fire import numpy as np from scipy import sparse -from sklearn.model_selection import PredefinedSplit -from sklearn.linear_model import LogisticRegressionCV +from sklearn.model_selection import PredefinedSplit, GridSearchCV +from sklearn.linear_model import LogisticRegression from sklearn.feature_extraction.text import TfidfVectorizer def _load_split(data_dir, source, split, n=np.inf): @@ -35,10 +35,13 @@ def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=10000 valid_features = vect.transform(valid_texts) test_features = vect.transform(test_texts) - Cs = [1/64, 1/32, 1/16, 1/8, 1/4, 1/2, 1, 2, 4, 8, 16, 32, 64] + model = LogisticRegression(solver='liblinear') + params = {'C': [1/64, 1/32, 1/16, 1/8, 1/4, 1/2, 1, 2, 4, 8, 16, 32, 64]} split = PredefinedSplit([-1]*n_train+[0]*n_valid) - model = LogisticRegressionCV(Cs=Cs, cv=split, solver='liblinear', n_jobs=n_jobs, verbose=verbose, refit=False) - model.fit(sparse.vstack([train_features, valid_features]), train_labels+valid_labels) + search = GridSearchCV(model, params, cv=split, n_jobs=n_jobs, verbose=verbose, refit=False) + search.fit(sparse.vstack([train_features, valid_features]), train_labels+valid_labels) + model = model.set_params(**search.best_params_) + model.fit(train_features, train_labels) valid_accuracy = model.score(valid_features, valid_labels)*100. test_accuracy = model.score(test_features, test_labels)*100. data = {