Skip to content

Commit

Permalink
Change input/output dimension for train-finetune (#1)
Browse files Browse the repository at this point in the history
* implement FCN (ctu-geoforall-lab-projects#45)

* BaseModel classifier: support turning off onehot encoding

* fix typo in error message

* propagate summary definitions

allow changing the summary width

* VGG: fix wrong kwargs in get_config

* detect.py: support more TF versions for setting seed

* support of changing input and output channel dimension for finetune models

* remove non relvant changes from PR

* cleanup and comment changes

* checkout code-stand metz

* fix setting model weights

* fix load model weights

* AW review

---------

Co-authored-by: Ondrej Pesek <pesej.ondrek@gmail.com>
  • Loading branch information
linakrisztian and pesekon2 committed Aug 14, 2024
1 parent 730e2bf commit 4d46d15
Showing 1 changed file with 92 additions and 8 deletions.
100 changes: 92 additions & 8 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import argparse
import sys

import numpy as np
import tensorflow as tf
Expand All @@ -21,15 +22,77 @@ def rescale_image(input_image, input_mask):

return input_image, input_mask

def load_pretrained_model(model, id2code,
tensor_shape, loss_function, tversky_alpha, tversky_beta,
dropout_rate_input, dropout_rate_hidden, backbone, name,
in_weights_path, model_new,
finetune_old_inp_dim, finetune_old_out_dim):
# if input or output dimension changed w.r.t pretrained model
if finetune_old_inp_dim or finetune_old_out_dim:
if model == "U-Net":
# set dimensions for creating pretrained model
if finetune_old_inp_dim:
nr_bands = finetune_old_inp_dim
if finetune_old_out_dim:
num_class = finetune_old_out_dim
else:
num_class = len(id2code)

# creating model with dimensions of pretrained model
# NOTE: do not set create_model to verbose=False
# --> need once run model.summary() -> otherwise model dimensions are not set
print("------------------------------")
print("-- Start: Dimensions of OLD Model: --")
print("------------------------------")
model_old = create_model(
model, num_class , nr_bands, tensor_shape, nr_filters=32, loss=loss_function,
alpha=tversky_alpha, beta=tversky_beta,
dropout_rate_input=dropout_rate_input,
dropout_rate_hidden=dropout_rate_hidden, backbone=backbone, name=name)
print("----------------------------------")
print("-- End: Dimensions of OLD Model: --")
print("----------------------------------")
# load model weights of pretrained model
model_old.load_weights(in_weights_path)

# Set weights of new model, with weights of pretrained model
# NOTE: model.layers returns list of model layers BUT not necessarily in the correct order
# Thus have to explicitely check for first and last layer index
# Get all layer names:
layer_names = [layer.name for layer in model_new.layers]
# Get layer index of first downsampling block
chlayer_first = model_new.ds_blocks[0].name
ind_chlayer_first = layer_names.index(chlayer_first)
# Get layer index of last layer od model
chlayer_last = "classifier_layer"
ind_chlayer_last = layer_names.index(chlayer_last)
# iterate over all layers to set the weights
for ind in range(0,len(model_new.layers)):
# if input dimension changed, don't set weigts for this layer in new model
if ind == ind_chlayer_first and finetune_old_inp_dim:
continue
# if output dimension changed, don't set weigts for this layer in new model
if ind == ind_chlayer_last and finetune_old_out_dim:
continue
# set weights from pretrained model, for all remaining layers
model_new.layers[ind].set_weights(model_old.layers[ind].get_weights())
else:
sys.exit("ERROR: Change of input or output dimensions w.r.t pretrained models only "
"supported for U-Net so far (parameter --finetune_old_inp_dim or --finetune_old_out_dim)")
else:
# if model dimension did not chainged, load weights from complete model
model_new.load_weights(in_weights_path)

def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None,
visualization_path='/tmp', nr_epochs=1, initial_epoch=0, batch_size=1,
loss_function='dice', seed=1, patience=100, tensor_shape=(256, 256),
monitored_value='val_accuracy', force_dataset_generation=False,
fit_memory=False, augment=False, tversky_alpha=0.5,
tversky_beta=0.5, dropout_rate_input=None, dropout_rate_hidden=None,
val_set_pct=0.2, filter_by_class=None, backbone=None, name='model',
verbose=1):
val_set_pct=0.2, filter_by_class=None, backbone=None,
finetune_old_inp_dim=None, finetune_old_out_dim=None,
name='model', verbose=1,
):
if verbose > 0:
utils.print_device_info()

