se_resnext.py
Go to the documentation of this file.
1 '''ResNeXt models for Keras.
2 # Reference
3 - [Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf))
4 '''
5 from __future__ import print_function
6 from __future__ import absolute_import
7 from __future__ import division
8 
9 import warnings
10 
11 from keras.models import Model
12 from keras.layers.core import Dense, Lambda
13 from keras.layers.advanced_activations import LeakyReLU
14 from keras.layers.convolutional import Conv2D
15 from keras.layers.pooling import GlobalAveragePooling2D, GlobalMaxPooling2D, MaxPooling2D
16 from keras.layers import Input
17 from keras.layers.merge import concatenate, add
18 from keras.layers.normalization import BatchNormalization
19 from keras.regularizers import l2
20 from keras.utils.layer_utils import convert_all_kernels_in_model
21 from keras.utils.data_utils import get_file
22 from keras.engine.topology import get_source_inputs
23 from keras.applications.imagenet_utils import _obtain_input_shape
24 import keras.backend as K
25 
26 from se import squeeze_excite_block
27 
28 CIFAR_TH_WEIGHTS_PATH = ''
29 CIFAR_TF_WEIGHTS_PATH = ''
30 CIFAR_TH_WEIGHTS_PATH_NO_TOP = ''
31 CIFAR_TF_WEIGHTS_PATH_NO_TOP = ''
32 
33 IMAGENET_TH_WEIGHTS_PATH = ''
34 IMAGENET_TF_WEIGHTS_PATH = ''
35 IMAGENET_TH_WEIGHTS_PATH_NO_TOP = ''
36 IMAGENET_TF_WEIGHTS_PATH_NO_TOP = ''
37 
38 
39 def SEResNext(input_shape=None,
40  depth=29,
41  cardinality=8,
42  width=64,
43  weight_decay=5e-4,
44  include_top=True,
45  weights=None,
46  input_tensor=None,
47  pooling=None,
48  classes=10):
49  """Instantiate the ResNeXt architecture. Note that ,
50  when using TensorFlow for best performance you should set
51  `image_data_format="channels_last"` in your Keras config
52  at ~/.keras/keras.json.
53  The model are compatible with both
54  TensorFlow and Theano. The dimension ordering
55  convention used by the model is the one
56  specified in your Keras config file.
57  # Arguments
58  depth: number or layers in the ResNeXt model. Can be an
59  integer or a list of integers.
60  cardinality: the size of the set of transformations
61  width: multiplier to the ResNeXt width (number of filters)
62  weight_decay: weight decay (l2 norm)
63  include_top: whether to include the fully-connected
64  layer at the top of the network.
65  weights: `None` (random initialization)
66  input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
67  to use as image input for the model.
68  input_shape: optional shape tuple, only to be specified
69  if `include_top` is False (otherwise the input shape
70  has to be `(32, 32, 3)` (with `tf` dim ordering)
71  or `(3, 32, 32)` (with `th` dim ordering).
72  It should have exactly 3 inputs channels,
73  and width and height should be no smaller than 8.
74  E.g. `(200, 200, 3)` would be one valid value.
75  pooling: Optional pooling mode for feature extraction
76  when `include_top` is `False`.
77  - `None` means that the output of the model will be
78  the 4D tensor output of the
79  last convolutional layer.
80  - `avg` means that global average pooling
81  will be applied to the output of the
82  last convolutional layer, and thus
83  the output of the model will be a 2D tensor.
84  - `max` means that global max pooling will
85  be applied.
86  classes: optional number of classes to classify images
87  into, only to be specified if `include_top` is True, and
88  if no `weights` argument is specified.
89  # Returns
90  A Keras model instance.
91  """
92 
93  if weights not in {'cifar10', None}:
94  raise ValueError('The `weights` argument should be either '
95  '`None` (random initialization) or `cifar10` '
96  '(pre-training on CIFAR-10).')
97 
98  if weights == 'cifar10' and include_top and classes != 10:
99  raise ValueError('If using `weights` as CIFAR 10 with `include_top`'
100  ' as true, `classes` should be 10')
101 
102  if type(depth) == int:
103  if (depth - 2) % 9 != 0:
104  raise ValueError('Depth of the network must be such that (depth - 2)'
105  'should be divisible by 9.')
106 
107  # Determine proper input shape
108  input_shape = _obtain_input_shape(input_shape,
109  default_size=32,
110  min_size=8,
111  data_format=K.image_data_format(),
112  require_flatten=include_top)
113 
114  if input_tensor is None:
115  img_input = Input(shape=input_shape)
116  else:
117  if not K.is_keras_tensor(input_tensor):
118  img_input = Input(tensor=input_tensor, shape=input_shape)
119  else:
120  img_input = input_tensor
121 
122  x = __create_res_next(classes, img_input, include_top, depth, cardinality, width,
123  weight_decay, pooling)
124 
125  # Ensure that the model takes into account
126  # any potential predecessors of `input_tensor`.
127  if input_tensor is not None:
128  inputs = get_source_inputs(input_tensor)
129  else:
130  inputs = img_input
131  # Create model.
132  model = Model(inputs, x, name='se-resnext')
133 
134  return model
135 
136 
137 def SEResNextImageNet(input_shape=None,
138  depth=[3, 4, 6, 3],
139  cardinality=32,
140  width=4,
141  weight_decay=5e-4,
142  include_top=True,
143  weights=None,
144  input_tensor=None,
145  pooling=None,
146  classes=1000):
147  """ Instantiate the SE ResNeXt architecture for the ImageNet dataset. Note that ,
148  when using TensorFlow for best performance you should set
149  `image_data_format="channels_last"` in your Keras config
150  at ~/.keras/keras.json.
151  The model are compatible with both
152  TensorFlow and Theano. The dimension ordering
153  convention used by the model is the one
154  specified in your Keras config file.
155  # Arguments
156  depth: number or layers in the each block, defined as a list.
157  ResNeXt-50 can be defined as [3, 4, 6, 3].
158  ResNeXt-101 can be defined as [3, 4, 23, 3].
159  Defaults is ResNeXt-50.
160  cardinality: the size of the set of transformations
161  width: multiplier to the ResNeXt width (number of filters)
162  weight_decay: weight decay (l2 norm)
163  include_top: whether to include the fully-connected
164  layer at the top of the network.
165  weights: `None` (random initialization) or `imagenet` (trained
166  on ImageNet)
167  input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
168  to use as image input for the model.
169  input_shape: optional shape tuple, only to be specified
170  if `include_top` is False (otherwise the input shape
171  has to be `(224, 224, 3)` (with `tf` dim ordering)
172  or `(3, 224, 224)` (with `th` dim ordering).
173  It should have exactly 3 inputs channels,
174  and width and height should be no smaller than 8.
175  E.g. `(200, 200, 3)` would be one valid value.
176  pooling: Optional pooling mode for feature extraction
177  when `include_top` is `False`.
178  - `None` means that the output of the model will be
179  the 4D tensor output of the
180  last convolutional layer.
181  - `avg` means that global average pooling
182  will be applied to the output of the
183  last convolutional layer, and thus
184  the output of the model will be a 2D tensor.
185  - `max` means that global max pooling will
186  be applied.
187  classes: optional number of classes to classify images
188  into, only to be specified if `include_top` is True, and
189  if no `weights` argument is specified.
190  # Returns
191  A Keras model instance.
192  """
193 
194  if weights not in {'imagenet', None}:
195  raise ValueError('The `weights` argument should be either '
196  '`None` (random initialization) or `imagenet` '
197  '(pre-training on ImageNet).')
198 
199  if weights == 'imagenet' and include_top and classes != 1000:
200  raise ValueError('If using `weights` as imagenet with `include_top`'
201  ' as true, `classes` should be 1000')
202 
203  if type(depth) == int and (depth - 2) % 9 != 0:
204  raise ValueError('Depth of the network must be such that (depth - 2)'
205  'should be divisible by 9.')
206  # Determine proper input shape
207  input_shape = _obtain_input_shape(input_shape,
208  default_size=224,
209  min_size=112,
210  data_format=K.image_data_format(),
211  require_flatten=include_top)
212 
213  if input_tensor is None:
214  img_input = Input(shape=input_shape)
215  else:
216  if not K.is_keras_tensor(input_tensor):
217  img_input = Input(tensor=input_tensor, shape=input_shape)
218  else:
219  img_input = input_tensor
220 
221  x = __create_res_next_imagenet(classes, img_input, include_top, depth, cardinality, width,
222  weight_decay, pooling)
223 
224  # Ensure that the model takes into account
225  # any potential predecessors of `input_tensor`.
226  if input_tensor is not None:
227  inputs = get_source_inputs(input_tensor)
228  else:
229  inputs = img_input
230  # Create model.
231  model = Model(inputs, x, name='resnext')
232 
233  return model
234 
235 
236 def __initial_conv_block(input, weight_decay=5e-4):
237  ''' Adds an initial convolution block, with batch normalization and relu activation
238  Args:
239  input: input tensor
240  weight_decay: weight decay factor
241  Returns: a keras tensor
242  '''
243  channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
244 
245  x = Conv2D(64, (3, 3), padding='same', use_bias=False, kernel_initializer='he_normal',
246  kernel_regularizer=l2(weight_decay))(input)
247  x = BatchNormalization(axis=channel_axis)(x)
248  x = LeakyReLU()(x)
249 
250  return x
251 
252 
253 def __initial_conv_block_inception(input, weight_decay=5e-4):
254  ''' Adds an initial conv block, with batch norm and relu for the inception resnext
255  Args:
256  input: input tensor
257  weight_decay: weight decay factor
258  Returns: a keras tensor
259  '''
260  channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
261 
262  x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal',
263  kernel_regularizer=l2(weight_decay), strides=(2, 2))(input)
264  x = BatchNormalization(axis=channel_axis)(x)
265  x = LeakyReLU()(x)
266 
267  x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
268 
269  return x
270 
271 
272 def __grouped_convolution_block(input, grouped_channels, cardinality, strides, weight_decay=5e-4):
273  ''' Adds a grouped convolution block. It is an equivalent block from the paper
274  Args:
275  input: input tensor
276  grouped_channels: grouped number of filters
277  cardinality: cardinality factor describing the number of groups
278  strides: performs strided convolution for downscaling if > 1
279  weight_decay: weight decay term
280  Returns: a keras tensor
281  '''
282  init = input
283  channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
284 
285  group_list = []
286 
287  if cardinality == 1:
288  # with cardinality 1, it is a standard convolution
289  x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides),
290  kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
291  x = BatchNormalization(axis=channel_axis)(x)
292  x = LeakyReLU()(x)
293  return x
294 
295  for c in range(cardinality):
296  x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels]
297  if K.image_data_format() == 'channels_last' else
298  lambda z: z[:, c * grouped_channels:(c + 1) * grouped_channels, :, :])(input)
299 
300  x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides),
301  kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(x)
302 
303  group_list.append(x)
304 
305  group_merge = concatenate(group_list, axis=channel_axis)
306  x = BatchNormalization(axis=channel_axis)(group_merge)
307  x = LeakyReLU()(x)
308 
309  return x
310 
311 
312 def __bottleneck_block(input, filters=64, cardinality=8, strides=1, weight_decay=5e-4):
313  ''' Adds a bottleneck block
314  Args:
315  input: input tensor
316  filters: number of output filters
317  cardinality: cardinality factor described number of
318  grouped convolutions
319  strides: performs strided convolution for downsampling if > 1
320  weight_decay: weight decay factor
321  Returns: a keras tensor
322  '''
323  init = input
324 
325  grouped_channels = int(filters / cardinality)
326  channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
327 
328  # Check if input number of filters is same as 16 * k, else create convolution2d for this input
329  if K.image_data_format() == 'channels_first':
330  if init._keras_shape[1] != 2 * filters:
331  init = Conv2D(filters * 2, (1, 1), padding='same', strides=(strides, strides),
332  use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
333  init = BatchNormalization(axis=channel_axis)(init)
334  else:
335  if init._keras_shape[-1] != 2 * filters:
336  init = Conv2D(filters * 2, (1, 1), padding='same', strides=(strides, strides),
337  use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
338  init = BatchNormalization(axis=channel_axis)(init)
339 
340  x = Conv2D(filters, (1, 1), padding='same', use_bias=False,
341  kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(input)
342  x = BatchNormalization(axis=channel_axis)(x)
343  x = LeakyReLU()(x)
344 
345  x = __grouped_convolution_block(x, grouped_channels, cardinality, strides, weight_decay)
346 
347  x = Conv2D(filters * 2, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal',
348  kernel_regularizer=l2(weight_decay))(x)
349  x = BatchNormalization(axis=channel_axis)(x)
350 
351  # squeeze and excite block
352  x = squeeze_excite_block(x)
353 
354  x = add([init, x])
355  x = LeakyReLU()(x)
356 
357  return x
358 
359 
360 def __create_res_next(nb_classes, img_input, include_top, depth=29, cardinality=8, width=4,
361  weight_decay=5e-4, pooling=None):
362  ''' Creates a ResNeXt model with specified parameters
363  Args:
364  nb_classes: Number of output classes
365  img_input: Input tensor or layer
366  include_top: Flag to include the last dense layer
367  depth: Depth of the network. Can be an positive integer or a list
368  Compute N = (n - 2) / 9.
369  For a depth of 56, n = 56, N = (56 - 2) / 9 = 6
370  For a depth of 101, n = 101, N = (101 - 2) / 9 = 11
371  cardinality: the size of the set of transformations.
372  Increasing cardinality improves classification accuracy,
373  width: Width of the network.
374  weight_decay: weight_decay (l2 norm)
375  pooling: Optional pooling mode for feature extraction
376  when `include_top` is `False`.
377  - `None` means that the output of the model will be
378  the 4D tensor output of the
379  last convolutional layer.
380  - `avg` means that global average pooling
381  will be applied to the output of the
382  last convolutional layer, and thus
383  the output of the model will be a 2D tensor.
384  - `max` means that global max pooling will
385  be applied.
386  Returns: a Keras Model
387  '''
388 
389  if type(depth) is list or type(depth) is tuple:
390  # If a list is provided, defer to user how many blocks are present
391  N = list(depth)
392  else:
393  # Otherwise, default to 3 blocks each of default number of group convolution blocks
394  N = [(depth - 2) // 9 for _ in range(3)]
395 
396  filters = cardinality * width
397  filters_list = []
398 
399  for i in range(len(N)):
400  filters_list.append(filters)
401  filters *= 2 # double the size of the filters
402 
403  x = __initial_conv_block(img_input, weight_decay)
404 
405  # block 1 (no pooling)
406  for i in range(N[0]):
407  x = __bottleneck_block(x, filters_list[0], cardinality, strides=1, weight_decay=weight_decay)
408 
409  N = N[1:] # remove the first block from block definition list
410  filters_list = filters_list[1:] # remove the first filter from the filter list
411 
412  # block 2 to N
413  for block_idx, n_i in enumerate(N):
414  for i in range(n_i):
415  if i == 0:
416  x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides=2,
417  weight_decay=weight_decay)
418  else:
419  x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides=1,
420  weight_decay=weight_decay)
421 
422  if include_top:
423  x = GlobalAveragePooling2D()(x)
424  x = Dense(nb_classes, use_bias=False, kernel_regularizer=l2(weight_decay),
425  kernel_initializer='he_normal', activation='softmax')(x)
426  else:
427  if pooling == 'avg':
428  x = GlobalAveragePooling2D()(x)
429  elif pooling == 'max':
430  x = GlobalMaxPooling2D()(x)
431 
432  return x
433 
434 
435 def __create_res_next_imagenet(nb_classes, img_input, include_top, depth, cardinality=32, width=4,
436  weight_decay=5e-4, pooling=None):
437  ''' Creates a ResNeXt model with specified parameters
438  Args:
439  nb_classes: Number of output classes
440  img_input: Input tensor or layer
441  include_top: Flag to include the last dense layer
442  depth: Depth of the network. List of integers.
443  Increasing cardinality improves classification accuracy,
444  width: Width of the network.
445  weight_decay: weight_decay (l2 norm)
446  pooling: Optional pooling mode for feature extraction
447  when `include_top` is `False`.
448  - `None` means that the output of the model will be
449  the 4D tensor output of the
450  last convolutional layer.
451  - `avg` means that global average pooling
452  will be applied to the output of the
453  last convolutional layer, and thus
454  the output of the model will be a 2D tensor.
455  - `max` means that global max pooling will
456  be applied.
457  Returns: a Keras Model
458  '''
459 
460  if type(depth) is list or type(depth) is tuple:
461  # If a list is provided, defer to user how many blocks are present
462  N = list(depth)
463  else:
464  # Otherwise, default to 3 blocks each of default number of group convolution blocks
465  N = [(depth - 2) // 9 for _ in range(3)]
466 
467  filters = cardinality * width
468  filters_list = []
469 
470  for i in range(len(N)):
471  filters_list.append(filters)
472  filters *= 2 # double the size of the filters
473 
474  x = __initial_conv_block_inception(img_input, weight_decay)
475 
476  # block 1 (no pooling)
477  for i in range(N[0]):
478  x = __bottleneck_block(x, filters_list[0], cardinality, strides=1, weight_decay=weight_decay)
479 
480  N = N[1:] # remove the first block from block definition list
481  filters_list = filters_list[1:] # remove the first filter from the filter list
482 
483  # block 2 to N
484  for block_idx, n_i in enumerate(N):
485  for i in range(n_i):
486  if i == 0:
487  x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides=2,
488  weight_decay=weight_decay)
489  else:
490  x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides=1,
491  weight_decay=weight_decay)
492 
493  if include_top:
494  x = GlobalAveragePooling2D()(x)
495  x = Dense(nb_classes, use_bias=False, kernel_regularizer=l2(weight_decay),
496  kernel_initializer='he_normal', activation='softmax')(x)
497  else:
498  if pooling == 'avg':
499  x = GlobalAveragePooling2D()(x)
500  elif pooling == 'max':
501  x = GlobalMaxPooling2D()(x)
502 
503  return x
Coord add(Coord c1, Coord c2)
Definition: restypedef.cpp:23
std::string concatenate(H const &h, T const &...t)
Definition: select.h:138
def SEResNext(input_shape=None, depth=29, cardinality=8, width=64, weight_decay=5e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=10)
Definition: se_resnext.py:48
def __initial_conv_block(input, weight_decay=5e-4)
Definition: se_resnext.py:236
def __grouped_convolution_block(input, grouped_channels, cardinality, strides, weight_decay=5e-4)
Definition: se_resnext.py:272
def squeeze_excite_block(input, ratio=16)
Definition: se.py:5
def __create_res_next_imagenet(nb_classes, img_input, include_top, depth, cardinality=32, width=4, weight_decay=5e-4, pooling=None)
Definition: se_resnext.py:436
def SEResNextImageNet(input_shape=None, depth=[3, cardinality=32, width=4, weight_decay=5e-4, include_top=True, weights=None, input_tensor=None, pooling=None, classes=1000)
Definition: se_resnext.py:146
def __bottleneck_block(input, filters=64, cardinality=8, strides=1, weight_decay=5e-4)
Definition: se_resnext.py:312
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def __initial_conv_block_inception(input, weight_decay=5e-4)
Definition: se_resnext.py:253
def _obtain_input_shape(input_shape, default_size, min_size, data_format, require_flatten, weights=None)
def __create_res_next(nb_classes, img_input, include_top, depth=29, cardinality=8, width=4, weight_decay=5e-4, pooling=None)
Definition: se_resnext.py:361