forked from GeorgeSeif/Semantic-Segmentation-Suite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPSPNet.py
108 lines (86 loc) · 4.42 KB
/
PSPNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import tensorflow as tf
from tensorflow.contrib import slim
import numpy as np
from builders import frontend_builder
import os, sys
def Upsampling(inputs,feature_map_shape):
return tf.image.resize_bilinear(inputs, size=feature_map_shape)
def ConvUpscaleBlock(inputs, n_filters, kernel_size=[3, 3], scale=2):
"""
Basic conv transpose block for Encoder-Decoder upsampling
Apply successivly Transposed Convolution, BatchNormalization, ReLU nonlinearity
"""
net = tf.nn.relu(slim.batch_norm(inputs, fused=True))
net = slim.conv2d_transpose(net, n_filters, kernel_size=[3, 3], stride=[scale, scale], activation_fn=None)
return net
def ConvBlock(inputs, n_filters, kernel_size=[3, 3]):
"""
Basic conv block for Encoder-Decoder
Apply successivly Convolution, BatchNormalization, ReLU nonlinearity
"""
net = tf.nn.relu(slim.batch_norm(inputs, fused=True))
net = slim.conv2d(net, n_filters, kernel_size, activation_fn=None, normalizer_fn=None)
return net
def InterpBlock(net, level, feature_map_shape, pooling_type):
# Compute the kernel and stride sizes according to how large the final feature map will be
# When the kernel size and strides are equal, then we can compute the final feature map size
# by simply dividing the current size by the kernel or stride size
# The final feature map sizes are 1x1, 2x2, 3x3, and 6x6. We round to the closest integer
kernel_size = [int(np.round(float(feature_map_shape[0]) / float(level))), int(np.round(float(feature_map_shape[1]) / float(level)))]
stride_size = kernel_size
net = slim.pool(net, kernel_size, stride=stride_size, pooling_type='MAX')
net = slim.conv2d(net, 512, [1, 1], activation_fn=None)
net = slim.batch_norm(net, fused=True)
net = tf.nn.relu(net)
net = Upsampling(net, feature_map_shape)
return net
def PyramidPoolingModule(inputs, feature_map_shape, pooling_type):
"""
Build the Pyramid Pooling Module.
"""
interp_block1 = InterpBlock(inputs, 1, feature_map_shape, pooling_type)
interp_block2 = InterpBlock(inputs, 2, feature_map_shape, pooling_type)
interp_block3 = InterpBlock(inputs, 3, feature_map_shape, pooling_type)
interp_block6 = InterpBlock(inputs, 6, feature_map_shape, pooling_type)
res = tf.concat([inputs, interp_block6, interp_block3, interp_block2, interp_block1], axis=-1)
return res
def build_pspnet(inputs, label_size, num_classes, preset_model='PSPNet', frontend="ResNet101", pooling_type = "MAX",
weight_decay=1e-5, upscaling_method="conv", is_training=True, pretrained_dir="models"):
"""
Builds the PSPNet model.
Arguments:
inputs: The input tensor
label_size: Size of the final label tensor. We need to know this for proper upscaling
preset_model: Which model you want to use. Select which ResNet model to use for feature extraction
num_classes: Number of classes
pooling_type: Max or Average pooling
Returns:
PSPNet model
"""
logits, end_points, frontend_scope, init_fn = frontend_builder.build_frontend(inputs, frontend, pretrained_dir=pretrained_dir, is_training=is_training)
feature_map_shape = [int(x / 8.0) for x in label_size]
print(feature_map_shape)
psp = PyramidPoolingModule(end_points['pool3'], feature_map_shape=feature_map_shape, pooling_type=pooling_type)
net = slim.conv2d(psp, 512, [3, 3], activation_fn=None)
net = slim.batch_norm(net, fused=True)
net = tf.nn.relu(net)
if upscaling_method.lower() == "conv":
net = ConvUpscaleBlock(net, 256, kernel_size=[3, 3], scale=2)
net = ConvBlock(net, 256)
net = ConvUpscaleBlock(net, 128, kernel_size=[3, 3], scale=2)
net = ConvBlock(net, 128)
net = ConvUpscaleBlock(net, 64, kernel_size=[3, 3], scale=2)
net = ConvBlock(net, 64)
elif upscaling_method.lower() == "bilinear":
net = Upsampling(net, label_size)
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, scope='logits')
return net, init_fn
def mean_image_subtraction(inputs, means=[123.68, 116.78, 103.94]):
inputs=tf.to_float(inputs)
num_channels = inputs.get_shape().as_list()[-1]
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')
channels = tf.split(axis=3, num_or_size_splits=num_channels, value=inputs)
for i in range(num_channels):
channels[i] -= means[i]
return tf.concat(axis=3, values=channels)