Skip to content

Commit

Permalink
BaseModel classifier: support turning off onehot encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
pesekon2 committed Sep 21, 2023
1 parent 834de85 commit 4e1a193
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class _BaseModel(Model, ABC):

def __init__(self, nr_classes, nr_bands=12, nr_filters=64, batch_norm=True,
dilation_rate=1, tensor_shape=(256, 256),
activation=k_layers.ReLU,
padding='same', dropout_rate_input=None,
dropout_rate_hidden=None, use_bias=True, name='model', **kwargs):
activation=k_layers.ReLU, padding='same',
dropout_rate_input=None, dropout_rate_hidden=None,
use_bias=True, onehot_encode=True, name='model', **kwargs):
"""Model constructor.
:param nr_classes: number of classes to be predicted
Expand All @@ -40,8 +40,9 @@ def __init__(self, nr_classes, nr_bands=12, nr_filters=64, batch_norm=True,
units of the input layer to drop
:param dropout_rate_hidden: float between 0 and 1. Fraction of
the input
:param name: The name of the model
:param use_bias: Boolean, whether the layer uses a bias vector
:param onehot_encode: boolean to onehot-encode masks in the last layer
:param name: The name of the model
"""
super(_BaseModel, self).__init__(name=name, **kwargs)

Expand All @@ -58,6 +59,7 @@ def __init__(self, nr_classes, nr_bands=12, nr_filters=64, batch_norm=True,
self.use_bias = use_bias
# TODO: Maybe use_bias should be by default == False, see:
# https://arxiv.org/pdf/1502.03167.pdf
self.onehot_encode = onehot_encode

self.check_parameters()

Expand Down Expand Up @@ -106,7 +108,11 @@ def get_classifier_layer(self):
:return: the classifier layer
"""
return Conv2D(self.nr_classes,
if self.onehot_encode is True:
nr_filters = self.nr_classes
else:
nr_filters = 1
return Conv2D(nr_filters,
(1, 1),
activation=self.get_classifier_function(),
padding=self.padding,
Expand Down

1 comment on commit 4e1a193

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src
   architectures.py3856184%75, 88, 100, 114, 150, 161, 561, 565, 590–608, 619–626, 718, 728, 798, 855–858, 915–921, 976, 999–1017, 1028–1034, 1073–1080, 1096, 1136, 1209, 1239–1241, 1243, 1277–1282
   cnn_lib.py3099868%48, 117, 128–142, 160, 171, 181–183, 195–201, 225, 270–273, 293–315, 409–420, 451–465, 529–534, 582–592, 651–655, 691–701, 763–777, 842–853, 894–903, 913–917, 938–945, 984–1003, 1014–1019, 1028
   data_preparation.py103991%35, 56, 60, 88, 107, 117, 125, 131, 183
   detect.py67670%3–186
   train.py834447%29, 41, 59, 96, 110, 147–272
   utils.py432151%53–60, 73–80, 94–105
   visualization.py916133%26–40, 75, 100–188
TOTAL108736167% 

Tests Skipped Failures Errors Time
3 0 💤 0 ❌ 0 🔥 194m 49s ⏱️

Please sign in to comment.