Skip to content

Commit

Permalink
fix attribution target (use mean)
Browse files Browse the repository at this point in the history
  • Loading branch information
dn070017 committed Dec 14, 2023
1 parent 01e9c7c commit a2c9519
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from cavachon.environment.Constants import Constants
from cavachon.modules.parameterizers.Parameterizer import Parameterizer
from typing import Mapping

import tensorflow as tf

from cavachon.environment.Constants import Constants
from cavachon.modules.parameterizers.Parameterizer import Parameterizer


class IndependentZeroInflatedNegativeBinomial(Parameterizer):
"""IndependentZeroInflatedNegativeBinomial
Expand Down Expand Up @@ -66,7 +68,7 @@ def compute_attribution_target(self, inputs: tf.Tensor):
outputs = self.layer(inputs.get(Constants.TENSOR_NAME_X))
probs, means, dispersion = tf.split(outputs, 3, axis=-1)
probs = tf.keras.activations.sigmoid(probs)
return tf.concat([probs, means, dispersion], axis=-1)
return means

@classmethod
def modify_outputs(
Expand Down

0 comments on commit a2c9519

Please sign in to comment.