detector model

This commit is contained in:
Jong Wook Kim 2019-11-05 07:32:08 +01:00 committed by Jeff Wu
parent cba7812d49
commit 6f40009b89
8 changed files with 831 additions and 0 deletions

49
detector/README.md Normal file
View File

@ -0,0 +1,49 @@
GPT-2 Output Detector
=====================
This directory contains the code for working with the GPT-2 output detector model, obtained by fine-tuning a
[RoBERTa model](https://ai.facebook.com/blog/roberta-an-optimized-method-for-pretraining-self-supervised-nlp-systems/)
with [the outputs of the 1.5B-parameter GPT-2 model](https://github.com/openai/gpt-2-output-dataset).
For motivations and discussions regarding the release of this detector model, please check out
[out blog post](https://openai.com/blog/gpt-2-6-month-follow-up/) and [report](https://arxiv.org/abs/1908.09203).
## Downloading a pre-trained detector model
Download the weights for the fine-tuned `roberta-base` model (478 MB):
```bash
wget https://storage.googleapis.com/gpt-2/detector-models/v1/detector-base.pt
```
or `roberta-large` model (1.5 GB):
```bash
wget https://storage.googleapis.com/gpt-2/detector-models/v1/detector-large.pt
```
These RoBERTa-based models are fine-tuned with a mixture of temperature-1 and nucleus sampling outputs,
which should generalize well to outputs generated using different sampling methods.
## Running a detector model
You can launch a web UI in which you can enter a text and see the detector model's prediction
on whether or not it was generated by a GPT-2 model.
```bash
# (on the top-level directory of this repository)
pip install -r requirements.txt
python -m detector.server detector-base.pt
```
## Training a new detector model
You can use the provided training script to train a detector model on a new set of datasets.
We recommend using a GPU machine for this task.
```bash
# (on the top-level directory of this repository)
pip install -r requirements.txt
python -m detector.train
```
The training script supports a number of different options; append `--help` to the command above for usage.

86
detector/dataset.py Normal file
View File

@ -0,0 +1,86 @@
import json
import numpy as np
from typing import List
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from .download import download
def load_texts(data_file, expected_size=None):
texts = []
for line in tqdm(open(data_file), total=expected_size, desc=f'Loading {data_file}'):
texts.append(json.loads(line)['text'])
return texts
class Corpus:
def __init__(self, name, data_dir='data', skip_train=False):
download(name, data_dir=data_dir)
self.name = name
self.train = load_texts(f'{data_dir}/{name}.train.jsonl', expected_size=250000) if not skip_train else None
self.test = load_texts(f'{data_dir}/{name}.test.jsonl', expected_size=5000)
self.valid = load_texts(f'{data_dir}/{name}.valid.jsonl', expected_size=5000)
class EncodedDataset(Dataset):
def __init__(self, real_texts: List[str], fake_texts: List[str], tokenizer: PreTrainedTokenizer,
max_sequence_length: int = None, min_sequence_length: int = None, epoch_size: int = None,
token_dropout: float = None, seed: int = None):
self.real_texts = real_texts
self.fake_texts = fake_texts
self.tokenizer = tokenizer
self.max_sequence_length = max_sequence_length
self.min_sequence_length = min_sequence_length
self.epoch_size = epoch_size
self.token_dropout = token_dropout
self.random = np.random.RandomState(seed)
def __len__(self):
return self.epoch_size or len(self.real_texts) + len(self.fake_texts)
def __getitem__(self, index):
if self.epoch_size is not None:
label = self.random.randint(2)
texts = [self.fake_texts, self.real_texts][label]
text = texts[self.random.randint(len(texts))]
else:
if index < len(self.real_texts):
text = self.real_texts[index]
label = 1
else:
text = self.fake_texts[index - len(self.real_texts)]
label = 0
tokens = self.tokenizer.encode(text)
if self.max_sequence_length is None:
tokens = tokens[:self.tokenizer.max_len - 2]
else:
output_length = min(len(tokens), self.max_sequence_length)
if self.min_sequence_length:
output_length = self.random.randint(min(self.min_sequence_length, len(tokens)), output_length + 1)
start_index = 0 if len(tokens) <= output_length else self.random.randint(0, len(tokens) - output_length + 1)
end_index = start_index + output_length
tokens = tokens[start_index:end_index]
if self.token_dropout:
dropout_mask = self.random.binomial(1, self.token_dropout, len(tokens)).astype(np.bool)
tokens = np.array(tokens)
tokens[dropout_mask] = self.tokenizer.unk_token_id
tokens = tokens.tolist()
if self.max_sequence_length is None or len(tokens) == self.max_sequence_length:
mask = torch.ones(len(tokens) + 2)
return torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]), mask, label
padding = [self.tokenizer.pad_token_id] * (self.max_sequence_length - len(tokens))
tokens = torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id] + padding)
mask = torch.ones(tokens.shape[0])
mask[-len(padding):] = 0
return tokens, mask, label

