Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
RoundOffError committed Oct 11, 2023
1 parent 1e06153 commit b452285
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/test_equine_protonet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def test_train_episodes_shared_reg(random_dataset):
X, Y = dataset.tensors
num_deep_features = 32
embed_model = BasicEmbeddingModel(X.shape[1], num_deep_features)
model = eq.EquineProtonet(embed_model, num_deep_features, , cov_reg_type="shared")
model = eq.EquineProtonet(embed_model, num_deep_features)
model.cov_reg_type = "shared"
model.model.cov_reg_type = "shared"
model.train_model(
dataset,
way=way,
Expand Down

0 comments on commit b452285

Please sign in to comment.