fix n_valid default

This commit is contained in:
Alec 2019-05-03 11:32:28 -07:00
parent 32c71904ad
commit 284240f15d

View File

@ -25,7 +25,7 @@ def load_split(data_dir, source, split, n=np.inf):
labels = [0]*len(webtext)+[1]*len(gen)
return texts, labels
def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=1000, n_jobs=-1, verbose=False):
def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=10000, n_jobs=None, verbose=False):
train_texts, train_labels = load_split(data_dir, source, 'train', n=n_train)
valid_texts, valid_labels = load_split(data_dir, source, 'valid', n=n_valid)
test_texts, test_labels = load_split(data_dir, source, 'test')