diff --git a/detector/README.md b/detector/README.md
new file mode 100644
index 0000000..503a97a
--- /dev/null
+++ b/detector/README.md
@@ -0,0 +1,49 @@
+GPT-2 Output Detector
+=====================
+
+This directory contains the code for working with the GPT-2 output detector model, obtained by fine-tuning a
+[RoBERTa model](https://ai.facebook.com/blog/roberta-an-optimized-method-for-pretraining-self-supervised-nlp-systems/)
+with [the outputs of the 1.5B-parameter GPT-2 model](https://github.com/openai/gpt-2-output-dataset).
+For motivations and discussions regarding the release of this detector model, please check out
+[out blog post](https://openai.com/blog/gpt-2-6-month-follow-up/) and [report](https://arxiv.org/abs/1908.09203).
+
+## Downloading a pre-trained detector model
+
+Download the weights for the fine-tuned `roberta-base` model (478 MB):
+
+```bash
+wget https://storage.googleapis.com/gpt-2/detector-models/v1/detector-base.pt
+```
+
+or `roberta-large` model (1.5 GB):
+
+```bash
+wget https://storage.googleapis.com/gpt-2/detector-models/v1/detector-large.pt
+```
+
+These RoBERTa-based models are fine-tuned with a mixture of temperature-1 and nucleus sampling outputs,
+which should generalize well to outputs generated using different sampling methods.
+
+## Running a detector model
+
+You can launch a web UI in which you can enter a text and see the detector model's prediction
+on whether or not it was generated by a GPT-2 model.
+
+```bash
+# (on the top-level directory of this repository)
+pip install -r requirements.txt
+python -m detector.server detector-base.pt
+```
+
+## Training a new detector model
+
+You can use the provided training script to train a detector model on a new set of datasets.
+We recommend using a GPU machine for this task.
+
+```bash
+# (on the top-level directory of this repository)
+pip install -r requirements.txt
+python -m detector.train
+```
+
+The training script supports a number of different options; append `--help` to the command above for usage.
diff --git a/detector/dataset.py b/detector/dataset.py
new file mode 100644
index 0000000..e5a739a
--- /dev/null
+++ b/detector/dataset.py
@@ -0,0 +1,86 @@
+import json
+import numpy as np
+from typing import List
+
+import torch
+from torch.utils.data import Dataset
+from tqdm import tqdm
+from transformers import PreTrainedTokenizer
+
+from .download import download
+
+
+def load_texts(data_file, expected_size=None):
+ texts = []
+
+ for line in tqdm(open(data_file), total=expected_size, desc=f'Loading {data_file}'):
+ texts.append(json.loads(line)['text'])
+
+ return texts
+
+
+class Corpus:
+ def __init__(self, name, data_dir='data', skip_train=False):
+ download(name, data_dir=data_dir)
+ self.name = name
+ self.train = load_texts(f'{data_dir}/{name}.train.jsonl', expected_size=250000) if not skip_train else None
+ self.test = load_texts(f'{data_dir}/{name}.test.jsonl', expected_size=5000)
+ self.valid = load_texts(f'{data_dir}/{name}.valid.jsonl', expected_size=5000)
+
+
+class EncodedDataset(Dataset):
+ def __init__(self, real_texts: List[str], fake_texts: List[str], tokenizer: PreTrainedTokenizer,
+ max_sequence_length: int = None, min_sequence_length: int = None, epoch_size: int = None,
+ token_dropout: float = None, seed: int = None):
+ self.real_texts = real_texts
+ self.fake_texts = fake_texts
+ self.tokenizer = tokenizer
+ self.max_sequence_length = max_sequence_length
+ self.min_sequence_length = min_sequence_length
+ self.epoch_size = epoch_size
+ self.token_dropout = token_dropout
+ self.random = np.random.RandomState(seed)
+
+ def __len__(self):
+ return self.epoch_size or len(self.real_texts) + len(self.fake_texts)
+
+ def __getitem__(self, index):
+ if self.epoch_size is not None:
+ label = self.random.randint(2)
+ texts = [self.fake_texts, self.real_texts][label]
+ text = texts[self.random.randint(len(texts))]
+ else:
+ if index < len(self.real_texts):
+ text = self.real_texts[index]
+ label = 1
+ else:
+ text = self.fake_texts[index - len(self.real_texts)]
+ label = 0
+
+ tokens = self.tokenizer.encode(text)
+
+ if self.max_sequence_length is None:
+ tokens = tokens[:self.tokenizer.max_len - 2]
+ else:
+ output_length = min(len(tokens), self.max_sequence_length)
+ if self.min_sequence_length:
+ output_length = self.random.randint(min(self.min_sequence_length, len(tokens)), output_length + 1)
+ start_index = 0 if len(tokens) <= output_length else self.random.randint(0, len(tokens) - output_length + 1)
+ end_index = start_index + output_length
+ tokens = tokens[start_index:end_index]
+
+ if self.token_dropout:
+ dropout_mask = self.random.binomial(1, self.token_dropout, len(tokens)).astype(np.bool)
+ tokens = np.array(tokens)
+ tokens[dropout_mask] = self.tokenizer.unk_token_id
+ tokens = tokens.tolist()
+
+ if self.max_sequence_length is None or len(tokens) == self.max_sequence_length:
+ mask = torch.ones(len(tokens) + 2)
+ return torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]), mask, label
+
+ padding = [self.tokenizer.pad_token_id] * (self.max_sequence_length - len(tokens))
+ tokens = torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id] + padding)
+ mask = torch.ones(tokens.shape[0])
+ mask[-len(padding):] = 0
+ return tokens, mask, label
diff --git a/detector/download.py b/detector/download.py
new file mode 100644
index 0000000..3d4f931
--- /dev/null
+++ b/detector/download.py
@@ -0,0 +1,49 @@
+import os
+
+import requests
+import torch.distributed as dist
+from tqdm import tqdm
+
+from .utils import distributed
+
+ALL_DATASETS = [
+ 'webtext',
+ 'small-117M', 'small-117M-k40', 'small-117M-nucleus',
+ 'medium-345M', 'medium-345M-k40', 'medium-345M-nucleus',
+ 'large-762M', 'large-762M-k40', 'large-762M-nucleus',
+ 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus'
+]
+
+
+def download(*datasets, data_dir='data'):
+ os.makedirs(data_dir, exist_ok=True)
+
+ if distributed() and dist.get_rank() > 0:
+ dist.barrier()
+
+ for ds in datasets:
+ assert ds in ALL_DATASETS, f'Unknown dataset {ds}'
+
+ for split in ['train', 'valid', 'test']:
+ filename = ds + "." + split + '.jsonl'
+ output_file = os.path.join(data_dir, filename)
+ if os.path.isfile(output_file):
+ continue
+
+ r = requests.get("https://storage.googleapis.com/gpt-2/output-dataset/v1/" + filename, stream=True)
+
+ with open(output_file, 'wb') as f:
+ file_size = int(r.headers["content-length"])
+ chunk_size = 1000
+ with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
+ # 1k for chunk_size, since Ethernet packet size is around 1500 bytes
+ for chunk in r.iter_content(chunk_size=chunk_size):
+ f.write(chunk)
+ pbar.update(chunk_size)
+
+ if distributed() and dist.get_rank() == 0:
+ dist.barrier()
+
+
+if __name__ == '__main__':
+ download(*ALL_DATASETS)
diff --git a/detector/index.html b/detector/index.html
new file mode 100644
index 0000000..f9cef0c
--- /dev/null
+++ b/detector/index.html
@@ -0,0 +1,154 @@
+
+
+
+GPT-2 Output Detector
+
+
+
+
+
GPT-2 Output Detector Demo
+
+ This is an online demo of the
+ GPT-2 output detector
+ model. Enter some text in the text box; the predicted probabilities will be displayed below.
+ The results start to get reliable after around 50 tokens.
+
+
+
+
+ Real |
+ |
+ Fake |
+
+
+ |
+ |
+ |
+
+
+
+
+
+
diff --git a/detector/server.py b/detector/server.py
new file mode 100644
index 0000000..90cdcf8
--- /dev/null
+++ b/detector/server.py
@@ -0,0 +1,120 @@
+import os
+import sys
+from http.server import HTTPServer, SimpleHTTPRequestHandler
+from multiprocessing import Process
+import subprocess
+from transformers import RobertaForSequenceClassification, RobertaTokenizer
+import json
+import fire
+import torch
+from urllib.parse import urlparse, unquote
+
+
+model: RobertaForSequenceClassification = None
+tokenizer: RobertaTokenizer = None
+device: str = None
+
+def log(*args):
+ print(f"[{os.environ.get('RANK', '')}]", *args, file=sys.stderr)
+
+
+class RequestHandler(SimpleHTTPRequestHandler):
+
+ def do_GET(self):
+ query = unquote(urlparse(self.path).query)
+
+ if not query:
+ self.begin_content('text/html')
+
+ html = os.path.join(os.path.dirname(__file__), 'index.html')
+ self.wfile.write(open(html).read().encode())
+ return
+
+ self.begin_content('application/json;charset=UTF-8')
+
+ tokens = tokenizer.encode(query)
+ all_tokens = len(tokens)
+ tokens = tokens[:tokenizer.max_len - 2]
+ used_tokens = len(tokens)
+ tokens = torch.tensor([tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]).unsqueeze(0)
+ mask = torch.ones_like(tokens)
+
+ with torch.no_grad():
+ logits = model(tokens.to(device), attention_mask=mask.to(device))[0]
+ probs = logits.softmax(dim=-1)
+
+ fake, real = probs.detach().cpu().flatten().numpy().tolist()
+
+ self.wfile.write(json.dumps(dict(
+ all_tokens=all_tokens,
+ used_tokens=used_tokens,
+ real_probability=real,
+ fake_probability=fake
+ )).encode())
+
+ def begin_content(self, content_type):
+ self.send_response(200)
+ self.send_header('Content-Type', content_type)
+ self.send_header('Access-Control-Allow-Origin', '*')
+ self.end_headers()
+
+ def log_message(self, format, *args):
+ log(format % args)
+
+
+def serve_forever(server, model, tokenizer, device):
+ log('Process has started; loading the model ...')
+ globals()['model'] = model.to(device)
+ globals()['tokenizer'] = tokenizer
+ globals()['device'] = device
+
+ log('Ready to serve')
+ server.serve_forever()
+
+
+def main(checkpoint, port=8080, device='cuda' if torch.cuda.is_available() else 'cpu'):
+ if checkpoint.startswith('gs://'):
+ print(f'Downloading {checkpoint}', file=sys.stderr)
+ subprocess.check_output(['gsutil', 'cp', checkpoint, '.'])
+ checkpoint = os.path.basename(checkpoint)
+ assert os.path.isfile(checkpoint)
+
+ print(f'Loading checkpoint from {checkpoint}')
+ data = torch.load(checkpoint, map_location='cpu')
+
+ model_name = 'roberta-large' if data['args']['large'] else 'roberta-base'
+ model = RobertaForSequenceClassification.from_pretrained(model_name)
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
+
+ model.load_state_dict(data['model_state_dict'])
+ model.eval()
+
+ print(f'Starting HTTP server on port {port}', file=sys.stderr)
+ server = HTTPServer(('0.0.0.0', port), RequestHandler)
+
+ # avoid calling CUDA API before forking; doing so in a subprocess is fine.
+ num_workers = int(subprocess.check_output(['python', '-c', 'import torch; print(torch.cuda.device_count())']))
+
+ if num_workers <= 1:
+ serve_forever(server, model, tokenizer, device)
+ else:
+ print(f'Launching {num_workers} worker processes...')
+
+ subprocesses = []
+
+ for i in range(num_workers):
+ os.environ['RANK'] = f'{i}'
+ os.environ['CUDA_VISIBLE_DEVICES'] = f'{i}'
+ process = Process(target=serve_forever, args=(server, model, tokenizer, device))
+ process.start()
+ subprocesses.append(process)
+
+ del os.environ['RANK']
+ del os.environ['CUDA_VISIBLE_DEVICES']
+
+ for process in subprocesses:
+ process.join()
+
+
+if __name__ == '__main__':
+ fire.Fire(main)
diff --git a/detector/train.py b/detector/train.py
new file mode 100644
index 0000000..ff8d79a
--- /dev/null
+++ b/detector/train.py
@@ -0,0 +1,305 @@
+"""Training code for the detector model"""
+
+import argparse
+import os
+import subprocess
+import sys
+from itertools import count
+from multiprocessing import Process
+
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel
+from torch.optim import Adam
+from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
+from tqdm import tqdm
+from transformers import *
+
+from .dataset import Corpus, EncodedDataset
+from .download import download
+from .utils import summary, distributed
+
+
+def setup_distributed(port=29500):
+ if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
+ return 0, 1
+
+ if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ:
+ from mpi4py import MPI
+ mpi_rank = MPI.COMM_WORLD.Get_rank()
+ mpi_size = MPI.COMM_WORLD.Get_size()
+
+ os.environ["MASTER_ADDR"] = '127.0.0.1'
+ os.environ["MASTER_PORT"] = str(port)
+
+ dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank)
+ return mpi_rank, mpi_size
+
+ dist.init_process_group(backend="nccl", init_method="env://")
+ return dist.get_rank(), dist.get_world_size()
+
+
+def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
+ max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None):
+ if fake_dataset == 'TWO':
+ download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir)
+ elif fake_dataset == 'THREE':
+ download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir)
+ else:
+ download(real_dataset, fake_dataset, data_dir=data_dir)
+
+ real_corpus = Corpus(real_dataset, data_dir=data_dir)
+
+ if fake_dataset == "TWO":
+ real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
+ fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']]
+ fake_train = sum([corpus.train for corpus in fake_corpora], [])
+ fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
+ elif fake_dataset == "THREE":
+ real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3
+ fake_corpora = [Corpus(name, data_dir=data_dir) for name in
+ ['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']]
+ fake_train = sum([corpus.train for corpus in fake_corpora], [])
+ fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
+ else:
+ fake_corpus = Corpus(fake_dataset, data_dir=data_dir)
+
+ real_train, real_valid = real_corpus.train, real_corpus.valid
+ fake_train, fake_valid = fake_corpus.train, fake_corpus.valid
+
+ Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler
+
+ min_sequence_length = 10 if random_sequence_length else None
+ train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length,
+ epoch_size, token_dropout, seed)
+ train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)
+
+ validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer)
+ validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))
+
+ return train_loader, validation_loader
+
+
+def accuracy_sum(logits, labels):
+ if list(logits.shape) == list(labels.shape) + [2]:
+ # 2-d outputs
+ classification = (logits[..., 0] < logits[..., 1]).long().flatten()
+ else:
+ classification = (logits > 0).long().flatten()
+ assert classification.shape == labels.shape
+ return (classification == labels).float().sum().item()
+
+
+def train(model: nn.Module, optimizer, device: str, loader: DataLoader, desc='Train'):
+ model.train()
+
+ train_accuracy = 0
+ train_epoch_size = 0
+ train_loss = 0
+
+ with tqdm(loader, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop:
+ for texts, masks, labels in loop:
+
+ texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
+ batch_size = texts.shape[0]
+
+ optimizer.zero_grad()
+ loss, logits = model(texts, attention_mask=masks, labels=labels)
+ loss.backward()
+ optimizer.step()
+
+ batch_accuracy = accuracy_sum(logits, labels)
+ train_accuracy += batch_accuracy
+ train_epoch_size += batch_size
+ train_loss += loss.item() * batch_size
+
+ loop.set_postfix(loss=loss.item(), acc=train_accuracy / train_epoch_size)
+
+ return {
+ "train/accuracy": train_accuracy,
+ "train/epoch_size": train_epoch_size,
+ "train/loss": train_loss
+ }
+
+
+def validate(model: nn.Module, device: str, loader: DataLoader, votes=1, desc='Validation'):
+ model.eval()
+
+ validation_accuracy = 0
+ validation_epoch_size = 0
+ validation_loss = 0
+
+ records = [record for v in range(votes) for record in tqdm(loader, desc=f'Preloading data ... {v}',
+ disable=dist.is_available() and dist.get_rank() > 0)]
+ records = [[records[v * len(loader) + i] for v in range(votes)] for i in range(len(loader))]
+
+ with tqdm(records, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop, torch.no_grad():
+ for example in loop:
+ losses = []
+ logit_votes = []
+
+ for texts, masks, labels in example:
+ texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
+ batch_size = texts.shape[0]
+
+ loss, logits = model(texts, attention_mask=masks, labels=labels)
+ losses.append(loss)
+ logit_votes.append(logits)
+
+ loss = torch.stack(losses).mean(dim=0)
+ logits = torch.stack(logit_votes).mean(dim=0)
+
+ batch_accuracy = accuracy_sum(logits, labels)
+ validation_accuracy += batch_accuracy
+ validation_epoch_size += batch_size
+ validation_loss += loss.item() * batch_size
+
+ loop.set_postfix(loss=loss.item(), acc=validation_accuracy / validation_epoch_size)
+
+ return {
+ "validation/accuracy": validation_accuracy,
+ "validation/epoch_size": validation_epoch_size,
+ "validation/loss": validation_loss
+ }
+
+
+def _all_reduce_dict(d, device):
+ # wrap in tensor and use reduce to gpu0 tensor
+ output_d = {}
+ for (key, value) in sorted(d.items()):
+ tensor_input = torch.tensor([[value]]).to(device)
+ torch.distributed.all_reduce(tensor_input)
+ output_d[key] = tensor_input.item()
+ return output_d
+
+
+def run(max_epochs=None,
+ device=None,
+ batch_size=24,
+ max_sequence_length=128,
+ random_sequence_length=False,
+ epoch_size=None,
+ seed=None,
+ data_dir='data',
+ real_dataset='webtext',
+ fake_dataset='xl-1542M-nucleus',
+ token_dropout=None,
+ large=False,
+ learning_rate=2e-5,
+ weight_decay=0,
+ **kwargs):
+ args = locals()
+ rank, world_size = setup_distributed()
+
+ if device is None:
+ device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'
+
+ print('rank:', rank, 'world_size:', world_size, 'device:', device)
+
+ import torch.distributed as dist
+ if distributed() and rank > 0:
+ dist.barrier()
+
+ model_name = 'roberta-large' if large else 'roberta-base'
+ tokenization_utils.logger.setLevel('ERROR')
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
+ model = RobertaForSequenceClassification.from_pretrained(model_name).to(device)
+
+ if rank == 0:
+ summary(model)
+ if distributed():
+ dist.barrier()
+
+ if world_size > 1:
+ model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True)
+
+ train_loader, validation_loader = load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
+ max_sequence_length, random_sequence_length, epoch_size,
+ token_dropout, seed)
+
+ optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
+ epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1)
+
+ logdir = os.environ.get("OPENAI_LOGDIR", "logs")
+ os.makedirs(logdir, exist_ok=True)
+
+ from torch.utils.tensorboard import SummaryWriter
+ writer = SummaryWriter(logdir) if rank == 0 else None
+ best_validation_accuracy = 0
+
+ for epoch in epoch_loop:
+ if world_size > 1:
+ train_loader.sampler.set_epoch(epoch)
+ validation_loader.sampler.set_epoch(epoch)
+
+ train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}')
+ validation_metrics = validate(model, device, validation_loader)
+
+ combined_metrics = _all_reduce_dict({**validation_metrics, **train_metrics}, device)
+
+ combined_metrics["train/accuracy"] /= combined_metrics["train/epoch_size"]
+ combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"]
+ combined_metrics["validation/accuracy"] /= combined_metrics["validation/epoch_size"]
+ combined_metrics["validation/loss"] /= combined_metrics["validation/epoch_size"]
+
+ if rank == 0:
+ for key, value in combined_metrics.items():
+ writer.add_scalar(key, value, global_step=epoch)
+
+ if combined_metrics["validation/accuracy"] > best_validation_accuracy:
+ best_validation_accuracy = combined_metrics["validation/accuracy"]
+
+ model_to_save = model.module if hasattr(model, 'module') else model
+ torch.save(dict(
+ epoch=epoch,
+ model_state_dict=model_to_save.state_dict(),
+ optimizer_state_dict=optimizer.state_dict(),
+ args=args
+ ),
+ os.path.join(logdir, "best-model.pt")
+ )
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--max-epochs', type=int, default=None)
+ parser.add_argument('--device', type=str, default=None)
+ parser.add_argument('--batch-size', type=int, default=24)
+ parser.add_argument('--max-sequence-length', type=int, default=128)
+ parser.add_argument('--random-sequence-length', action='store_true')
+ parser.add_argument('--epoch-size', type=int, default=None)
+ parser.add_argument('--seed', type=int, default=None)
+ parser.add_argument('--data-dir', type=str, default='data')
+ parser.add_argument('--real-dataset', type=str, default='webtext')
+ parser.add_argument('--fake-dataset', type=str, default='xl-1542M-k40')
+ parser.add_argument('--token-dropout', type=float, default=None)
+
+ parser.add_argument('--large', action='store_true', help='use the roberta-large model instead of roberta-base')
+ parser.add_argument('--learning-rate', type=float, default=2e-5)
+ parser.add_argument('--weight-decay', type=float, default=0)
+ args = parser.parse_args()
+
+ nproc = int(subprocess.check_output(['python', '-c', "import torch;"
+ "print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
+ if nproc > 1:
+ print(f'Launching {nproc} processes ...', file=sys.stderr)
+
+ os.environ["MASTER_ADDR"] = '127.0.0.1'
+ os.environ["MASTER_PORT"] = str(29500)
+ os.environ['WORLD_SIZE'] = str(nproc)
+ os.environ['OMP_NUM_THREAD'] = str(1)
+ subprocesses = []
+
+ for i in range(nproc):
+ os.environ['RANK'] = str(i)
+ os.environ['LOCAL_RANK'] = str(i)
+ process = Process(target=run, kwargs=vars(args))
+ process.start()
+ subprocesses.append(process)
+
+ for process in subprocesses:
+ process.join()
+ else:
+ run(**vars(args))
diff --git a/detector/utils.py b/detector/utils.py
new file mode 100644
index 0000000..7b636d6
--- /dev/null
+++ b/detector/utils.py
@@ -0,0 +1,62 @@
+import sys
+from functools import reduce
+
+from torch import nn
+import torch.distributed as dist
+
+
+def summary(model: nn.Module, file=sys.stdout):
+ def repr(model):
+ # We treat the extra repr like the sub-module, one item per line
+ extra_lines = []
+ extra_repr = model.extra_repr()
+ # empty string will be split into list ['']
+ if extra_repr:
+ extra_lines = extra_repr.split('\n')
+ child_lines = []
+ total_params = 0
+ for key, module in model._modules.items():
+ mod_str, num_params = repr(module)
+ mod_str = nn.modules.module._addindent(mod_str, 2)
+ child_lines.append('(' + key + '): ' + mod_str)
+ total_params += num_params
+ lines = extra_lines + child_lines
+
+ for name, p in model._parameters.items():
+ if hasattr(p, 'shape'):
+ total_params += reduce(lambda x, y: x * y, p.shape)
+
+ main_str = model._get_name() + '('
+ if lines:
+ # simple one-liner info, which most builtin Modules will use
+ if len(extra_lines) == 1 and not child_lines:
+ main_str += extra_lines[0]
+ else:
+ main_str += '\n ' + '\n '.join(lines) + '\n'
+
+ main_str += ')'
+ if file is sys.stdout:
+ main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
+ else:
+ main_str += ', {:,} params'.format(total_params)
+ return main_str, total_params
+
+ string, count = repr(model)
+ if file is not None:
+ if isinstance(file, str):
+ file = open(file, 'w')
+ print(string, file=file)
+ file.flush()
+
+ return count
+
+
+def grad_norm(model: nn.Module):
+ total_norm = 0
+ for p in model.parameters():
+ param_norm = p.grad.data.norm(2)
+ total_norm += param_norm.item() ** 2
+ return total_norm ** 0.5
+
+def distributed():
+ return dist.is_available() and dist.is_initialized()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..af666b1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,6 @@
+transformers>=2.0.0
+fire>=0.2.1
+requests>=2.22.0
+tqdm>=4.32.2
+torch>=1.2.0
+tensorboard>=1.14.0