49
detector/download.py Normal file
View File

@ -0,0 +1,49 @@
import os
import requests
import torch.distributed as dist
from tqdm import tqdm
from .utils import distributed
ALL_DATASETS = [
'webtext',
'small-117M', 'small-117M-k40', 'small-117M-nucleus',
'medium-345M', 'medium-345M-k40', 'medium-345M-nucleus',
'large-762M', 'large-762M-k40', 'large-762M-nucleus',
'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus'
]
def download(*datasets, data_dir='data'):
os.makedirs(data_dir, exist_ok=True)
if distributed() and dist.get_rank() > 0:
dist.barrier()
for ds in datasets:
assert ds in ALL_DATASETS, f'Unknown dataset {ds}'
for split in ['train', 'valid', 'test']:
filename = ds + "." + split + '.jsonl'
output_file = os.path.join(data_dir, filename)
if os.path.isfile(output_file):
continue
r = requests.get("https://storage.googleapis.com/gpt-2/output-dataset/v1/" + filename, stream=True)
with open(output_file, 'wb') as f:
file_size = int(r.headers["content-length"])
chunk_size = 1000
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)
if distributed() and dist.get_rank() == 0:
dist.barrier()
if __name__ == '__main__':
download(*ALL_DATASETS)

154
detector/index.html Normal file
View File

