se_resnet.py
Go to the documentation of this file.
1 '''
2 Based on https://github.com/titu1994/keras-squeeze-excite-network/blob/master/se_resnet.py
3 
4 Squeeze-and-Excitation ResNets
5 
6 References:
7  - [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
8  - []() # added when paper is published on Arxiv
9 '''
10 from __future__ import print_function
11 from __future__ import absolute_import
12 from __future__ import division
13 
14 from keras.models import Model
15 from keras.layers import Input
16 from keras.layers import Dense
17 from keras.layers import Reshape
18 from keras.layers import Activation
19 from keras.layers import BatchNormalization
20 from keras.layers import MaxPooling2D
21 from keras.layers import GlobalAveragePooling2D
22 from keras.layers import GlobalMaxPooling2D
23 from keras.layers import Conv2D
24 from keras.layers import add
25 from keras.layers import concatenate
26 from keras.layers import multiply
27 from keras.regularizers import l2
28 from keras.utils import conv_utils
29 from keras.utils.data_utils import get_file
30 from keras.engine.topology import get_source_inputs
31 from keras.applications.imagenet_utils import _obtain_input_shape
32 from keras.applications.resnet50 import preprocess_input
33 from keras.applications.imagenet_utils import decode_predictions
34 from keras import backend as K
35 
36 from se import squeeze_excite_block
37 
38 __all__ = ['SEResNet', 'SEResNet50', 'SEResNet101', 'SEResNet154', 'preprocess_input', 'decode_predictions']
39 
40 
41 WEIGHTS_PATH = ""
42 WEIGHTS_PATH_NO_TOP = ""
43 
44 
45 def SEResNet(input_shape=None,
46  initial_conv_filters=64,
47  depth=[3, 4, 6, 3],
48  filters=[64, 128, 256, 512],
49  width=1,
50  bottleneck=False,
51  weight_decay=1e-4,
52  include_top=True,
53  weights=None,
54  input_tensor=None,
55  pooling=None,
56  classes=1000):
57  """ Instantiate the Squeeze and Excite ResNet architecture. Note that ,
58  when using TensorFlow for best performance you should set
59  `image_data_format="channels_last"` in your Keras config
60  at ~/.keras/keras.json.
61  The model are compatible with both
62  TensorFlow and Theano. The dimension ordering
63  convention used by the model is the one
64  specified in your Keras config file.
65  # Arguments
66  initial_conv_filters: number of features for the initial convolution
67  depth: number or layers in the each block, defined as a list.
68  ResNet-50 = [3, 4, 6, 3]
69  ResNet-101 = [3, 6, 23, 3]
70  ResNet-152 = [3, 8, 36, 3]
71  filter: number of filters per block, defined as a list.
72  filters = [64, 128, 256, 512
73  width: width multiplier for the network (for Wide ResNets)
74  bottleneck: adds a bottleneck conv to reduce computation
75  weight_decay: weight decay (l2 norm)
76  include_top: whether to include the fully-connected
77  layer at the top of the network.
78  weights: `None` (random initialization) or `imagenet` (trained
79  on ImageNet)
80  input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
81  to use as image input for the model.
82  input_shape: optional shape tuple, only to be specified
83  if `include_top` is False (otherwise the input shape
84  has to be `(224, 224, 3)` (with `tf` dim ordering)
85  or `(3, 224, 224)` (with `th` dim ordering).
86  It should have exactly 3 inputs channels,
87  and width and height should be no smaller than 8.
88  E.g. `(200, 200, 3)` would be one valid value.
89  pooling: Optional pooling mode for feature extraction
90  when `include_top` is `False`.
91  - `None` means that the output of the model will be
92  the 4D tensor output of the
93  last convolutional layer.
94  - `avg` means that global average pooling
95  will be applied to the output of the
96  last convolutional layer, and thus
97  the output of the model will be a 2D tensor.
98  - `max` means that global max pooling will
99  be applied.
100  classes: optional number of classes to classify images
101  into, only to be specified if `include_top` is True, and
102  if no `weights` argument is specified.
103  # Returns
104  A Keras model instance.
105  """
106 
107  if weights not in {'imagenet', None}:
108  raise ValueError('The `weights` argument should be either '
109  '`None` (random initialization) or `imagenet` '
110  '(pre-training on ImageNet).')
111 
112  if weights == 'imagenet' and include_top and classes != 1000:
113  raise ValueError('If using `weights` as imagenet with `include_top`'
114  ' as true, `classes` should be 1000')
115 
116  assert len(depth) == len(filters), "The length of filter increment list must match the length " \
117  "of the depth list."
118 
119  # Determine proper input shape
120  input_shape = _obtain_input_shape(input_shape,
121  default_size=224,
122  min_size=32,
123  data_format=K.image_data_format(),
124  require_flatten=False)
125 
126  if input_tensor is None:
127  img_input = Input(shape=input_shape)
128  else:
129  if not K.is_keras_tensor(input_tensor):
130  img_input = Input(tensor=input_tensor, shape=input_shape)
131  else:
132  img_input = input_tensor
133 
134  x = _create_se_resnet(classes, img_input, include_top, initial_conv_filters,
135  filters, depth, width, bottleneck, weight_decay, pooling)
136 
137  # Ensure that the model takes into account
138  # any potential predecessors of `input_tensor`.
139  if input_tensor is not None:
140  inputs = get_source_inputs(input_tensor)
141  else:
142  inputs = img_input
143  # Create model.
144  model = Model(inputs=inputs, outputs=x, name='resnext')
145 
146  # load weights
147 
148  return model
149 
150 
151 def SEResNet18(input_shape=None,
152  width=1,
153  bottleneck=False,
154  weight_decay=1e-4,
155  include_top=True,
156  weights=None,
157  input_tensor=None,
158  pooling=None,
159  classes=1000):
160  return SEResNet(input_shape,
161  depth=[2, 2, 2, 2],
162  width=width,
163  bottleneck=bottleneck,
164  weight_decay=weight_decay,
165  include_top=include_top,
166  weights=weights,
167  input_tensor=input_tensor,
168  pooling=pooling,
169  classes=classes)
170 
171 
172 def SEResNet34(input_shape=None,
173  width=1,
174  bottleneck=False,
175  weight_decay=1e-4,
176  include_top=True,
177  weights=None,
178  input_tensor=None,
179  pooling=None,
180  classes=1000):
181  return SEResNet(input_shape,
182  depth=[3, 4, 6, 3],
183  width=width,
184  bottleneck=bottleneck,
185  weight_decay=weight_decay,
186  include_top=include_top,
187  weights=weights,
188  input_tensor=input_tensor,
189  pooling=pooling,
190  classes=classes)
191 
192 
193 def SEResNet50(input_shape=None,
194  width=1,
195  bottleneck=True,
196  weight_decay=1e-4,
197  include_top=True,
198  weights=None,
199  input_tensor=None,
200  pooling=None,
201  classes=1000):
202  return SEResNet(input_shape,
203  width=width,
204  bottleneck=bottleneck,
205  weight_decay=weight_decay,
206  include_top=include_top,
207  weights=weights,
208  input_tensor=input_tensor,
209  pooling=pooling,
210  classes=classes)
211 
212 
213 def SEResNet101(input_shape=None,
214  width=1,
215  bottleneck=True,
216  weight_decay=1e-4,
217  include_top=True,
218  weights=None,
219  input_tensor=None,
220  pooling=None,
221  classes=1000):
222  return SEResNet(input_shape,
223  depth=[3, 6, 23, 3],
224  width=width,
225  bottleneck=bottleneck,
226  weight_decay=weight_decay,
227  include_top=include_top,
228  weights=weights,
229  input_tensor=input_tensor,
230  pooling=pooling,
231  classes=classes)
232 
233 
234 def SEResNet154(input_shape=None,
235  width=1,
236  bottleneck=True,
237  weight_decay=1e-4,
238  include_top=True,
239  weights=None,
240  input_tensor=None,
241  pooling=None,
242  classes=1000):
243  return SEResNet(input_shape,
244  depth=[3, 8, 36, 3],
245  width=width,
246  bottleneck=bottleneck,
247  weight_decay=weight_decay,
248  include_top=include_top,
249  weights=weights,
250  input_tensor=input_tensor,
251  pooling=pooling,
252  classes=classes)
253 
254 
255 def _resnet_block(input, filters, k=1, strides=(1, 1)):
256  ''' Adds a pre-activation resnet block without bottleneck layers
257 
258  Args:
259  input: input tensor
260  filters: number of output filters
261  k: width factor
262  strides: strides of the convolution layer
263 
264  Returns: a keras tensor
265  '''
266  init = input
267  channel_axis = 1 if K.image_data_format() == "channels_first" else -1
268 
269  x = BatchNormalization(axis=channel_axis)(input)
270  x = Activation('relu')(x)
271 
272  if strides != (1, 1) or init._keras_shape[channel_axis] != filters * k:
273  init = Conv2D(filters * k, (1, 1), padding='same', kernel_initializer='he_normal',
274  use_bias=False, strides=strides)(x)
275 
276  x = Conv2D(filters * k, (3, 3), padding='same', kernel_initializer='he_normal',
277  use_bias=False, strides=strides)(x)
278  x = BatchNormalization(axis=channel_axis)(x)
279  x = Activation('relu')(x)
280 
281  x = Conv2D(filters * k, (3, 3), padding='same', kernel_initializer='he_normal',
282  use_bias=False)(x)
283 
284  # squeeze and excite block
285  x = squeeze_excite_block(x)
286 
287  m = add([x, init])
288  return m
289 
290 
291 def _resnet_bottleneck_block(input, filters, k=1, strides=(1, 1)):
292  ''' Adds a pre-activation resnet block with bottleneck layers
293 
294  Args:
295  input: input tensor
296  filters: number of output filters
297  k: width factor
298  strides: strides of the convolution layer
299 
300  Returns: a keras tensor
301  '''
302  init = input
303  channel_axis = 1 if K.image_data_format() == "channels_first" else -1
304  bottleneck_expand = 4
305 
306  x = BatchNormalization(axis=channel_axis)(input)
307  x = Activation('relu')(x)
308 
309  if strides != (1, 1) or init._keras_shape[channel_axis] != bottleneck_expand * filters * k:
310  init = Conv2D(bottleneck_expand * filters * k, (1, 1), padding='same', kernel_initializer='he_normal',
311  use_bias=False, strides=strides)(x)
312 
313  x = Conv2D(filters * k, (1, 1), padding='same', kernel_initializer='he_normal',
314  use_bias=False)(x)
315  x = BatchNormalization(axis=channel_axis)(x)
316  x = Activation('relu')(x)
317 
318  x = Conv2D(filters * k, (3, 3), padding='same', kernel_initializer='he_normal',
319  use_bias=False, strides=strides)(x)
320  x = BatchNormalization(axis=channel_axis)(x)
321  x = Activation('relu')(x)
322 
323  x = Conv2D(bottleneck_expand * filters * k, (1, 1), padding='same', kernel_initializer='he_normal',
324  use_bias=False)(x)
325 
326  # squeeze and excite block
327  x = squeeze_excite_block(x)
328 
329  m = add([x, init])
330  return m
331 
332 
333 def _create_se_resnet(classes, img_input, include_top, initial_conv_filters, filters,
334  depth, width, bottleneck, weight_decay, pooling):
335  '''Creates a SE ResNet model with specified parameters
336  Args:
337  initial_conv_filters: number of features for the initial convolution
338  include_top: Flag to include the last dense layer
339  filters: number of filters per block, defined as a list.
340  filters = [64, 128, 256, 512
341  depth: number or layers in the each block, defined as a list.
342  ResNet-50 = [3, 4, 6, 3]
343  ResNet-101 = [3, 6, 23, 3]
344  ResNet-152 = [3, 8, 36, 3]
345  width: width multiplier for network (for Wide ResNet)
346  bottleneck: adds a bottleneck conv to reduce computation
347  weight_decay: weight_decay (l2 norm)
348  pooling: Optional pooling mode for feature extraction
349  when `include_top` is `False`.
350  - `None` means that the output of the model will be
351  the 4D tensor output of the
352  last convolutional layer.
353  - `avg` means that global average pooling
354  will be applied to the output of the
355  last convolutional layer, and thus
356  the output of the model will be a 2D tensor.
357  - `max` means that global max pooling will
358  be applied.
359  Returns: a Keras Model
360  '''
361  channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
362  N = list(depth)
363 
364  # block 1 (initial conv block)
365  x = Conv2D(initial_conv_filters, (7, 7), padding='same', use_bias=False, strides=(2, 2),
366  kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input)
367 
368  x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
369 
370  # block 2 (projection block)
371  for i in range(N[0]):
372  if bottleneck:
373  x = _resnet_bottleneck_block(x, filters[0], width)
374  else:
375  x = _resnet_block(x, filters[0], width)
376 
377  # block 3 - N
378  for k in range(1, len(N)):
379  if bottleneck:
380  x = _resnet_bottleneck_block(x, filters[k], width, strides=(2, 2))
381  else:
382  x = _resnet_block(x, filters[k], width, strides=(2, 2))
383 
384  for i in range(N[k] - 1):
385  if bottleneck:
386  x = _resnet_bottleneck_block(x, filters[k], width)
387  else:
388  x = _resnet_block(x, filters[k], width)
389 
390  x = BatchNormalization(axis=channel_axis)(x)
391  x = Activation('relu')(x)
392 
393  if include_top:
394  x = GlobalAveragePooling2D()(x)
395  x = Dense(classes, use_bias=False, kernel_regularizer=l2(weight_decay),
396  activation='sigmoid', name='neutrino')(x)
397  else:
398  if pooling == 'avg':
399  x = GlobalAveragePooling2D()(x)
400  elif pooling == 'max':
401  x = GlobalMaxPooling2D()(x)
402 
403  return x
def SEResNet(input_shape=None, initial_conv_filters=64, depth=[3, filters=[64, width=1, bottleneck=False, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
Definition: se_resnet.py:56
def SEResNet18(input_shape=None, width=1, bottleneck=False, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
Definition: se_resnet.py:159
Coord add(Coord c1, Coord c2)
Definition: restypedef.cpp:23
def squeeze_excite_block(input, ratio=16)
Definition: se.py:5
def _create_se_resnet(classes, img_input, include_top, initial_conv_filters, filters, depth, width, bottleneck, weight_decay, pooling)
Definition: se_resnet.py:334
def _resnet_block(input, filters, k=1, strides=(1, 1))
Definition: se_resnet.py:255
def SEResNet101(input_shape=None, width=1, bottleneck=True, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
Definition: se_resnet.py:221
def SEResNet34(input_shape=None, width=1, bottleneck=False, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
Definition: se_resnet.py:180
def _resnet_bottleneck_block(input, filters, k=1, strides=(1, 1))
Definition: se_resnet.py:291
def _obtain_input_shape(input_shape, default_size, min_size, data_format, require_flatten, weights=None)
def SEResNet50(input_shape=None, width=1, bottleneck=True, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
Definition: se_resnet.py:201
def SEResNet154(input_shape=None, width=1, bottleneck=True, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
Definition: se_resnet.py:242