Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement FCN #45

Merged
merged 21 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
355 changes: 354 additions & 1 deletion src/architectures.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions src/cnn_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class AugmentGenerator:
def __init__(self, data_dir, batch_size=5, operation='train',
tensor_shape=(256, 256), force_dataset_generation=False,
fit_memory=False, augment=False, onehot_encode=True,
val_set_pct=0.2, filter_by_class=None, verbose=1):
val_set_pct=0.2, filter_by_class=None, ignore_masks=False,
verbose=1):
"""Initialize the generator.

:param data_dir: path to the directory containing images
Expand All @@ -40,6 +41,7 @@ def __init__(self, data_dir, batch_size=5, operation='train',
:param filter_by_class: classes of interest (for the case of dataset
generation - if specified, only samples containing at least one of
them will be created)
:param ignore_masks: do not create nor return masks
:param verbose: verbosity (0=quiet, >0 verbose)
"""
if operation not in ('train', 'val'):
Expand All @@ -54,7 +56,7 @@ def __init__(self, data_dir, batch_size=5, operation='train',
do_exist = [os.path.isdir(i) is True for i in (images_dir, masks_dir)]
if force_dataset_generation is True or all(do_exist) is False:
generate_dataset_structure(data_dir, tensor_shape, val_set_pct,
filter_by_class, augment,
filter_by_class, augment, ignore_masks,
verbose=verbose)

# create variables useful throughout the entire class
Expand Down
109 changes: 59 additions & 50 deletions src/data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def generate_dataset_structure(data_dir, tensor_shape=(256, 256),
val_set_pct=0.2, filter_by_class=None,
augment=True, verbose=1):
augment=True, ignore_masks=False, verbose=1):
"""Generate the expected dataset structure.

Will generate directories train_images, train_masks, val_images and
Expand All @@ -25,10 +25,14 @@ def generate_dataset_structure(data_dir, tensor_shape=(256, 256),
:param filter_by_class: classes of interest (if specified, only samples
containing at least one of them will be created)
:param augment: boolean saying whether to augment the dataset or not
:param ignore_masks: do not create masks
:param verbose: verbosity (0=quiet, >0 verbose)
"""
# Create folders to hold images and masks
dirs = ('train_images', 'train_masks', 'val_images', 'val_masks')
if ignore_masks is False:
dirs = ('train_images', 'train_masks', 'val_images', 'val_masks')
else:
dirs = ('train_images', 'val_images')

for directory in dirs:
dir_full_path = os.path.join(data_dir, directory)
Expand All @@ -43,7 +47,7 @@ def generate_dataset_structure(data_dir, tensor_shape=(256, 256),
source_images = sorted(glob.glob(os.path.join(data_dir, '*image.tif')))
for i in source_images:
tile(i, i.replace('image.tif', 'label.tif'), tensor_shape,
filter_by_class, augment, dir_names)
filter_by_class, augment, dir_names, ignore_masks)

