mirror of
https://github.com/openai/gpt-2-output-dataset
synced 2025-08-29 13:28:06 +00:00
fix n_valid default
This commit is contained in:
parent
32c71904ad
commit
284240f15d
@ -25,7 +25,7 @@ def load_split(data_dir, source, split, n=np.inf):
|
|||||||
labels = [0]*len(webtext)+[1]*len(gen)
|
labels = [0]*len(webtext)+[1]*len(gen)
|
||||||
return texts, labels
|
return texts, labels
|
||||||
|
|
||||||
def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=1000, n_jobs=-1, verbose=False):
|
def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=10000, n_jobs=None, verbose=False):
|
||||||
train_texts, train_labels = load_split(data_dir, source, 'train', n=n_train)
|
train_texts, train_labels = load_split(data_dir, source, 'train', n=n_train)
|
||||||
valid_texts, valid_labels = load_split(data_dir, source, 'valid', n=n_valid)
|
valid_texts, valid_labels = load_split(data_dir, source, 'valid', n=n_valid)
|
||||||
test_texts, test_labels = load_split(data_dir, source, 'test')
|
test_texts, test_labels = load_split(data_dir, source, 'test')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user