residual_network.py
Go to the documentation of this file.
1 from keras import layers
2 from keras import models
3 
4 
5 #
6 # image dimensions
7 #
8 
9 img_height = 500
10 img_width = 500
11 img_channels = 3
12 
13 #
14 # network params
15 #
16 
17 cardinality = 1
18 
19 
21  """
22  ResNeXt by default. For ResNet set `cardinality` = 1 above.
23 
24  """
25  def add_common_layers(y):
26  y = layers.BatchNormalization()(y)
27  y = layers.LeakyReLU()(y)
28 
29  return y
30 
31  def grouped_convolution(y, nb_channels, _strides):
32  # when `cardinality` == 1 this is just a standard convolution
33  if cardinality == 1:
34  return layers.Conv2D(nb_channels, kernel_size=(3, 3), strides=_strides, padding='same')(y)
35 
36  assert not nb_channels % cardinality
37  _d = nb_channels // cardinality
38 
39  # in a grouped convolution layer, input and output channels are divided into `cardinality` groups,
40  # and convolutions are separately performed within each group
41  groups = []
42  for j in range(cardinality):
43  group = layers.Lambda(lambda z: z[:, :, :, j * _d:j * _d + _d])(y)
44  groups.append(layers.Conv2D(_d, kernel_size=(3, 3), strides=_strides, padding='same')(group))
45 
46  # the grouped convolutional layer concatenates them as the outputs of the layer
47  y = layers.concatenate(groups)
48 
49  return y
50 
51  def residual_block(y, nb_channels_in, nb_channels_out, _strides=(1, 1), _project_shortcut=False):
52  """
53  Our network consists of a stack of residual blocks. These blocks have the same topology,
54  and are subject to two simple rules:
55  - If producing spatial maps of the same size, the blocks share the same hyper-parameters (width and filter sizes).
56  - Each time the spatial map is down-sampled by a factor of 2, the width of the blocks is multiplied by a factor of 2.
57  """
58  shortcut = y
59 
60  # we modify the residual building block as a bottleneck design to make the network more economical
61  y = layers.Conv2D(nb_channels_in, kernel_size=(1, 1), strides=(1, 1), padding='same')(y)
62  y = add_common_layers(y)
63 
64  # ResNeXt (identical to ResNet when `cardinality` == 1)
65  y = grouped_convolution(y, nb_channels_in, _strides=_strides)
66  y = add_common_layers(y)
67 
68  y = layers.Conv2D(nb_channels_out, kernel_size=(1, 1), strides=(1, 1), padding='same')(y)
69  # batch normalization is employed after aggregating the transformations and before adding to the shortcut
70  y = layers.BatchNormalization()(y)
71 
72  # identity shortcuts used directly when the input and output are of the same dimensions
73  if _project_shortcut or _strides != (1, 1):
74  # when the dimensions increase projection shortcut is used to match dimensions (done by 1x1 convolutions)
75  # when the shortcuts go across feature maps of two sizes, they are performed with a stride of 2
76  shortcut = layers.Conv2D(nb_channels_out, kernel_size=(1, 1), strides=_strides, padding='same')(shortcut)
77  shortcut = layers.BatchNormalization()(shortcut)
78 
79  y = layers.add([shortcut, y])
80 
81  # relu is performed right after each batch normalization,
82  # expect for the output of the block where relu is performed after the adding to the shortcut
83  y = layers.LeakyReLU()(y)
84 
85  return y
86 
87  # conv1
88  x = layers.Conv2D(64, kernel_size=(7, 7), strides=(2, 2), padding='same')(x)
89  x = add_common_layers(x)
90 
91  # conv2
92  x = layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
93  for i in range(3):
94  project_shortcut = True if i == 0 else False
95  x = residual_block(x, 128, 256, _project_shortcut=project_shortcut)
96 
97  # conv3
98  for i in range(4):
99  # down-sampling is performed by conv3_1, conv4_1, and conv5_1 with a stride of 2
100  strides = (2, 2) if i == 0 else (1, 1)
101  x = residual_block(x, 256, 512, _strides=strides)
102 
103  # conv4
104  for i in range(6):
105  strides = (2, 2) if i == 0 else (1, 1)
106  x = residual_block(x, 512, 1024, _strides=strides)
107 
108  # conv5
109  for i in range(3):
110  strides = (2, 2) if i == 0 else (1, 1)
111  x = residual_block(x, 1024, 2048, _strides=strides)
112 
113  x = layers.GlobalAveragePooling2D()(x)
114  x = layers.Dense(13)(x)
115 
116  return x
117 
119  image_tensor = layers.Input(shape=(img_height, img_width, img_channels))
120  network_output = residual_network(image_tensor)
121 
122  model = models.Model(inputs=[image_tensor], outputs=[network_output])
123  return model