se_resnet_saul.py
Go to the documentation of this file.
1 '''
2 Based on https://github.com/titu1994/keras-squeeze-excite-network/blob/master/se_resnet.py
3 
4 Squeeze-and-Excitation ResNets
5 
6 References:
7  - [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
8  - []() # added when paper is published on Arxiv
9 '''
10 from __future__ import print_function
11 from __future__ import absolute_import
12 from __future__ import division
13 
14 from keras.models import Model
15 from keras.layers import Input
16 from keras.layers import Dense
17 from keras.layers import Reshape
18 from keras.layers import Activation
19 from keras.layers import BatchNormalization
20 from keras.layers import MaxPooling2D
21 from keras.layers import GlobalAveragePooling2D
22 from keras.layers import GlobalMaxPooling2D
23 from keras.layers import Conv2D
24 from keras.layers import add
25 from keras.layers import concatenate
26 from keras.layers import multiply
27 from keras.regularizers import l2
28 from keras.utils import conv_utils
29 from keras.utils.data_utils import get_file
30 from keras.engine.topology import get_source_inputs
31 from keras.applications.imagenet_utils import _obtain_input_shape
32 from keras.applications.resnet50 import preprocess_input
33 from keras.applications.imagenet_utils import decode_predictions
34 from keras import backend as K
35 
36 from se import squeeze_excite_block
37 
38 __all__ = ['SEResNet', 'SEResNet50', 'SEResNet101', 'SEResNet154', 'preprocess_input', 'decode_predictions']
39 
40 
41 WEIGHTS_PATH = ""
42 WEIGHTS_PATH_NO_TOP = ""
43 
44 
45 def SEResNet(input_shape=None,
46  initial_conv_filters=64,
47  depth=[3, 4, 6, 3],
48  filters=[64, 128, 256, 512],
49  width=1,
50  bottleneck=False,
51  weight_decay=1e-4,
52  include_top=True,
53  weights=None,
54  input_tensor=None,
55  pooling=None,
56  classes=1000):
57  """ Instantiate the Squeeze and Excite ResNet architecture. Note that ,
58  when using TensorFlow for best performance you should set
59  `image_data_format="channels_last"` in your Keras config
60  at ~/.keras/keras.json.
61  The model are compatible with both
62  TensorFlow and Theano. The dimension ordering
63  convention used by the model is the one
64  specified in your Keras config file.
65  # Arguments
66  initial_conv_filters: number of features for the initial convolution
67  depth: number or layers in the each block, defined as a list.
68  ResNet-50 = [3, 4, 6, 3]
69  ResNet-101 = [3, 6, 23, 3]
70  ResNet-152 = [3, 8, 36, 3]
71  filter: number of filters per block, defined as a list.
72  filters = [64, 128, 256, 512
73  width: width multiplier for the network (for Wide ResNets)
74  bottleneck: adds a bottleneck conv to reduce computation
75  weight_decay: weight decay (l2 norm)
76  include_top: whether to include the fully-connected
77  layer at the top of the network.
78  weights: `None` (random initialization) or `imagenet` (trained
79  on ImageNet)
80  input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
81  to use as image input for the model.
82  input_shape: optional shape tuple, only to be specified
83  if `include_top` is False (otherwise the input shape
84  has to be `(224, 224, 3)` (with `tf` dim ordering)
85  or `(3, 224, 224)` (with `th` dim ordering).
86  It should have exactly 3 inputs channels,
87  and width and height should be no smaller than 8.
88  E.g. `(200, 200, 3)` would be one valid value.
89  pooling: Optional pooling mode for feature extraction
90  when `include_top` is `False`.
91  - `None` means that the output of the model will be
92  the 4D tensor output of the
93  last convolutional layer.
94  - `avg` means that global average pooling
95  will be applied to the output of the
96  last convolutional layer, and thus
97  the output of the model will be a 2D tensor.
98  - `max` means that global max pooling will
99  be applied.
100  classes: optional number of classes to classify images
101  into, only to be specified if `include_top` is True, and
102  if no `weights` argument is specified.
103  # Returns
104  A Keras model instance.
105  """
106 
107  if weights not in {'imagenet', None}:
108  raise ValueError('The `weights` argument should be either '
109  '`None` (random initialization) or `imagenet` '
110  '(pre-training on ImageNet).')
111 
112  if weights == 'imagenet' and include_top and classes != 1000:
113  raise ValueError('If using `weights` as imagenet with `include_top`'
114  ' as true, `classes` should be 1000')
115 
116  assert len(depth) == len(filters), "The length of filter increment list must match the length " \
117  "of the depth list."
118 
119  # Determine proper input shape
120  input_shape = _obtain_input_shape(input_shape,
121  default_size=224,
122  min_size=32,
123  data_format=K.image_data_format(),
124  require_flatten=False)
125 
126  if input_tensor is None:
127 
128  img_input1 = Input(shape=input_shape, name='view0')
129  img_input2 = Input(shape=input_shape, name='view1')
130  img_input3 = Input(shape=input_shape, name='view2')
131 
132  img_input = [img_input1, img_input2, img_input3]
133 
134  x = _create_se_resnet(classes, img_input, include_top, initial_conv_filters,
135  filters, depth, width, bottleneck, weight_decay, pooling)
136 
137  inputs = img_input
138 
139  # Create model.
140  model = Model(inputs=inputs, outputs=x, name='resnext')
141 
142  # load weights
143 
144  return model
145 
146 
147 def SEResNet18(input_shape=None,
148  width=1,
149  bottleneck=False,
150  weight_decay=1e-4,
151  include_top=True,
152  weights=None,
153  input_tensor=None,
154  pooling=None,
155  classes=1000):
156  return SEResNet(input_shape,
157  depth=[2, 2, 2, 2],
158  width=width,
159  bottleneck=bottleneck,
160  weight_decay=weight_decay,
161  include_top=include_top,
162  weights=weights,
163  input_tensor=input_tensor,
164  pooling=pooling,
165  classes=classes)
166 
167 
168 def SEResNet34(input_shape=None,
169  width=1,
170  bottleneck=False,
171  weight_decay=1e-4,
172  include_top=True,
173  weights=None,
174  input_tensor=None,
175  pooling=None,
176  classes=1000):
177  return SEResNet(input_shape,
178  depth=[3, 4, 6, 3],
179  width=width,
180  bottleneck=bottleneck,
181  weight_decay=weight_decay,
182  include_top=include_top,
183  weights=weights,
184  input_tensor=input_tensor,
185  pooling=pooling,
186  classes=classes)
187 
188 
189 def SEResNet50(input_shape=None,
190  width=1,
191  bottleneck=True,
192  weight_decay=1e-4,
193  include_top=True,
194  weights=None,
195  input_tensor=None,
196  pooling=None,
197  classes=1000):
198  return SEResNet(input_shape,
199  width=width,
200  bottleneck=bottleneck,
201  weight_decay=weight_decay,
202  include_top=include_top,
203  weights=weights,
204  input_tensor=input_tensor,
205  pooling=pooling,
206  classes=classes)
207 
208 
209 def SEResNet101(input_shape=None,
210  width=1,
211  bottleneck=True,
212  weight_decay=1e-4,
213  include_top=True,
214  weights=None,
215  input_tensor=None,
216  pooling=None,
217  classes=1000):
218  return SEResNet(input_shape,
219  depth=[3, 6, 23, 3],
220  width=width,
221  bottleneck=bottleneck,
222  weight_decay=weight_decay,
223  include_top=include_top,
224  weights=weights,
225  input_tensor=input_tensor,
226  pooling=pooling,
227  classes=classes)
228 
229 
230 def SEResNet154(input_shape=None,
231  width=1,
232  bottleneck=True,
233  weight_decay=1e-4,
234  include_top=True,
235  weights=None,
236  input_tensor=None,
237  pooling=None,
238  classes=1000):
239  return SEResNet(input_shape,
240  depth=[3, 8, 36, 3],
241  width=width,
242  bottleneck=bottleneck,
243  weight_decay=weight_decay,
244  include_top=include_top,
245  weights=weights,
246  input_tensor=input_tensor,
247  pooling=pooling,
248  classes=classes)
249 
250 
251 def _resnet_block(input, filters, k=1, strides=(1, 1)):
252  ''' Adds a pre-activation resnet block without bottleneck layers
253 
254  Args:
255  input: input tensor
256  filters: number of output filters
257  k: width factor
258  strides: strides of the convolution layer
259 
260  Returns: a keras tensor
261  '''
262  init = input
263  channel_axis = 1 if K.image_data_format() == "channels_first" else -1
264 
265  x = BatchNormalization(axis=channel_axis)(input)
266  x = Activation('relu')(x)
267 
268  if strides != (1, 1) or init._keras_shape[channel_axis] != filters * k:
269  init = Conv2D(filters * k, (1, 1), padding='same', kernel_initializer='he_normal',
270  use_bias=False, strides=strides)(x)
271 
272  x = Conv2D(filters * k, (3, 3), padding='same', kernel_initializer='he_normal',
273  use_bias=False, strides=strides)(x)
274  x = BatchNormalization(axis=channel_axis)(x)
275  x = Activation('relu')(x)
276 
277  x = Conv2D(filters * k, (3, 3), padding='same', kernel_initializer='he_normal',
278  use_bias=False)(x)
279 
280  # squeeze and excite block
281  x = squeeze_excite_block(x)
282 
283  m = add([x, init])
284  return m
285 
286 
287 def _resnet_bottleneck_block(input, filters, k=1, strides=(1, 1)):
288  ''' Adds a pre-activation resnet block with bottleneck layers
289 
290  Args:
291  input: input tensor
292  filters: number of output filters
293  k: width factor
294  strides: strides of the convolution layer
295 
296  Returns: a keras tensor
297  '''
298  init = input
299  channel_axis = 1 if K.image_data_format() == "channels_first" else -1
300  bottleneck_expand = 4
301 
302  x = BatchNormalization(axis=channel_axis)(input)
303  x = Activation('relu')(x)
304 
305  if strides != (1, 1) or init._keras_shape[channel_axis] != bottleneck_expand * filters * k:
306  init = Conv2D(bottleneck_expand * filters * k, (1, 1), padding='same', kernel_initializer='he_normal',
307  use_bias=False, strides=strides)(x)
308 
309  x = Conv2D(filters * k, (1, 1), padding='same', kernel_initializer='he_normal',
310  use_bias=False)(x)
311  x = BatchNormalization(axis=channel_axis)(x)
312  x = Activation('relu')(x)
313 
314  x = Conv2D(filters * k, (3, 3), padding='same', kernel_initializer='he_normal',
315  use_bias=False, strides=strides)(x)
316  x = BatchNormalization(axis=channel_axis)(x)
317  x = Activation('relu')(x)
318 
319  x = Conv2D(bottleneck_expand * filters * k, (1, 1), padding='same', kernel_initializer='he_normal',
320  use_bias=False)(x)
321 
322  # squeeze and excite block
323  x = squeeze_excite_block(x)
324 
325  m = add([x, init])
326  return m
327 
328 
329 def _create_se_resnet(classes, img_input, include_top, initial_conv_filters, filters,
330  depth, width, bottleneck, weight_decay, pooling):
331  '''Creates a SE ResNet model with specified parameters
332  Args:
333  initial_conv_filters: number of features for the initial convolution
334  include_top: Flag to include the last dense layer
335  filters: number of filters per block, defined as a list.
336  filters = [64, 128, 256, 512
337  depth: number or layers in the each block, defined as a list.
338  ResNet-50 = [3, 4, 6, 3]
339  ResNet-101 = [3, 6, 23, 3]
340  ResNet-152 = [3, 8, 36, 3]
341  width: width multiplier for network (for Wide ResNet)
342  bottleneck: adds a bottleneck conv to reduce computation
343  weight_decay: weight_decay (l2 norm)
344  pooling: Optional pooling mode for feature extraction
345  when `include_top` is `False`.
346  - `None` means that the output of the model will be
347  the 4D tensor output of the
348  last convolutional layer.
349  - `avg` means that global average pooling
350  will be applied to the output of the
351  last convolutional layer, and thus
352  the output of the model will be a 2D tensor.
353  - `max` means that global max pooling will
354  be applied.
355  Returns: a Keras Model
356  '''
357 
358  channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
359  N = list(depth)
360 
361  # branches
362  branches = []
363 
364  for i in range(len(img_input)):
365 
366  # block 1 (initial conv block)
367  branch = Conv2D(initial_conv_filters, (7, 7), padding='same', use_bias=False, strides=(2, 2),
368  kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input[i])
369 
370  branch = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(branch)
371 
372  # block 2 (projection block)
373  for i in range(N[0]):
374  if bottleneck:
375  branch = _resnet_bottleneck_block(branch, filters[0], width)
376  else:
377  branch = _resnet_block(branch, filters[0], width)
378 
379  branches.append(branch)
380 
381  x = concatenate(branches)
382 
383  '''
384  # block 1 (initial conv block)
385  x = Conv2D(initial_conv_filters, (7, 7), padding='same', use_bias=False, strides=(2, 2),
386  kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(img_input)
387 
388  x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
389 
390  # block 2 (projection block)
391  for i in range(N[0]):
392  if bottleneck:
393  x = _resnet_bottleneck_block(x, filters[0], width)
394  else:
395  x = _resnet_block(x, filters[0], width)
396  '''
397 
398  # block 3 - N
399  for k in range(1, len(N)):
400  if bottleneck:
401  x = _resnet_bottleneck_block(x, filters[k], width, strides=(2, 2))
402  else:
403  x = _resnet_block(x, filters[k], width, strides=(2, 2))
404 
405  for i in range(N[k] - 1):
406  if bottleneck:
407  x = _resnet_bottleneck_block(x, filters[k], width)
408  else:
409  x = _resnet_block(x, filters[k], width)
410 
411  x = BatchNormalization(axis=channel_axis)(x)
412  x = Activation('relu')(x)
413 
414  if include_top:
415  x = GlobalAveragePooling2D()(x)
416  x = Dense(classes, use_bias=False, kernel_regularizer=l2(weight_decay),
417  activation='sigmoid', name='output')(x)
418  else:
419  if pooling == 'avg':
420  x = GlobalAveragePooling2D()(x)
421  elif pooling == 'max':
422  x = GlobalMaxPooling2D()(x)
423 
424  return x
def _resnet_bottleneck_block(input, filters, k=1, strides=(1, 1))
Coord add(Coord c1, Coord c2)
Definition: restypedef.cpp:23
std::string concatenate(H const &h, T const &...t)
Definition: select.h:138
def _create_se_resnet(classes, img_input, include_top, initial_conv_filters, filters, depth, width, bottleneck, weight_decay, pooling)
def squeeze_excite_block(input, ratio=16)
Definition: se.py:5
def SEResNet101(input_shape=None, width=1, bottleneck=True, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
def SEResNet50(input_shape=None, width=1, bottleneck=True, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
def SEResNet(input_shape=None, initial_conv_filters=64, depth=[3, filters=[64, width=1, bottleneck=False, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
def SEResNet34(input_shape=None, width=1, bottleneck=False, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
def _obtain_input_shape(input_shape, default_size, min_size, data_format, require_flatten, weights=None)
def SEResNet154(input_shape=None, width=1, bottleneck=True, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
def _resnet_block(input, filters, k=1, strides=(1, 1))
def SEResNet18(input_shape=None, width=1, bottleneck=False, weight_decay=1e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)