-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
25 lines (20 loc) · 969 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from configs import available_models, available_corpus, available_embeddings
from configs import get_corpus_params, get_hyperparams, get_embedding_params
from models import ModelFactory
from utils import tools
def train():
model_name = available_models[0]
seq2seq_model = ModelFactory.make_model(model_name)
hyperparams = get_hyperparams(model_name)
corpus_name = available_corpus[1]
corpus_params = get_corpus_params(corpus_name)
embedding_name = available_embeddings[1]
embedding_params = get_embedding_params(embedding_name)
model_url = ''
error_text = '她也就是说爱撒娇。'
tools.train_model(seq2seq_model, hyperparams, corpus_params, embedding_params, model_url,
observe=True, error_text=error_text,
beam_width=3, beamsearch_interval=1, is_latin=False)
# TODO use train and val data to train model again after params tuned.
if __name__ == '__main__':
train()