diff --git a/cavachon/utils/TensorUtils.py b/cavachon/utils/TensorUtils.py index 89227dd..f6f7908 100644 --- a/cavachon/utils/TensorUtils.py +++ b/cavachon/utils/TensorUtils.py @@ -1,11 +1,13 @@ +from typing import Dict, Iterable, List, Optional, Tuple + import numpy as np import pandas as pd import scipy import tensorflow as tf +from sklearn.preprocessing import LabelEncoder from cavachon.utils.DataFrameUtils import DataFrameUtils -from sklearn.preprocessing import LabelEncoder -from typing import Dict, Iterable, List, Optional, Tuple + class TensorUtils: """TensorUtils @@ -37,7 +39,7 @@ def max_n_neurons(layers: Iterable[tf.keras.layers.Layer]) -> int: return current_max @staticmethod - def remove_nan_gradients(gradients: List[tf.Tensor], clip_value=0.1) -> List[tf.Tensor]: + def remove_nan_gradients(gradients: List[tf.Tensor], clip_value=10) -> List[tf.Tensor]: """Replace nan, inf with 0 and perform gradient clipping for the gradients computed by tf.GradientTape.gradient().