From 6f40009b8942b42b1d1a52e91eb72e2c131acada Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Tue, 5 Nov 2019 07:32:08 +0100 Subject: [PATCH] detector model --- detector/README.md | 49 +++++++ detector/dataset.py | 86 ++++++++++++ detector/download.py | 49 +++++++ detector/index.html | 154 ++++++++++++++++++++++ detector/server.py | 120 +++++++++++++++++ detector/train.py | 305 +++++++++++++++++++++++++++++++++++++++++++ detector/utils.py | 62 +++++++++ requirements.txt | 6 + 8 files changed, 831 insertions(+) create mode 100644 detector/README.md create mode 100644 detector/dataset.py create mode 100644 detector/download.py create mode 100644 detector/index.html create mode 100644 detector/server.py create mode 100644 detector/train.py create mode 100644 detector/utils.py create mode 100644 requirements.txt diff --git a/detector/README.md b/detector/README.md new file mode 100644 index 0000000..503a97a --- /dev/null +++ b/detector/README.md @@ -0,0 +1,49 @@ +GPT-2 Output Detector +===================== + +This directory contains the code for working with the GPT-2 output detector model, obtained by fine-tuning a +[RoBERTa model](https://ai.facebook.com/blog/roberta-an-optimized-method-for-pretraining-self-supervised-nlp-systems/) +with [the outputs of the 1.5B-parameter GPT-2 model](https://github.com/openai/gpt-2-output-dataset). +For motivations and discussions regarding the release of this detector model, please check out +[out blog post](https://openai.com/blog/gpt-2-6-month-follow-up/) and [report](https://arxiv.org/abs/1908.09203). + +## Downloading a pre-trained detector model + +Download the weights for the fine-tuned `roberta-base` model (478 MB): + +```bash +wget https://storage.googleapis.com/gpt-2/detector-models/v1/detector-base.pt +``` + +or `roberta-large` model (1.5 GB): + +```bash +wget https://storage.googleapis.com/gpt-2/detector-models/v1/detector-large.pt +``` + +These RoBERTa-based models are fine-tuned with a mixture of temperature-1 and nucleus sampling outputs, +which should generalize well to outputs generated using different sampling methods. + +## Running a detector model + +You can launch a web UI in which you can enter a text and see the detector model's prediction +on whether or not it was generated by a GPT-2 model. + +```bash +# (on the top-level directory of this repository) +pip install -r requirements.txt +python -m detector.server detector-base.pt +``` + +## Training a new detector model + +You can use the provided training script to train a detector model on a new set of datasets. +We recommend using a GPU machine for this task. + +```bash +# (on the top-level directory of this repository) +pip install -r requirements.txt +python -m detector.train +``` + +The training script supports a number of different options; append `--help` to the command above for usage. diff --git a/detector/dataset.py b/detector/dataset.py new file mode 100644 index 0000000..e5a739a --- /dev/null +++ b/detector/dataset.py @@ -0,0 +1,86 @@ +import json +import numpy as np +from typing import List + +import torch +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import PreTrainedTokenizer + +from .download import download + + +def load_texts(data_file, expected_size=None): + texts = [] + + for line in tqdm(open(data_file), total=expected_size, desc=f'Loading {data_file}'): + texts.append(json.loads(line)['text']) + + return texts + + +class Corpus: + def __init__(self, name, data_dir='data', skip_train=False): + download(name, data_dir=data_dir) + self.name = name + self.train = load_texts(f'{data_dir}/{name}.train.jsonl', expected_size=250000) if not skip_train else None + self.test = load_texts(f'{data_dir}/{name}.test.jsonl', expected_size=5000) + self.valid = load_texts(f'{data_dir}/{name}.valid.jsonl', expected_size=5000) + + +class EncodedDataset(Dataset): + def __init__(self, real_texts: List[str], fake_texts: List[str], tokenizer: PreTrainedTokenizer, + max_sequence_length: int = None, min_sequence_length: int = None, epoch_size: int = None, + token_dropout: float = None, seed: int = None): + self.real_texts = real_texts + self.fake_texts = fake_texts + self.tokenizer = tokenizer + self.max_sequence_length = max_sequence_length + self.min_sequence_length = min_sequence_length + self.epoch_size = epoch_size + self.token_dropout = token_dropout + self.random = np.random.RandomState(seed) + + def __len__(self): + return self.epoch_size or len(self.real_texts) + len(self.fake_texts) + + def __getitem__(self, index): + if self.epoch_size is not None: + label = self.random.randint(2) + texts = [self.fake_texts, self.real_texts][label] + text = texts[self.random.randint(len(texts))] + else: + if index < len(self.real_texts): + text = self.real_texts[index] + label = 1 + else: + text = self.fake_texts[index - len(self.real_texts)] + label = 0 + + tokens = self.tokenizer.encode(text) + + if self.max_sequence_length is None: + tokens = tokens[:self.tokenizer.max_len - 2] + else: + output_length = min(len(tokens), self.max_sequence_length) + if self.min_sequence_length: + output_length = self.random.randint(min(self.min_sequence_length, len(tokens)), output_length + 1) + start_index = 0 if len(tokens) <= output_length else self.random.randint(0, len(tokens) - output_length + 1) + end_index = start_index + output_length + tokens = tokens[start_index:end_index] + + if self.token_dropout: + dropout_mask = self.random.binomial(1, self.token_dropout, len(tokens)).astype(np.bool) + tokens = np.array(tokens) + tokens[dropout_mask] = self.tokenizer.unk_token_id + tokens = tokens.tolist() + + if self.max_sequence_length is None or len(tokens) == self.max_sequence_length: + mask = torch.ones(len(tokens) + 2) + return torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]), mask, label + + padding = [self.tokenizer.pad_token_id] * (self.max_sequence_length - len(tokens)) + tokens = torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id] + padding) + mask = torch.ones(tokens.shape[0]) + mask[-len(padding):] = 0 + return tokens, mask, label diff --git a/detector/download.py b/detector/download.py new file mode 100644 index 0000000..3d4f931 --- /dev/null +++ b/detector/download.py @@ -0,0 +1,49 @@ +import os + +import requests +import torch.distributed as dist +from tqdm import tqdm + +from .utils import distributed + +ALL_DATASETS = [ + 'webtext', + 'small-117M', 'small-117M-k40', 'small-117M-nucleus', + 'medium-345M', 'medium-345M-k40', 'medium-345M-nucleus', + 'large-762M', 'large-762M-k40', 'large-762M-nucleus', + 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus' +] + + +def download(*datasets, data_dir='data'): + os.makedirs(data_dir, exist_ok=True) + + if distributed() and dist.get_rank() > 0: + dist.barrier() + + for ds in datasets: + assert ds in ALL_DATASETS, f'Unknown dataset {ds}' + + for split in ['train', 'valid', 'test']: + filename = ds + "." + split + '.jsonl' + output_file = os.path.join(data_dir, filename) + if os.path.isfile(output_file): + continue + + r = requests.get("https://storage.googleapis.com/gpt-2/output-dataset/v1/" + filename, stream=True) + + with open(output_file, 'wb') as f: + file_size = int(r.headers["content-length"]) + chunk_size = 1000 + with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: + # 1k for chunk_size, since Ethernet packet size is around 1500 bytes + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + pbar.update(chunk_size) + + if distributed() and dist.get_rank() == 0: + dist.barrier() + + +if __name__ == '__main__': + download(*ALL_DATASETS) diff --git a/detector/index.html b/detector/index.html new file mode 100644 index 0000000..f9cef0c --- /dev/null +++ b/detector/index.html @@ -0,0 +1,154 @@ + + + +GPT-2 Output Detector + + + +
+

GPT-2 Output Detector Demo

+

+ This is an online demo of the + GPT-2 output detector + model. Enter some text in the text box; the predicted probabilities will be displayed below. + The results start to get reliable after around 50 tokens. +

+ +
+ + + + + + + + + + +
RealFake
+
+ + + diff --git a/detector/server.py b/detector/server.py new file mode 100644 index 0000000..90cdcf8 --- /dev/null +++ b/detector/server.py @@ -0,0 +1,120 @@ +import os +import sys +from http.server import HTTPServer, SimpleHTTPRequestHandler +from multiprocessing import Process +import subprocess +from transformers import RobertaForSequenceClassification, RobertaTokenizer +import json +import fire +import torch +from urllib.parse import urlparse, unquote + + +model: RobertaForSequenceClassification = None +tokenizer: RobertaTokenizer = None +device: str = None + +def log(*args): + print(f"[{os.environ.get('RANK', '')}]", *args, file=sys.stderr) + + +class RequestHandler(SimpleHTTPRequestHandler): + + def do_GET(self): + query = unquote(urlparse(self.path).query) + + if not query: + self.begin_content('text/html') + + html = os.path.join(os.path.dirname(__file__), 'index.html') + self.wfile.write(open(html).read().encode()) + return + + self.begin_content('application/json;charset=UTF-8') + + tokens = tokenizer.encode(query) + all_tokens = len(tokens) + tokens = tokens[:tokenizer.max_len - 2] + used_tokens = len(tokens) + tokens = torch.tensor([tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]).unsqueeze(0) + mask = torch.ones_like(tokens) + + with torch.no_grad(): + logits = model(tokens.to(device), attention_mask=mask.to(device))[0] + probs = logits.softmax(dim=-1) + + fake, real = probs.detach().cpu().flatten().numpy().tolist() + + self.wfile.write(json.dumps(dict( + all_tokens=all_tokens, + used_tokens=used_tokens, + real_probability=real, + fake_probability=fake + )).encode()) + + def begin_content(self, content_type): + self.send_response(200) + self.send_header('Content-Type', content_type) + self.send_header('Access-Control-Allow-Origin', '*') + self.end_headers() + + def log_message(self, format, *args): + log(format % args) + + +def serve_forever(server, model, tokenizer, device): + log('Process has started; loading the model ...') + globals()['model'] = model.to(device) + globals()['tokenizer'] = tokenizer + globals()['device'] = device + + log('Ready to serve') + server.serve_forever() + + +def main(checkpoint, port=8080, device='cuda' if torch.cuda.is_available() else 'cpu'): + if checkpoint.startswith('gs://'): + print(f'Downloading {checkpoint}', file=sys.stderr) + subprocess.check_output(['gsutil', 'cp', checkpoint, '.']) + checkpoint = os.path.basename(checkpoint) + assert os.path.isfile(checkpoint) + + print(f'Loading checkpoint from {checkpoint}') + data = torch.load(checkpoint, map_location='cpu') + + model_name = 'roberta-large' if data['args']['large'] else 'roberta-base' + model = RobertaForSequenceClassification.from_pretrained(model_name) + tokenizer = RobertaTokenizer.from_pretrained(model_name) + + model.load_state_dict(data['model_state_dict']) + model.eval() + + print(f'Starting HTTP server on port {port}', file=sys.stderr) + server = HTTPServer(('0.0.0.0', port), RequestHandler) + + # avoid calling CUDA API before forking; doing so in a subprocess is fine. + num_workers = int(subprocess.check_output(['python', '-c', 'import torch; print(torch.cuda.device_count())'])) + + if num_workers <= 1: + serve_forever(server, model, tokenizer, device) + else: + print(f'Launching {num_workers} worker processes...') + + subprocesses = [] + + for i in range(num_workers): + os.environ['RANK'] = f'{i}' + os.environ['CUDA_VISIBLE_DEVICES'] = f'{i}' + process = Process(target=serve_forever, args=(server, model, tokenizer, device)) + process.start() + subprocesses.append(process) + + del os.environ['RANK'] + del os.environ['CUDA_VISIBLE_DEVICES'] + + for process in subprocesses: + process.join() + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/detector/train.py b/detector/train.py new file mode 100644 index 0000000..ff8d79a --- /dev/null +++ b/detector/train.py @@ -0,0 +1,305 @@ +"""Training code for the detector model""" + +import argparse +import os +import subprocess +import sys +from itertools import count +from multiprocessing import Process + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Adam +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler +from tqdm import tqdm +from transformers import * + +from .dataset import Corpus, EncodedDataset +from .download import download +from .utils import summary, distributed + + +def setup_distributed(port=29500): + if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1: + return 0, 1 + + if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ: + from mpi4py import MPI + mpi_rank = MPI.COMM_WORLD.Get_rank() + mpi_size = MPI.COMM_WORLD.Get_size() + + os.environ["MASTER_ADDR"] = '127.0.0.1' + os.environ["MASTER_PORT"] = str(port) + + dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank) + return mpi_rank, mpi_size + + dist.init_process_group(backend="nccl", init_method="env://") + return dist.get_rank(), dist.get_world_size() + + +def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size, + max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None): + if fake_dataset == 'TWO': + download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir) + elif fake_dataset == 'THREE': + download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir) + else: + download(real_dataset, fake_dataset, data_dir=data_dir) + + real_corpus = Corpus(real_dataset, data_dir=data_dir) + + if fake_dataset == "TWO": + real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2 + fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']] + fake_train = sum([corpus.train for corpus in fake_corpora], []) + fake_valid = sum([corpus.valid for corpus in fake_corpora], []) + elif fake_dataset == "THREE": + real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3 + fake_corpora = [Corpus(name, data_dir=data_dir) for name in + ['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']] + fake_train = sum([corpus.train for corpus in fake_corpora], []) + fake_valid = sum([corpus.valid for corpus in fake_corpora], []) + else: + fake_corpus = Corpus(fake_dataset, data_dir=data_dir) + + real_train, real_valid = real_corpus.train, real_corpus.valid + fake_train, fake_valid = fake_corpus.train, fake_corpus.valid + + Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler + + min_sequence_length = 10 if random_sequence_length else None + train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length, + epoch_size, token_dropout, seed) + train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0) + + validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer) + validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset)) + + return train_loader, validation_loader + + +def accuracy_sum(logits, labels): + if list(logits.shape) == list(labels.shape) + [2]: + # 2-d outputs + classification = (logits[..., 0] < logits[..., 1]).long().flatten() + else: + classification = (logits > 0).long().flatten() + assert classification.shape == labels.shape + return (classification == labels).float().sum().item() + + +def train(model: nn.Module, optimizer, device: str, loader: DataLoader, desc='Train'): + model.train() + + train_accuracy = 0 + train_epoch_size = 0 + train_loss = 0 + + with tqdm(loader, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop: + for texts, masks, labels in loop: + + texts, masks, labels = texts.to(device), masks.to(device), labels.to(device) + batch_size = texts.shape[0] + + optimizer.zero_grad() + loss, logits = model(texts, attention_mask=masks, labels=labels) + loss.backward() + optimizer.step() + + batch_accuracy = accuracy_sum(logits, labels) + train_accuracy += batch_accuracy + train_epoch_size += batch_size + train_loss += loss.item() * batch_size + + loop.set_postfix(loss=loss.item(), acc=train_accuracy / train_epoch_size) + + return { + "train/accuracy": train_accuracy, + "train/epoch_size": train_epoch_size, + "train/loss": train_loss + } + + +def validate(model: nn.Module, device: str, loader: DataLoader, votes=1, desc='Validation'): + model.eval() + + validation_accuracy = 0 + validation_epoch_size = 0 + validation_loss = 0 + + records = [record for v in range(votes) for record in tqdm(loader, desc=f'Preloading data ... {v}', + disable=dist.is_available() and dist.get_rank() > 0)] + records = [[records[v * len(loader) + i] for v in range(votes)] for i in range(len(loader))] + + with tqdm(records, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop, torch.no_grad(): + for example in loop: + losses = [] + logit_votes = [] + + for texts, masks, labels in example: + texts, masks, labels = texts.to(device), masks.to(device), labels.to(device) + batch_size = texts.shape[0] + + loss, logits = model(texts, attention_mask=masks, labels=labels) + losses.append(loss) + logit_votes.append(logits) + + loss = torch.stack(losses).mean(dim=0) + logits = torch.stack(logit_votes).mean(dim=0) + + batch_accuracy = accuracy_sum(logits, labels) + validation_accuracy += batch_accuracy + validation_epoch_size += batch_size + validation_loss += loss.item() * batch_size + + loop.set_postfix(loss=loss.item(), acc=validation_accuracy / validation_epoch_size) + + return { + "validation/accuracy": validation_accuracy, + "validation/epoch_size": validation_epoch_size, + "validation/loss": validation_loss + } + + +def _all_reduce_dict(d, device): + # wrap in tensor and use reduce to gpu0 tensor + output_d = {} + for (key, value) in sorted(d.items()): + tensor_input = torch.tensor([[value]]).to(device) + torch.distributed.all_reduce(tensor_input) + output_d[key] = tensor_input.item() + return output_d + + +def run(max_epochs=None, + device=None, + batch_size=24, + max_sequence_length=128, + random_sequence_length=False, + epoch_size=None, + seed=None, + data_dir='data', + real_dataset='webtext', + fake_dataset='xl-1542M-nucleus', + token_dropout=None, + large=False, + learning_rate=2e-5, + weight_decay=0, + **kwargs): + args = locals() + rank, world_size = setup_distributed() + + if device is None: + device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu' + + print('rank:', rank, 'world_size:', world_size, 'device:', device) + + import torch.distributed as dist + if distributed() and rank > 0: + dist.barrier() + + model_name = 'roberta-large' if large else 'roberta-base' + tokenization_utils.logger.setLevel('ERROR') + tokenizer = RobertaTokenizer.from_pretrained(model_name) + model = RobertaForSequenceClassification.from_pretrained(model_name).to(device) + + if rank == 0: + summary(model) + if distributed(): + dist.barrier() + + if world_size > 1: + model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True) + + train_loader, validation_loader = load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size, + max_sequence_length, random_sequence_length, epoch_size, + token_dropout, seed) + + optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1) + + logdir = os.environ.get("OPENAI_LOGDIR", "logs") + os.makedirs(logdir, exist_ok=True) + + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter(logdir) if rank == 0 else None + best_validation_accuracy = 0 + + for epoch in epoch_loop: + if world_size > 1: + train_loader.sampler.set_epoch(epoch) + validation_loader.sampler.set_epoch(epoch) + + train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}') + validation_metrics = validate(model, device, validation_loader) + + combined_metrics = _all_reduce_dict({**validation_metrics, **train_metrics}, device) + + combined_metrics["train/accuracy"] /= combined_metrics["train/epoch_size"] + combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"] + combined_metrics["validation/accuracy"] /= combined_metrics["validation/epoch_size"] + combined_metrics["validation/loss"] /= combined_metrics["validation/epoch_size"] + + if rank == 0: + for key, value in combined_metrics.items(): + writer.add_scalar(key, value, global_step=epoch) + + if combined_metrics["validation/accuracy"] > best_validation_accuracy: + best_validation_accuracy = combined_metrics["validation/accuracy"] + + model_to_save = model.module if hasattr(model, 'module') else model + torch.save(dict( + epoch=epoch, + model_state_dict=model_to_save.state_dict(), + optimizer_state_dict=optimizer.state_dict(), + args=args + ), + os.path.join(logdir, "best-model.pt") + ) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--max-epochs', type=int, default=None) + parser.add_argument('--device', type=str, default=None) + parser.add_argument('--batch-size', type=int, default=24) + parser.add_argument('--max-sequence-length', type=int, default=128) + parser.add_argument('--random-sequence-length', action='store_true') + parser.add_argument('--epoch-size', type=int, default=None) + parser.add_argument('--seed', type=int, default=None) + parser.add_argument('--data-dir', type=str, default='data') + parser.add_argument('--real-dataset', type=str, default='webtext') + parser.add_argument('--fake-dataset', type=str, default='xl-1542M-k40') + parser.add_argument('--token-dropout', type=float, default=None) + + parser.add_argument('--large', action='store_true', help='use the roberta-large model instead of roberta-base') + parser.add_argument('--learning-rate', type=float, default=2e-5) + parser.add_argument('--weight-decay', type=float, default=0) + args = parser.parse_args() + + nproc = int(subprocess.check_output(['python', '-c', "import torch;" + "print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"])) + if nproc > 1: + print(f'Launching {nproc} processes ...', file=sys.stderr) + + os.environ["MASTER_ADDR"] = '127.0.0.1' + os.environ["MASTER_PORT"] = str(29500) + os.environ['WORLD_SIZE'] = str(nproc) + os.environ['OMP_NUM_THREAD'] = str(1) + subprocesses = [] + + for i in range(nproc): + os.environ['RANK'] = str(i) + os.environ['LOCAL_RANK'] = str(i) + process = Process(target=run, kwargs=vars(args)) + process.start() + subprocesses.append(process) + + for process in subprocesses: + process.join() + else: + run(**vars(args)) diff --git a/detector/utils.py b/detector/utils.py new file mode 100644 index 0000000..7b636d6 --- /dev/null +++ b/detector/utils.py @@ -0,0 +1,62 @@ +import sys +from functools import reduce + +from torch import nn +import torch.distributed as dist + + +def summary(model: nn.Module, file=sys.stdout): + def repr(model): + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = model.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + total_params = 0 + for key, module in model._modules.items(): + mod_str, num_params = repr(module) + mod_str = nn.modules.module._addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + total_params += num_params + lines = extra_lines + child_lines + + for name, p in model._parameters.items(): + if hasattr(p, 'shape'): + total_params += reduce(lambda x, y: x * y, p.shape) + + main_str = model._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + if file is sys.stdout: + main_str += ', \033[92m{:,}\033[0m params'.format(total_params) + else: + main_str += ', {:,} params'.format(total_params) + return main_str, total_params + + string, count = repr(model) + if file is not None: + if isinstance(file, str): + file = open(file, 'w') + print(string, file=file) + file.flush() + + return count + + +def grad_norm(model: nn.Module): + total_norm = 0 + for p in model.parameters(): + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + return total_norm ** 0.5 + +def distributed(): + return dist.is_available() and dist.is_initialized() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..af666b1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +transformers>=2.0.0 +fire>=0.2.1 +requests>=2.22.0 +tqdm>=4.32.2 +torch>=1.2.0 +tensorboard>=1.14.0