# check if there are some training data
train_images_nr = len(os.listdir(os.path.join(data_dir, 'train_images')))
Expand All @@ -59,7 +63,7 @@ def generate_dataset_structure(data_dir, tensor_shape=(256, 256),


def tile(scene_path, labels_path, tensor_shape, filter_by_class=None,
augment=True, dir_names=None):
augment=True, dir_names=None, ignore_masks=False):
"""Tile the big scene into smaller samples and write them.

If filter_by_class is not None, only samples containing at least one of
Expand All @@ -74,6 +78,7 @@ def tile(scene_path, labels_path, tensor_shape, filter_by_class=None,
containing at least one of them will be returned)
:param augment: boolean saying whether to augment the dataset or not
:param dir_names: a generator determining directory names (train/val)
:param ignore_masks: do not create masks
"""
rows_step = tensor_shape[0]
cols_step = tensor_shape[1]
Expand All @@ -87,34 +92,29 @@ def tile(scene_path, labels_path, tensor_shape, filter_by_class=None,

# the following variables are defined here to avoid creating them in the
# loop later
if augment is True:
driver = gdal.GetDriverByName("GTiff")
scene = gdal.Open(scene_path, gdal.GA_ReadOnly)
nr_bands = scene.RasterCount
projection = scene.GetProjection()
data_type = scene.GetRasterBand(1).DataType
scene = None

if cols_step == rows_step:
rotations = (1, 2, 3)
else:
rotations = (2, )
driver = gdal.GetDriverByName("GTiff")
scene = gdal.Open(scene_path, gdal.GA_ReadOnly)
nr_bands = scene.RasterCount
projection = scene.GetProjection()
data_type = scene.GetRasterBand(1).DataType
nr_rows = scene.RasterYSize
nr_cols = scene.RasterXSize
scene = None

if cols_step == rows_step:
rotations = (1, 2, 3)
else:
driver = None
nr_bands = None
projection = None
data_type = None

rotations = None
rotations = (2, )

# do not write aux.xml files
os.environ['GDAL_PAM_ENABLED'] = 'NO'

# get variables for the loop and checks
labels = gdal.Open(labels_path, gdal.GA_ReadOnly)
labels_np = labels.GetRasterBand(1).ReadAsArray()
nr_rows = labels.RasterYSize
nr_cols = labels.RasterXSize
# get variables for the loop checks
if ignore_masks is False:
labels = gdal.Open(labels_path, gdal.GA_ReadOnly)
labels_np = labels.GetRasterBand(1).ReadAsArray()
else:
labels_np = None

scene_dir, scene_name = os.path.split(scene_path[:-10])

Expand All @@ -131,7 +131,7 @@ def tile(scene_path, labels_path, tensor_shape, filter_by_class=None,
j = nr_rows - rows_step

# if filtering, check if it makes sense to continue
if filt is True:
if filt is True and ignore_masks is False:
labels_cropped = labels_np[j:j + rows_step, i:i + cols_step]
if not any(i in labels_cropped for i in filter_by_class):
# no occurrence of classes to filter by - continue with
Expand All @@ -146,17 +146,20 @@ def tile(scene_path, labels_path, tensor_shape, filter_by_class=None,
output_scene_path = os.path.join(scene_dir,
'{}_images'.format(dir_name),
scene_name + f'_{i}_{j}.tif')
output_mask_path = os.path.join(scene_dir,
'{}_masks'.format(dir_name),
scene_name + f'_{i}_{j}.tif')

# crop
gdal.Translate(output_scene_path,
scene_path,
srcWin=(i, j, cols_step, rows_step))
gdal.Translate(output_mask_path,
labels_path,
srcWin=(i, j, cols_step, rows_step))

if ignore_masks is False:
# do the same for masks
output_mask_path = os.path.join(scene_dir,
'{}_masks'.format(dir_name),
scene_name + f'_{i}_{j}.tif')
gdal.Translate(output_mask_path,
labels_path,
srcWin=(i, j, cols_step, rows_step))

if augment is False:
# the following code is unnecessary then
Expand All @@ -173,8 +176,11 @@ def tile(scene_path, labels_path, tensor_shape, filter_by_class=None,
src_bands.append(
src_scene.GetRasterBand(band_i).ReadAsArray())

src_mask = gdal.Open(output_mask_path, gdal.GA_ReadOnly)
src_mask_band = src_mask.GetRasterBand(1).ReadAsArray()
if ignore_masks is False:
src_mask = gdal.Open(output_mask_path, gdal.GA_ReadOnly)
src_mask_band = src_mask.GetRasterBand(1).ReadAsArray()
else:
src_mask_band = None

src_scene = None
src_mask = None
Expand All @@ -186,9 +192,6 @@ def tile(scene_path, labels_path, tensor_shape, filter_by_class=None,
rot_scene_path = os.path.join(
scene_dir, '{}_images'.format(dir_name),
scene_name + f'_{i}_{j}_rot{rot_k * 90}.tif')
rot_mask_path = os.path.join(
scene_dir, '{}_masks'.format(dir_name),
scene_name + f'_{i}_{j}_rot{rot_k * 90}.tif')

# create files
out_scene = driver.Create(
Expand All @@ -197,16 +200,9 @@ def tile(scene_path, labels_path, tensor_shape, filter_by_class=None,
rows_step,
nr_bands,
data_type)
out_mask = driver.Create(
rot_mask_path,
cols_step,
rows_step,
1,
gdal.GDT_UInt16)

out_scene.SetGeoTransform(geo_transform)
out_mask.SetGeoTransform(geo_transform)
out_scene.SetProjection(projection)
out_mask.SetProjection(projection)

# write rotated arrays
for band_i in range(nr_bands):
Expand All @@ -215,9 +211,22 @@ def tile(scene_path, labels_path, tensor_shape, filter_by_class=None,
out_scene_band.WriteArray(
np.rot90(src_bands[band_i], rot_k), 0, 0)

out_mask_band = out_mask.GetRasterBand(1)
out_mask_band.WriteArray(
np.rot90(src_mask_band, rot_k), 0, 0)
if ignore_masks is False:
# do the same for masks
rot_mask_path = os.path.join(
scene_dir, '{}_masks'.format(dir_name),
scene_name + f'_{i}_{j}_rot{rot_k * 90}.tif')
out_mask = driver.Create(
rot_mask_path,
cols_step,
rows_step,
1,
gdal.GDT_UInt16)
out_mask.SetGeoTransform(geo_transform)
out_mask.SetProjection(projection)
out_mask_band = out_mask.GetRasterBand(1)
out_mask_band.WriteArray(
np.rot90(src_mask_band, rot_k), 0, 0)

out_scene = None
out_mask = None
Expand Down
27 changes: 18 additions & 9 deletions src/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def main(data_dir, model, in_weights_path, visualization_path, batch_size,
seed, tensor_shape, force_dataset_generation, fit_memory, val_set_pct,
filter_by_class, backbone=None):
filter_by_class, backbone=None, ignore_masks=False):
utils.print_device_info()

# get nr of bands
Expand All @@ -35,18 +35,20 @@ def main(data_dir, model, in_weights_path, visualization_path, batch_size,
# val generator used for both the training and the detection
val_generator = AugmentGenerator(
data_dir, batch_size, 'val', tensor_shape, force_dataset_generation,
fit_memory, val_set_pct=val_set_pct, filter_by_class=filter_by_class)
fit_memory, val_set_pct=val_set_pct, filter_by_class=filter_by_class,
ignore_masks=ignore_masks)

# load weights if the model is supposed to do so
model.load_weights(in_weights_path)
model.set_weights(utils.model_replace_nans(model.get_weights()))

detect(model, val_generator, id2code, [i for i in label_codes],
label_names, data_dir, seed, visualization_path)
label_names, data_dir, seed, visualization_path,
ignore_masks=ignore_masks)


def detect(model, val_generator, id2code, label_codes, label_names,
data_dir, seed=1, out_dir='/tmp'):
data_dir, seed=1, out_dir='/tmp', ignore_masks=False):
"""Run detection.

:param model: model to be used for the detection
Expand All @@ -57,14 +59,16 @@ def detect(model, val_generator, id2code, label_codes, label_names,
:param data_dir: path to the directory containing images and labels
:param seed: the generator seed
:param out_dir: directory where the output visualizations will be saved
:param ignore_masks: if computing average statistics (True) or running only
prediction (False)
"""
testing_gen = val_generator(id2code, seed)

if not os.path.exists(out_dir):
os.makedirs(out_dir)

# get information needed to write referenced geotifs of detections
geoinfos = get_geoinfo(val_generator.masks_dir)
geoinfos = get_geoinfo(val_generator.images_dir)

batch_size = val_generator.batch_size

Expand All @@ -78,7 +82,7 @@ def detect(model, val_generator, id2code, label_codes, label_names,
# visualize the natch
visualize_detections(batch_img, batch_mask, pred_all, id2code,
label_codes, label_names, batch_geoinfos,
out_dir)
out_dir, ignore_masks=ignore_masks)


def get_geoinfo(data_dir):
Expand Down Expand Up @@ -111,7 +115,7 @@ def get_geoinfo(data_dir):
help='Path to the directory containing images and labels')
parser.add_argument(
'--model', type=str, default='U-Net',
choices=('U-Net', 'SegNet', 'DeepLab'),
choices=('U-Net', 'SegNet', 'DeepLab', 'FCN'),
help='Model architecture')
parser.add_argument(
'--weights_path', type=str, default=None,
Expand Down Expand Up @@ -154,8 +158,12 @@ def get_geoinfo(data_dir):
'comma-separated (e.g. "1,2,6" to filter by classes 1, 2 and 6)')
parser.add_argument(
'--backbone', type=str, default=None,
choices=('ResNet50', 'ResNet101', 'ResNet152'),
choices=('ResNet50', 'ResNet101', 'ResNet152', 'VGG16'),
help='Backbone architecture')
parser.add_argument(
'--ignore_masks', type=utils.str2bool, default=False,
help='Boolean to decide if computing also average statstics based on '
'grand truth data or running only the prediction')

args = parser.parse_args()

Expand All @@ -171,4 +179,5 @@ def get_geoinfo(data_dir):
main(args.data_dir, args.model, args.weights_path, args.visualization_path,
args.batch_size, args.seed, (args.tensor_height, args.tensor_width),
args.force_dataset_generation, args.fit_dataset_in_memory,
args.validation_set_percentage, args.filter_by_classes, args.backbone)
args.validation_set_percentage, args.filter_by_classes,
args.backbone, args.ignore_masks)
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
Model: "fcn_drop0.5_VGG16_categorical_crossentropy"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input (InputLayer) [(None, 192, 192, 1 0 []
2)]

vgg16 (VGG) [(None, 24, 24, 256 31259520 ['input[0][0]']
),
(None, 12, 12, 512
),
(None, 6, 6, 1024)
]

block5_conv1 (ConvBlock) (None, 6, 6, 4096) 151015424 ['vgg16[0][2]']

block5_conv2 (ConvBlock) (None, 6, 6, 4096) 16797696 ['block5_conv1[0][0]']

block5_class (Conv2D) (None, 6, 6, 3) 12291 ['block5_conv2[0][0]']

upsampling_5_to_4 (UpSampling2 (None, 12, 12, 3) 0 ['block5_class[0][0]']
D)

block4_class (Conv2D) (None, 12, 12, 3) 1539 ['vgg16[0][1]']

concat_5_to_4 (Concatenate) (None, 12, 12, 6) 0 ['upsampling_5_to_4[0][0]',
'block4_class[0][0]']

upsampling_4_to_3 (UpSampling2 (None, 24, 24, 6) 0 ['concat_5_to_4[0][0]']
D)

block3_class (Conv2D) (None, 24, 24, 3) 771 ['vgg16[0][0]']

concat_4_to_3 (Concatenate) (None, 24, 24, 9) 0 ['upsampling_4_to_3[0][0]',
'block3_class[0][0]']

upsampling_final (UpSampling2D (None, 192, 192, 9) 0 ['concat_4_to_3[0][0]']
)

classifier_layer (Conv2D) (None, 192, 192, 3) 30 ['upsampling_final[0][0]']

==================================================================================================
Total params: 199,087,271
Trainable params: 199,059,367
Non-trainable params: 27,904
__________________________________________________________________________________________________

Epoch 00001: val_loss improved from inf to 1.23222, saving model to /tmp/output_fcn_drop0.5_VGG16_categorical_crossentropy/model.h5

Epoch 00002: val_loss improved from 1.23222 to 0.65152, saving model to /tmp/output_fcn_drop0.5_VGG16_categorical_crossentropy/model.h5
Loading