@ -0,0 +1,154 @@
<!doctype html>
<html>
<head>
<title>GPT-2 Output Detector</title>
<style type="text/css">
* {
box-sizing: border-box;
}
body {
font-family: sans-serif;
margin: 0;
}
h1 {
font-weight: lighter;
}
a {
text-decoration: none;
color: #666;
}
a:hover {
text-decoration: underline;
}
#container {
margin: auto;
width: 960px;
}
#textbox {
font-family: serif;
font-size: 16pt;
width: 100%;
height: 480px;
padding: 20px 30px;
line-height: 1.6;
}
.bar-row {
height: 30px;
}
#real-percentage {
width: 80px;
vertical-align: top;
}
#bar-container {
width: 800px;
background-color: #ff7674;
line-height: 0.5;
position:relative;
top:6px;
}
#fake-percentage {
width: 80px;
vertical-align: top;
}
#bar {
display: inline-block;
height: 30px;
background-color: #83aaff;
}
em {
font-family: monospace;
font-style: normal;
}
</style>
</head>
<body>
<div id="container">
<h1>GPT-2 Output Detector Demo</h1>
<p>
This is an online demo of the
<a href="https://github.com/openai/gpt-2-output-dataset/tree/master/detector">GPT-2 output detector</a>
model. Enter some text in the text box; the predicted probabilities will be displayed below.
<u>The results start to get reliable after around 50 tokens.</u>
</p>
<textarea id="textbox" placeholder="Enter text here"></textarea>
<div><table cellspacing="0" cellpadding="0">
<tr class="bar-row" style="vertical-align: bottom;">
<td style="text-align: left;">Real</td>
<td id="message" style="text-align: center;"></td>
<td style="text-align: right;">Fake</td>
</tr>
<tr class="bar-row">
<td id="real-percentage" style="text-align: left; vertical-align: bottom;"></td>
<td id="bar-container"><div id="bar" style="width: 50%;"></div></td>
<td id="fake-percentage" style="text-align: right; vertical-align: bottom;"></td>
</tr>
</table></div>
</div>
<script>
let textbox = document.getElementById('textbox');
let last_submit = null;
let real_percentage = document.getElementById('real-percentage');
let fake_percentage = document.getElementById('fake-percentage');
let bar = document.getElementById('bar');
let message = document.getElementById('message');
function update_graph(result) {
if (result === null) {
real_percentage.innerHTML = '';
fake_percentage.innerHTML = '';
bar.style.width = '50%';
message.innerHTML = '';
} else {
let percentage = result.real_probability;
real_percentage.innerHTML = (100 * percentage).toFixed(2) + '%';
fake_percentage.innerHTML = (100 * (1 - percentage)).toFixed(2) + '%';
bar.style.width = (100 * percentage).toFixed(2) + '%';
if (result.used_tokens === result.all_tokens) {
message.innerHTML = `Prediction based on ${result.used_tokens} tokens`;
} else {
message.innerHTML = `Prediction based on the first ${result.used_tokens} tokens among the total ${result.all_tokens}`;
}
}
}
textbox.oninput = () => {
if (last_submit) {
clearTimeout(last_submit);
}
if (textbox.value.length === 0) {
update_graph(null);
return;
}
message.innerText = 'Predicting ...';
last_submit = setTimeout(() => {
let req = new XMLHttpRequest();
if (textbox.value.length === 0) {
update_graph(null);
return;
}
req.open('GET', '/?' + textbox.value, true);
req.onreadystatechange = () => {
if (req.readyState !== 4) return;
if (req.status !== 200) throw new Error("HTTP status: " + req.status);
let result = JSON.parse(req.responseText);
update_graph(result);
};
req.send();
}, 1000);
};
window.addEventListener('DOMContentLoaded', () => {
textbox.focus();
});
</script>
</body>
</html>

120
detector/server.py Normal file
View File

