mirror of
https://github.com/openai/gpt-2-output-dataset
synced 2025-08-22 18:07:53 +00:00
using sys.executable for subprocess calls (fixes #8)
This commit is contained in:
parent
12459ab3ed
commit
6d90da539b
@ -68,7 +68,7 @@ def serve_forever(server, model, tokenizer, device):
|
|||||||
globals()['tokenizer'] = tokenizer
|
globals()['tokenizer'] = tokenizer
|
||||||
globals()['device'] = device
|
globals()['device'] = device
|
||||||
|
|
||||||
log('Ready to serve')
|
log(f'Ready to serve at http://localhost:{server.server_address[1]}')
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ def main(checkpoint, port=8080, device='cuda' if torch.cuda.is_available() else
|
|||||||
server = HTTPServer(('0.0.0.0', port), RequestHandler)
|
server = HTTPServer(('0.0.0.0', port), RequestHandler)
|
||||||
|
|
||||||
# avoid calling CUDA API before forking; doing so in a subprocess is fine.
|
# avoid calling CUDA API before forking; doing so in a subprocess is fine.
|
||||||
num_workers = int(subprocess.check_output(['python', '-c', 'import torch; print(torch.cuda.device_count())']))
|
num_workers = int(subprocess.check_output([sys.executable, '-c', 'import torch; print(torch.cuda.device_count())']))
|
||||||
|
|
||||||
if num_workers <= 1:
|
if num_workers <= 1:
|
||||||
serve_forever(server, model, tokenizer, device)
|
serve_forever(server, model, tokenizer, device)
|
||||||
|
@ -281,7 +281,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--weight-decay', type=float, default=0)
|
parser.add_argument('--weight-decay', type=float, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
nproc = int(subprocess.check_output(['python', '-c', "import torch;"
|
nproc = int(subprocess.check_output([sys.executable, '-c', "import torch;"
|
||||||
"print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
|
"print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
|
||||||
if nproc > 1:
|
if nproc > 1:
|
||||||
print(f'Launching {nproc} processes ...', file=sys.stderr)
|
print(f'Launching {nproc} processes ...', file=sys.stderr)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user