Expand All @@ -48,7 +111,7 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None,
tf.keras.utils.set_random_seed(seed)

# tinyunet: nr_filters=32
model = create_model(
model_new = create_model(
model, len(id2code), nr_bands, tensor_shape, nr_filters=32, loss=loss_function,
alpha=tversky_alpha, beta=tversky_beta,
dropout_rate_input=dropout_rate_input,
Expand Down Expand Up @@ -78,9 +141,15 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None,
num_parallel_calls=tf.data.AUTOTUNE)
.repeat())

# load weights if the model is supposed to do so
# load weights if the model is supposed to do so (i.e. fine-tune mode)
if operation == 'fine-tune':
model.load_weights(in_weights_path)
load_pretrained_model(
model, id2code,
tensor_shape, loss_function, tversky_alpha, tversky_beta,
dropout_rate_input, dropout_rate_hidden, backbone, name,
in_weights_path, model_new,
finetune_old_inp_dim, finetune_old_out_dim
)

#train_generator = AugmentGenerator(
# data_dir, batch_size, 'train', fit_memory=fit_memory,
Expand All @@ -105,7 +174,7 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None,
.map(Augment())
.prefetch(buffer_size=tf.data.AUTOTUNE))

train(model, train_generator, train_nr_samples, val_generator, val_nr_samples, id2code, batch_size,
train(model_new, train_generator, train_nr_samples, val_generator, val_nr_samples, id2code, batch_size,
output_dir, visualization_path, model_fn, nr_epochs,
initial_epoch, seed=seed, patience=patience,
monitored_value=monitored_value, verbose=verbose)
Expand Down Expand Up @@ -294,13 +363,27 @@ def train(model, train_generator, train_nr_samples, val_generator, val_nr_sample
'--backbone', type=str, default=None,
choices=('ResNet50', 'ResNet101', 'ResNet152'),
help='Backbone architecture')

parser.add_argument(
"--finetune_old_inp_dim", type=int, default=None,
help="Input dimension of pretrained model, used for finetuning. "
"Set if dimension changed in new/currently trained model."
)
parser.add_argument(
"--finetune_old_out_dim", type=int, default=None,
help="Output dimension of pretrained model, used for finetuning. "
"Set if dimension changed in new/currently trained model."
)
args = parser.parse_args()

# check required arguments by individual operations
if args.operation == 'fine-tune' and args.weights_path is None:
raise parser.error(
'Argument weights_path required for operation == fine-tune')
if (args.finetune_old_inp_dim or args.finetune_old_out_dim) and args.operation != "fine-tune":
raise parser.error(
"Argument operation==fine-tune required for arguments "
"finetune_old_inp_dim or finetune_old_out_dim"
)
if args.operation == 'train' and args.initial_epoch != 0:
raise parser.error(
'Argument initial_epoch must be 0 for operation == train')
Expand All @@ -325,4 +408,5 @@ def train(model, train_generator, train_nr_samples, val_generator, val_nr_sample
args.augment_training_dataset, args.tversky_alpha,
args.tversky_beta, args.dropout_rate_input,
args.dropout_rate_hidden, args.validation_set_percentage,
args.filter_by_classes, args.backbone)
args.filter_by_classes, args.backbone,
args.finetune_old_inp_dim, args.finetune_old_out_dim)

0 comments on commit 4d46d15

Please sign in to comment.