@ -0,0 +1,120 @@
import os
import sys
from http.server import HTTPServer, SimpleHTTPRequestHandler
from multiprocessing import Process
import subprocess
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import json
import fire
import torch
from urllib.parse import urlparse, unquote
model: RobertaForSequenceClassification = None
tokenizer: RobertaTokenizer = None
device: str = None
def log(*args):
print(f"[{os.environ.get('RANK', '')}]", *args, file=sys.stderr)
class RequestHandler(SimpleHTTPRequestHandler):
def do_GET(self):
query = unquote(urlparse(self.path).query)
if not query:
self.begin_content('text/html')
html = os.path.join(os.path.dirname(__file__), 'index.html')
self.wfile.write(open(html).read().encode())
return
self.begin_content('application/json;charset=UTF-8')
tokens = tokenizer.encode(query)
all_tokens = len(tokens)
tokens = tokens[:tokenizer.max_len - 2]
used_tokens = len(tokens)
tokens = torch.tensor([tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]).unsqueeze(0)
mask = torch.ones_like(tokens)
with torch.no_grad():
logits = model(tokens.to(device), attention_mask=mask.to(device))[0]
probs = logits.softmax(dim=-1)
fake, real = probs.detach().cpu().flatten().numpy().tolist()
self.wfile.write(json.dumps(dict(
all_tokens=all_tokens,
used_tokens=used_tokens,
real_probability=real,
fake_probability=fake
)).encode())
def begin_content(self, content_type):
self.send_response(200)
self.send_header('Content-Type', content_type)
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
def log_message(self, format, *args):
log(format % args)
def serve_forever(server, model, tokenizer, device):
log('Process has started; loading the model ...')
globals()['model'] = model.to(device)
globals()['tokenizer'] = tokenizer
globals()['device'] = device
log('Ready to serve')
server.serve_forever()
def main(checkpoint, port=8080, device='cuda' if torch.cuda.is_available() else 'cpu'):
if checkpoint.startswith('gs://'):
print(f'Downloading {checkpoint}', file=sys.stderr)
subprocess.check_output(['gsutil', 'cp', checkpoint, '.'])
checkpoint = os.path.basename(checkpoint)
assert os.path.isfile(checkpoint)
print(f'Loading checkpoint from {checkpoint}')
data = torch.load(checkpoint, map_location='cpu')
model_name = 'roberta-large' if data['args']['large'] else 'roberta-base'
model = RobertaForSequenceClassification.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model.load_state_dict(data['model_state_dict'])
model.eval()
print(f'Starting HTTP server on port {port}', file=sys.stderr)
server = HTTPServer(('0.0.0.0', port), RequestHandler)
# avoid calling CUDA API before forking; doing so in a subprocess is fine.
num_workers = int(subprocess.check_output(['python', '-c', 'import torch; print(torch.cuda.device_count())']))
if num_workers <= 1:
serve_forever(server, model, tokenizer, device)
else:
print(f'Launching {num_workers} worker processes...')
subprocesses = []
for i in range(num_workers):
os.environ['RANK'] = f'{i}'
os.environ['CUDA_VISIBLE_DEVICES'] = f'{i}'
process = Process(target=serve_forever, args=(server, model, tokenizer, device))
process.start()
subprocesses.append(process)
del os.environ['RANK']
del os.environ['CUDA_VISIBLE_DEVICES']
for process in subprocesses:
process.join()
if __name__ == '__main__':
fire.Fire(main)

305
detector/train.py Normal file
View File

