-
Notifications
You must be signed in to change notification settings - Fork 0
/
modules.py
46 lines (34 loc) · 1.88 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
'''
Utilities for DeepLab
Lei Mao
Department of Computer Science
University of Chicago
dukeleimao@gmail.com
'''
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
def atrous_spatial_pyramid_pooling(inputs, filters=256, regularizer=None):
'''
Atrous Spatial Pyramid Pooling (ASPP) Block
'''
pool_height = tf.shape(inputs)[1]
pool_width = tf.shape(inputs)[2]
resize_height = pool_height
resize_width = pool_width
# Atrous Spatial Pyramid Pooling
# Atrous 1x1
aspp1x1 = tf.layers.conv2d(inputs=inputs, filters=filters, kernel_size=(1, 1), padding='same', kernel_regularizer=regularizer, name='aspp1x1')
# Atrous 3x3, rate = 6
aspp3x3_1 = tf.layers.conv2d(inputs=inputs, filters=filters, kernel_size=(3, 3), padding='same', dilation_rate=(6, 6), kernel_regularizer=regularizer, name='aspp3x3_1')
# Atrous 3x3, rate = 12
aspp3x3_2 = tf.layers.conv2d(inputs=inputs, filters=filters, kernel_size=(3, 3), padding='same', dilation_rate=(12, 12), kernel_regularizer=regularizer, name='aspp3x3_2')
# Atrous 3x3, rate = 18
aspp3x3_3 = tf.layers.conv2d(inputs=inputs, filters=filters, kernel_size=(3, 3), padding='same', dilation_rate=(18, 18), kernel_regularizer=regularizer, name='aspp3x3_3')
# Image Level Pooling
image_feature = tf.reduce_mean(inputs, [1, 2], keepdims=True)
image_feature = tf.layers.conv2d(inputs=image_feature, filters=filters, kernel_size=(1, 1), padding='same')
image_feature = tf.image.resize_bilinear(images=image_feature, size=[resize_height, resize_width], align_corners=True, name='image_pool_feature')
# Merge Poolings
outputs = tf.concat(values=[aspp1x1, aspp3x3_1, aspp3x3_2, aspp3x3_3, image_feature], axis=3, name='aspp_pools')
outputs = tf.layers.conv2d(inputs=outputs, filters=filters, kernel_size=(1, 1), padding='same', kernel_regularizer=regularizer, name='aspp_outputs')
return outputs