-
Notifications
You must be signed in to change notification settings - Fork 0
/
utility.py
56 lines (46 loc) · 1.5 KB
/
utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
from datetime import datetime
import os
import yaml
def get_map_char_to_int(chars):
''' Creates a mapping from characters to indices
'''
return { ch:i for i,ch in enumerate(chars) }
def get_map_int_to_char(chars):
''' Creates a mapping from indices to character
'''
return { i:ch for i,ch in enumerate(chars) }
def encode(map_char_to_int, s):
''' Encode string to list of indices
'''
return [map_char_to_int[c] for c in s]
def decode(map_int_to_char, l):
''' Decode list of indices to string
'''
return ''.join([map_int_to_char[i] for i in l])
@torch.no_grad()
def estimate_loss(model, device, data_loader, eval_iters=100):
''' Estimate loss
'''
model.eval()
losses = torch.zeros(eval_iters)
for k, (x, y) in enumerate(data_loader):
if k == eval_iters:
break
x, y = x.to(device), y.to(device)
logits, loss = model(x, y)
losses[k] = loss.item()
model.train()
return losses.mean()
def save_model_and_config(model, config):
''' Save model to .pth file and config to .yaml
Saves both files to timestamped directory and returns path to directory
'''
str_now = datetime.now().strftime("%y%m%d%H%M")
save_path = os.path.join('./models', str_now)
os.makedirs(save_path)
torch.save(model, os.path.join(save_path, 'model.pth'))
with open(os.path.join(save_path, 'config.yaml'), 'w') as fh:
yaml.dump(config, fh)
return save_path