@ -0,0 +1,305 @@
"""Training code for the detector model"""
import argparse
import os
import subprocess
import sys
from itertools import count
from multiprocessing import Process
import torch
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from tqdm import tqdm
from transformers import *
from .dataset import Corpus, EncodedDataset
from .download import download
from .utils import summary, distributed
def setup_distributed(port=29500):
if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
return 0, 1
if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ:
from mpi4py import MPI
mpi_rank = MPI.COMM_WORLD.Get_rank()
mpi_size = MPI.COMM_WORLD.Get_size()
os.environ["MASTER_ADDR"] = '127.0.0.1'
os.environ["MASTER_PORT"] = str(port)
dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank)
return mpi_rank, mpi_size
dist.init_process_group(backend="nccl", init_method="env://")
return dist.get_rank(), dist.get_world_size()
def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None):
if fake_dataset == 'TWO':
download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir)
elif fake_dataset == 'THREE':
download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir)
else:
download(real_dataset, fake_dataset, data_dir=data_dir)
real_corpus = Corpus(real_dataset, data_dir=data_dir)
if fake_dataset == "TWO":
real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']]
fake_train = sum([corpus.train for corpus in fake_corpora], [])
fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
elif fake_dataset == "THREE":
real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3
fake_corpora = [Corpus(name, data_dir=data_dir) for name in
['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']]
fake_train = sum([corpus.train for corpus in fake_corpora], [])
fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
else:
fake_corpus = Corpus(fake_dataset, data_dir=data_dir)
real_train, real_valid = real_corpus.train, real_corpus.valid
fake_train, fake_valid = fake_corpus.train, fake_corpus.valid
Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler
min_sequence_length = 10 if random_sequence_length else None
train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length,
epoch_size, token_dropout, seed)
train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)
validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer)
validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))
return train_loader, validation_loader
def accuracy_sum(logits, labels):
if list(logits.shape) == list(labels.shape) + [2]:
# 2-d outputs
classification = (logits[..., 0] < logits[..., 1]).long().flatten()
else:
classification = (logits > 0).long().flatten()
assert classification.shape == labels.shape
return (classification == labels).float().sum().item()
def train(model: nn.Module, optimizer, device: str, loader: DataLoader, desc='Train'):
model.train()
train_accuracy = 0
train_epoch_size = 0
train_loss = 0
with tqdm(loader, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop:
for texts, masks, labels in loop:
texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
batch_size = texts.shape[0]
optimizer.zero_grad()
loss, logits = model(texts, attention_mask=masks, labels=labels)
loss.backward()
optimizer.step()
batch_accuracy = accuracy_sum(logits, labels)
train_accuracy += batch_accuracy
train_epoch_size += batch_size
train_loss += loss.item() * batch_size
loop.set_postfix(loss=loss.item(), acc=train_accuracy / train_epoch_size)
return {
"train/accuracy": train_accuracy,
"train/epoch_size": train_epoch_size,
"train/loss": train_loss
}
def validate(model: nn.Module, device: str, loader: DataLoader, votes=1, desc='Validation'):
model.eval()
validation_accuracy = 0
validation_epoch_size = 0
validation_loss = 0
records = [record for v in range(votes) for record in tqdm(loader, desc=f'Preloading data ... {v}',
disable=dist.is_available() and dist.get_rank() > 0)]
records = [[records[v * len(loader) + i] for v in range(votes)] for i in range(len(loader))]
with tqdm(records, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop, torch.no_grad():
for example in loop:
losses = []
logit_votes = []
for texts, masks, labels in example:
texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
batch_size = texts.shape[0]
loss, logits = model(texts, attention_mask=masks, labels=labels)
losses.append(loss)
logit_votes.append(logits)
loss = torch.stack(losses).mean(dim=0)
logits = torch.stack(logit_votes).mean(dim=0)
batch_accuracy = accuracy_sum(logits, labels)
validation_accuracy += batch_accuracy
validation_epoch_size += batch_size
validation_loss += loss.item() * batch_size
loop.set_postfix(loss=loss.item(), acc=validation_accuracy / validation_epoch_size)
return {
"validation/accuracy": validation_accuracy,
"validation/epoch_size": validation_epoch_size,
"validation/loss": validation_loss
}
def _all_reduce_dict(d, device):
# wrap in tensor and use reduce to gpu0 tensor
output_d = {}
for (key, value) in sorted(d.items()):
tensor_input = torch.tensor([[value]]).to(device)
torch.distributed.all_reduce(tensor_input)
output_d[key] = tensor_input.item()
return output_d
def run(max_epochs=None,
device=None,
batch_size=24,
max_sequence_length=128,
random_sequence_length=False,
epoch_size=None,
seed=None,
data_dir='data',
real_dataset='webtext',
fake_dataset='xl-1542M-nucleus',
token_dropout=None,
large=False,
learning_rate=2e-5,
weight_decay=0,
**kwargs):
args = locals()
rank, world_size = setup_distributed()
if device is None:
device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'
print('rank:', rank, 'world_size:', world_size, 'device:', device)
import torch.distributed as dist
if distributed() and rank > 0:
dist.barrier()
model_name = 'roberta-large' if large else 'roberta-base'
tokenization_utils.logger.setLevel('ERROR')
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForSequenceClassification.from_pretrained(model_name).to(device)
if rank == 0:
summary(model)
if distributed():
dist.barrier()
if world_size > 1:
model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True)
train_loader, validation_loader = load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
max_sequence_length, random_sequence_length, epoch_size,
token_dropout, seed)
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1)
logdir = os.environ.get("OPENAI_LOGDIR", "logs")
os.makedirs(logdir, exist_ok=True)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(logdir) if rank == 0 else None
best_validation_accuracy = 0
for epoch in epoch_loop:
if world_size > 1:
train_loader.sampler.set_epoch(epoch)
validation_loader.sampler.set_epoch(epoch)
train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}')
validation_metrics = validate(model, device, validation_loader)
combined_metrics = _all_reduce_dict({**validation_metrics, **train_metrics}, device)
combined_metrics["train/accuracy"] /= combined_metrics["train/epoch_size"]
combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"]
combined_metrics["validation/accuracy"] /= combined_metrics["validation/epoch_size"]
combined_metrics["validation/loss"] /= combined_metrics["validation/epoch_size"]
if rank == 0:
for key, value in combined_metrics.items():
writer.add_scalar(key, value, global_step=epoch)
if combined_metrics["validation/accuracy"] > best_validation_accuracy:
best_validation_accuracy = combined_metrics["validation/accuracy"]
model_to_save = model.module if hasattr(model, 'module') else model
torch.save(dict(
epoch=epoch,
model_state_dict=model_to_save.state_dict(),
optimizer_state_dict=optimizer.state_dict(),
args=args
),
os.path.join(logdir, "best-model.pt")
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=None)
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--batch-size', type=int, default=24)
parser.add_argument('--max-sequence-length', type=int, default=128)
parser.add_argument('--random-sequence-length', action='store_true')
parser.add_argument('--epoch-size', type=int, default=None)
parser.add_argument('--seed', type=int, default=None)
parser.add_argument('--data-dir', type=str, default='data')
parser.add_argument('--real-dataset', type=str, default='webtext')
parser.add_argument('--fake-dataset', type=str, default='xl-1542M-k40')
parser.add_argument('--token-dropout', type=float, default=None)
parser.add_argument('--large', action='store_true', help='use the roberta-large model instead of roberta-base')
parser.add_argument('--learning-rate', type=float, default=2e-5)
parser.add_argument('--weight-decay', type=float, default=0)
args = parser.parse_args()
nproc = int(subprocess.check_output(['python', '-c', "import torch;"
"print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
if nproc > 1:
print(f'Launching {nproc} processes ...', file=sys.stderr)
os.environ["MASTER_ADDR"] = '127.0.0.1'
os.environ["MASTER_PORT"] = str(29500)
os.environ['WORLD_SIZE'] = str(nproc)
os.environ['OMP_NUM_THREAD'] = str(1)
subprocesses = []
for i in range(nproc):
os.environ['RANK'] = str(i)
os.environ['LOCAL_RANK'] = str(i)
process = Process(target=run, kwargs=vars(args))
process.start()
subprocesses.append(process)
for process in subprocesses:
process.join()
else:
run(**vars(args))

62
detector/utils.py Normal file
View File

@ -0,0 +1,62 @@
import sys
from functools import reduce
from torch import nn
import torch.distributed as dist
def summary(model: nn.Module, file=sys.stdout):
def repr(model):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
total_params = 0
for key, module in model._modules.items():
mod_str, num_params = repr(module)
mod_str = nn.modules.module._addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
total_params += num_params
lines = extra_lines + child_lines
for name, p in model._parameters.items():
if hasattr(p, 'shape'):
total_params += reduce(lambda x, y: x * y, p.shape)
main_str = model._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
if file is sys.stdout:
main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
else:
main_str += ', {:,} params'.format(total_params)
return main_str, total_params
string, count = repr(model)
if file is not None:
if isinstance(file, str):
file = open(file, 'w')
print(string, file=file)
file.flush()
return count
def grad_norm(model: nn.Module):
total_norm = 0
for p in model.parameters():
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
return total_norm ** 0.5
def distributed():
return dist.is_available() and dist.is_initialized()

6
requirements.txt Normal file
View File

@ -0,0 +1,6 @@
transformers>=2.0.0
fire>=0.2.1
requests>=2.22.0
tqdm>=4.32.2
torch>=1.2.0
tensorboard>=1.14.0