From 045a4066d433faa30e1f6d596f1badc8c5575815 Mon Sep 17 00:00:00 2001 From: menouar Date: Fri, 15 Sep 2023 13:01:49 +0200 Subject: [PATCH] Correct a unit test --- tests/models/test_models_builder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/test_models_builder.py b/tests/models/test_models_builder.py index 21dea3d..f63d218 100644 --- a/tests/models/test_models_builder.py +++ b/tests/models/test_models_builder.py @@ -33,9 +33,10 @@ def test_create_models(self): model_creator = ModelCreator( [(RNN_ENCODER_DECODER, 1), (FFN, 2), (CNN, 2), (RNN_BIDIRECTIONAL, 1), (CONV_LSTM1D, 1), (LSTM, 3), (SELF_ATTENTION, 3)], - hyperparams_rnn=(3, 45, 46), - hyperparams_cnn=(64, 65, 3, 4, 1), - hyperparams_ffn=(3, 64, 128), save_models_as_dot_format=False, root_dir=None) + hyperparams_rnn=(3, 45, 46, "tanh"), + hyperparams_cnn=(64, 65, 3, 4, 1, "relu"), + hyperparams_ffn=(3, 64, 128, "sigmoid"), save_models_as_dot_format=False, root_dir=None, dropout=0.3, + last_act_func="sigmoid", hyperparams_transformer=(256, 4, 1, True, "relu")) model_creator.create_models(inputs=self.inputs_rnn)