diff --git a/src/train.py b/src/train.py index d2d432f..4cc2e2f 100644 --- a/src/train.py +++ b/src/train.py @@ -2,6 +2,7 @@ import os import argparse +import sys import numpy as np import tensorflow as tf @@ -21,6 +22,66 @@ def rescale_image(input_image, input_mask): return input_image, input_mask +def load_pretrained_model(model, id2code, + tensor_shape, loss_function, tversky_alpha, tversky_beta, + dropout_rate_input, dropout_rate_hidden, backbone, name, + in_weights_path, model_new, + finetune_old_inp_dim, finetune_old_out_dim): + # if input or output dimension changed w.r.t pretrained model + if finetune_old_inp_dim or finetune_old_out_dim: + if model == "U-Net": + # set dimensions for creating pretrained model + if finetune_old_inp_dim: + nr_bands = finetune_old_inp_dim + if finetune_old_out_dim: + num_class = finetune_old_out_dim + else: + num_class = len(id2code) + + # creating model with dimensions of pretrained model + # NOTE: do not set create_model to verbose=False + # --> need once run model.summary() -> otherwise model dimensions are not set + print("------------------------------") + print("-- Start: Dimensions of OLD Model: --") + print("------------------------------") + model_old = create_model( + model, num_class , nr_bands, tensor_shape, nr_filters=32, loss=loss_function, + alpha=tversky_alpha, beta=tversky_beta, + dropout_rate_input=dropout_rate_input, + dropout_rate_hidden=dropout_rate_hidden, backbone=backbone, name=name) + print("----------------------------------") + print("-- End: Dimensions of OLD Model: --") + print("----------------------------------") + # load model weights of pretrained model + model_old.load_weights(in_weights_path) + + # Set weights of new model, with weights of pretrained model + # NOTE: model.layers returns list of model layers BUT not necessarily in the correct order + # Thus have to explicitely check for first and last layer index + # Get all layer names: + layer_names = [layer.name for layer in model_new.layers] + # Get layer index of first downsampling block + chlayer_first = model_new.ds_blocks[0].name + ind_chlayer_first = layer_names.index(chlayer_first) + # Get layer index of last layer od model + chlayer_last = "classifier_layer" + ind_chlayer_last = layer_names.index(chlayer_last) + # iterate over all layers to set the weights + for ind in range(0,len(model_new.layers)): + # if input dimension changed, don't set weigts for this layer in new model + if ind == ind_chlayer_first and finetune_old_inp_dim: + continue + # if output dimension changed, don't set weigts for this layer in new model + if ind == ind_chlayer_last and finetune_old_out_dim: + continue + # set weights from pretrained model, for all remaining layers + model_new.layers[ind].set_weights(model_old.layers[ind].get_weights()) + else: + sys.exit("ERROR: Change of input or output dimensions w.r.t pretrained models only " + "supported for U-Net so far (parameter --finetune_old_inp_dim or --finetune_old_out_dim)") + else: + # if model dimension did not chainged, load weights from complete model + model_new.load_weights(in_weights_path) def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None, visualization_path='/tmp', nr_epochs=1, initial_epoch=0, batch_size=1, @@ -28,8 +89,10 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None, monitored_value='val_accuracy', force_dataset_generation=False, fit_memory=False, augment=False, tversky_alpha=0.5, tversky_beta=0.5, dropout_rate_input=None, dropout_rate_hidden=None, - val_set_pct=0.2, filter_by_class=None, backbone=None, name='model', - verbose=1): + val_set_pct=0.2, filter_by_class=None, backbone=None, + finetune_old_inp_dim=None, finetune_old_out_dim=None, + name='model', verbose=1, + ): if verbose > 0: utils.print_device_info() @@ -48,7 +111,7 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None, tf.keras.utils.set_random_seed(seed) # tinyunet: nr_filters=32 - model = create_model( + model_new = create_model( model, len(id2code), nr_bands, tensor_shape, nr_filters=32, loss=loss_function, alpha=tversky_alpha, beta=tversky_beta, dropout_rate_input=dropout_rate_input, @@ -78,9 +141,15 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None, num_parallel_calls=tf.data.AUTOTUNE) .repeat()) - # load weights if the model is supposed to do so + # load weights if the model is supposed to do so (i.e. fine-tune mode) if operation == 'fine-tune': - model.load_weights(in_weights_path) + load_pretrained_model( + model, id2code, + tensor_shape, loss_function, tversky_alpha, tversky_beta, + dropout_rate_input, dropout_rate_hidden, backbone, name, + in_weights_path, model_new, + finetune_old_inp_dim, finetune_old_out_dim + ) #train_generator = AugmentGenerator( # data_dir, batch_size, 'train', fit_memory=fit_memory, @@ -105,7 +174,7 @@ def main(operation, data_dir, output_dir, model, model_fn, in_weights_path=None, .map(Augment()) .prefetch(buffer_size=tf.data.AUTOTUNE)) - train(model, train_generator, train_nr_samples, val_generator, val_nr_samples, id2code, batch_size, + train(model_new, train_generator, train_nr_samples, val_generator, val_nr_samples, id2code, batch_size, output_dir, visualization_path, model_fn, nr_epochs, initial_epoch, seed=seed, patience=patience, monitored_value=monitored_value, verbose=verbose) @@ -294,13 +363,27 @@ def train(model, train_generator, train_nr_samples, val_generator, val_nr_sample '--backbone', type=str, default=None, choices=('ResNet50', 'ResNet101', 'ResNet152'), help='Backbone architecture') - + parser.add_argument( + "--finetune_old_inp_dim", type=int, default=None, + help="Input dimension of pretrained model, used for finetuning. " + "Set if dimension changed in new/currently trained model." + ) + parser.add_argument( + "--finetune_old_out_dim", type=int, default=None, + help="Output dimension of pretrained model, used for finetuning. " + "Set if dimension changed in new/currently trained model." + ) args = parser.parse_args() # check required arguments by individual operations if args.operation == 'fine-tune' and args.weights_path is None: raise parser.error( 'Argument weights_path required for operation == fine-tune') + if (args.finetune_old_inp_dim or args.finetune_old_out_dim) and args.operation != "fine-tune": + raise parser.error( + "Argument operation==fine-tune required for arguments " + "finetune_old_inp_dim or finetune_old_out_dim" + ) if args.operation == 'train' and args.initial_epoch != 0: raise parser.error( 'Argument initial_epoch must be 0 for operation == train') @@ -325,4 +408,5 @@ def train(model, train_generator, train_nr_samples, val_generator, val_nr_sample args.augment_training_dataset, args.tversky_alpha, args.tversky_beta, args.dropout_rate_input, args.dropout_rate_hidden, args.validation_set_percentage, - args.filter_by_classes, args.backbone) + args.filter_by_classes, args.backbone, + args.finetune_old_inp_dim, args.finetune_old_out_dim)