update to gridsearch

This commit is contained in:
Alec 2019-05-03 11:48:41 -07:00
parent 284240f15d
commit 8101eff63e

View File

@ -5,8 +5,8 @@ import fire
import numpy as np
from scipy import sparse
from sklearn.model_selection import PredefinedSplit
from sklearn.linear_model import LogisticRegressionCV
from sklearn.model_selection import PredefinedSplit, GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
def _load_split(data_dir, source, split, n=np.inf):
@ -35,10 +35,13 @@ def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=10000
valid_features = vect.transform(valid_texts)
test_features = vect.transform(test_texts)
Cs = [1/64, 1/32, 1/16, 1/8, 1/4, 1/2, 1, 2, 4, 8, 16, 32, 64]
model = LogisticRegression(solver='liblinear')
params = {'C': [1/64, 1/32, 1/16, 1/8, 1/4, 1/2, 1, 2, 4, 8, 16, 32, 64]}
split = PredefinedSplit([-1]*n_train+[0]*n_valid)
model = LogisticRegressionCV(Cs=Cs, cv=split, solver='liblinear', n_jobs=n_jobs, verbose=verbose, refit=False)
model.fit(sparse.vstack([train_features, valid_features]), train_labels+valid_labels)
search = GridSearchCV(model, params, cv=split, n_jobs=n_jobs, verbose=verbose, refit=False)
search.fit(sparse.vstack([train_features, valid_features]), train_labels+valid_labels)
model = model.set_params(**search.best_params_)
model.fit(train_features, train_labels)
valid_accuracy = model.score(valid_features, valid_labels)*100.
test_accuracy = model.score(test_features, test_labels)*100.
data = {