diff --git a/baseline.py b/baseline.py index 795af1e..32da755 100644 --- a/baseline.py +++ b/baseline.py @@ -25,7 +25,7 @@ def load_split(data_dir, source, split, n=np.inf): labels = [0]*len(webtext)+[1]*len(gen) return texts, labels -def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=1000, n_jobs=-1, verbose=False): +def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=10000, n_jobs=None, verbose=False): train_texts, train_labels = load_split(data_dir, source, 'train', n=n_train) valid_texts, valid_labels = load_split(data_dir, source, 'valid', n=n_valid) test_texts, test_labels = load_split(data_dir, source, 'test')