-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_multimnist.py
38 lines (32 loc) · 1.12 KB
/
train_multimnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import json
import numpy as np
import tensorflow as tf
from helper import *
from capsule import CapsuleMultiMNIST
args = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'arch', 'architecture-multimnist.json', 'network architecture')
tf.app.flags.DEFINE_string(
'logdir_root', None, 'root of log dir')
tf.app.flags.DEFINE_string('logdir', None, 'log dir')
tf.app.flags.DEFINE_string(
'restore_from', None, 'restore from dir (not from *.ckpt)')
tf.app.flags.DEFINE_string('msg', '-MultiMNIST', 'Additional message')
def main():
with open(args.arch) as fp:
arch = json.load(fp)
data = MultiMNISTIndexReader(
train_index='MultiMNIST_index_train.npf',
batch_size=arch['training']['batch_size'],
data_format='channels_last',
capacity=2**10, min_after_dequeue=2**9
)
dirs = validate_log_dirs(args)
arch.update({'logdir': dirs['logdir']})
net = CapsuleMultiMNIST(arch=arch)
loss = net.loss(data.x, data.y, data.xi, data.xj)
# net.inspect(data.example)
# loss_t = net.loss(data.x_t, data.y_t)
net.train(loss, loss_t=None)
if __name__ == '__main__':
main()