-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
36 lines (30 loc) · 987 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import json
import numpy as np
import tensorflow as tf
from helper import *
from capsule import CapsuleNet
args = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'arch', 'architecture-capsule.json', 'network architecture')
tf.app.flags.DEFINE_string(
'logdir_root', None, 'root of log dir')
tf.app.flags.DEFINE_string('logdir', None, 'log dir')
tf.app.flags.DEFINE_string(
'restore_from', None, 'restore from dir (not from *.ckpt)')
tf.app.flags.DEFINE_string('msg', '-Capsule', 'Additional message')
def main():
with open(args.arch) as fp:
arch = json.load(fp)
data = MNIST(
batch_size=arch['training']['batch_size'],
data_format='channels_last'
)
dirs = validate_log_dirs(args)
arch.update({'logdir': dirs['logdir']})
net = CapsuleNet(arch=arch)
loss = net.loss(data.x, data.y)
net.inspect(data.example)
loss_t = net.loss(data.x_t, data.y_t)
net.train(loss, loss_t)
if __name__ == '__main__':
main()