resnet.py
Go to the documentation of this file.
1 '''
2 Based on https://github.com/raghakot/keras-resnet/blob/master/resnet.py
3 '''
4 
5 from __future__ import division
6 
7 import six
8 from keras.models import Model
9 from keras.layers import (
10  Input,
11  Activation,
12  Dense,
13  Flatten,
14  Add,
15  Subtract,
16  Multiply,
17  Average,
18  Maximum,
19  Concatenate,
20  Dot
21 )
22 from keras.layers.convolutional import (
23  Conv2D,
24  MaxPooling2D,
25  AveragePooling2D
26 )
27 from keras.layers.merge import add
28 from keras.layers.normalization import BatchNormalization
29 from keras.regularizers import l2
30 from keras import backend as K
31 
32 
33 def _bn_relu(input):
34  """Helper to build a BN -> relu block
35  """
36  norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
37  return Activation("relu")(norm)
38 
39 
40 def _conv_bn_relu(**conv_params):
41  """Helper to build a conv -> BN -> relu block
42  """
43  filters = conv_params["filters"]
44  kernel_size = conv_params["kernel_size"]
45  strides = conv_params.setdefault("strides", (1, 1))
46  kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
47  padding = conv_params.setdefault("padding", "same")
48  kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))
49 
50  def f(input):
51  conv = Conv2D(filters=filters, kernel_size=kernel_size,
52  strides=strides, padding=padding,
53  kernel_initializer=kernel_initializer,
54  kernel_regularizer=kernel_regularizer)(input)
55  return _bn_relu(conv)
56 
57  return f
58 
59 
60 def _bn_relu_conv(**conv_params):
61  """Helper to build a BN -> relu -> conv block.
62  This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf
63  """
64  filters = conv_params["filters"]
65  kernel_size = conv_params["kernel_size"]
66  strides = conv_params.setdefault("strides", (1, 1))
67  kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
68  padding = conv_params.setdefault("padding", "same")
69  kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))
70 
71  def f(input):
72  activation = _bn_relu(input)
73  return Conv2D(filters=filters, kernel_size=kernel_size,
74  strides=strides, padding=padding,
75  kernel_initializer=kernel_initializer,
76  kernel_regularizer=kernel_regularizer)(activation)
77 
78  return f
79 
80 
81 def _shortcut(input, residual):
82  """Adds a shortcut between input and residual block and merges them with "sum"
83  """
84  # Expand channels of shortcut to match residual.
85  # Stride appropriately to match residual (width, height)
86  # Should be int if network architecture is correctly configured.
87  input_shape = K.int_shape(input)
88  residual_shape = K.int_shape(residual)
89  stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS]))
90  stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS]))
91  equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]
92 
93  shortcut = input
94  # 1 X 1 conv if shape is different. Else identity.
95  if stride_width > 1 or stride_height > 1 or not equal_channels:
96  shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS],
97  kernel_size=(1, 1),
98  strides=(stride_width, stride_height),
99  padding="valid",
100  kernel_initializer="he_normal",
101  kernel_regularizer=l2(0.0001))(input)
102 
103  return add([shortcut, residual])
104 
105 
106 def _residual_block(block_function, filters, repetitions, is_first_layer=False):
107  """Builds a residual block with repeating bottleneck blocks.
108  """
109  def f(input):
110  for i in range(repetitions):
111  init_strides = (1, 1)
112  if i == 0 and not is_first_layer:
113  init_strides = (2, 2)
114  input = block_function(filters=filters, init_strides=init_strides,
115  is_first_block_of_first_layer=(is_first_layer and i == 0))(input)
116  return input
117 
118  return f
119 
120 
121 def basic_block(filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
122  """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
123  Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
124  """
125  def f(input):
126 
127  if is_first_block_of_first_layer:
128  # don't repeat bn->relu since we just did bn->relu->maxpool
129  conv1 = Conv2D(filters=filters, kernel_size=(3, 3),
130  strides=init_strides,
131  padding="same",
132  kernel_initializer="he_normal",
133  kernel_regularizer=l2(1e-4))(input)
134  else:
135  conv1 = _bn_relu_conv(filters=filters, kernel_size=(3, 3),
136  strides=init_strides)(input)
137 
138  residual = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv1)
139  return _shortcut(input, residual)
140 
141  return f
142 
143 
144 def bottleneck(filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
145  """Bottleneck architecture for > 34 layer resnet.
146  Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
147 
148  Returns:
149  A final conv layer of filters * 4
150  """
151  def f(input):
152 
153  if is_first_block_of_first_layer:
154  # don't repeat bn->relu since we just did bn->relu->maxpool
155  conv_1_1 = Conv2D(filters=filters, kernel_size=(1, 1),
156  strides=init_strides,
157  padding="same",
158  kernel_initializer="he_normal",
159  kernel_regularizer=l2(1e-4))(input)
160  else:
161  conv_1_1 = _bn_relu_conv(filters=filters, kernel_size=(1, 1),
162  strides=init_strides)(input)
163 
164  conv_3_3 = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv_1_1)
165  residual = _bn_relu_conv(filters=filters * 4, kernel_size=(1, 1))(conv_3_3)
166  return _shortcut(input, residual)
167 
168  return f
169 
170 
172  global ROW_AXIS
173  global COL_AXIS
174  global CHANNEL_AXIS
175  if K.image_dim_ordering() == 'tf':
176  ROW_AXIS = 1
177  COL_AXIS = 2
178  CHANNEL_AXIS = 3
179  else:
180  CHANNEL_AXIS = 1
181  ROW_AXIS = 2
182  COL_AXIS = 3
183 
184 
185 def _get_block(identifier):
186  if isinstance(identifier, six.string_types):
187  res = globals().get(identifier)
188  if not res:
189  raise ValueError('Invalid {}'.format(identifier))
190  return res
191  return identifier
192 
193 
194 class ResnetBuilder(object):
195  @staticmethod
196  def build(input_shape, num_outputs, block_fn, repetitions, branches=False):
197  """Builds a custom ResNet like architecture.
198 
199  Args:
200  input_shape: The input shape in the form (nb_channels, nb_rows, nb_cols)
201  num_outputs: The number of outputs at final softmax layer
202  block_fn: The block function to use. This is either `basic_block` or `bottleneck`.
203  The original paper used basic_block for layers < 50
204  repetitions: Number of repetitions of various block units.
205  At each block unit, the number of filters are doubled and the input size is halved
206 
207  Returns:
208  The keras `Model`.
209  """
211  if len(input_shape) != 3:
212  raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)")
213 
214  '''
215  # Permute dimension order if necessary
216  if K.image_dim_ordering() == 'tf':
217  input_shape = (input_shape[1], input_shape[2], input_shape[0])
218  '''
219 
220  # Load function from str if needed.
221  block_fn = _get_block(block_fn)
222 
223  input = Input(shape=input_shape)
224  conv1 = _conv_bn_relu(filters=64, kernel_size=(7, 7), strides=(2, 2))(input)
225  pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(conv1)
226 
227  block = pool1
228  filters = 64
229  for i, r in enumerate(repetitions):
230  block = _residual_block(block_fn, filters=filters, repetitions=r, is_first_layer=(i == 0))(block)
231  filters *= 2
232 
233  # Last activation
234  block = _bn_relu(block)
235 
236  # Classifier block
237  block_shape = K.int_shape(block)
238  pool2 = AveragePooling2D(pool_size=(block_shape[ROW_AXIS], block_shape[COL_AXIS]),
239  strides=(1, 1))(block)
240  flatten1 = Flatten()(pool2)
241 
242  if branches:
243 
244  # don't include dense layer when the network has more than one branch
245 
246  model = Model(inputs=input, outputs=flatten1)
247 
248  else:
249 
250  # last layer of the network should be a dense layer
251 
252  dense = Dense(units=num_outputs, kernel_initializer="he_normal",
253  activation="softmax")(flatten1)
254  model = Model(inputs=input, outputs=dense)
255 
256  return model
257 
258  @staticmethod
259  def build_resnet_18_merged(input_shape, num_outputs, merge_type='concat'):
260 
261  branches = input_shape[2] # number of branches == number of views
262  input_shape[2] = 1 # convert shape from (PLANES, CELLS, VIEWS) into (PLANES, CELLS, 1)
263 
264  branches_inputs = [] # inputs of the network branches
265  branches_outputs = [] # outputs of the network branches
266 
267  for branch in range(branches):
268 
269  # generate branche and save its input and output
270 
271  branch_model = ResnetBuilder.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2], branches=True)
272  branches_inputs.append(branch_model.input)
273  branches_outputs.append(branch_model.output)
274 
275  # merge the branches
276 
277  if merge_type == 'add':
278 
279  merged = Add()(branches_outputs)
280 
281  elif merge_type == 'sub':
282 
283  merged = Substract()(branches_outputs)
284 
285  elif merge_type == 'mul':
286 
287  merged = Multiply()(branches_outputs)
288 
289  elif merge_type == 'avg':
290 
291  merged = Average()(branches_outputs)
292 
293  elif merge_type == 'max':
294 
295  merged = Maximum()(branches_outputs)
296 
297  elif merge_type == 'dot':
298 
299  merged = Dot()(branches_outputs)
300 
301  else:
302 
303  merged = Concatenate()(branches_outputs)
304 
305  # dense output layer
306 
307  dense = Dense(units=num_outputs, kernel_initializer="he_normal", activation="softmax")(merged)
308 
309  # generate final model
310 
311  model = Model(branches_inputs, dense)
312  #model = Model(branches_inputs, merged)
313 
314  return model
315 
316  @staticmethod
317  def build_resnet_18(input_shape, num_outputs):
318  return ResnetBuilder.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2])
319 
320  @staticmethod
321  def build_resnet_34(input_shape, num_outputs):
322  return ResnetBuilder.build(input_shape, num_outputs, basic_block, [3, 4, 6, 3])
323 
324  @staticmethod
325  def build_resnet_50(input_shape, num_outputs):
326  return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 6, 3])
327 
328  @staticmethod
329  def build_resnet_101(input_shape, num_outputs):
330  return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 23, 3])
331 
332  @staticmethod
333  def build_resnet_152(input_shape, num_outputs):
334  return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 8, 36, 3])
def _bn_relu(input)
Definition: resnet.py:33
def _bn_relu_conv(conv_params)
Definition: resnet.py:60
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
Definition: qstring.cpp:11496
Coord add(Coord c1, Coord c2)
Definition: restypedef.cpp:23
def build_resnet_50(input_shape, num_outputs)
Definition: resnet.py:325
def _conv_bn_relu(conv_params)
Definition: resnet.py:40
def basic_block(filters, init_strides=(1, 1), is_first_block_of_first_layer=False)
Definition: resnet.py:121
def _get_block(identifier)
Definition: resnet.py:185
def _shortcut(input, residual)
Definition: resnet.py:81
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def build_resnet_18(input_shape, num_outputs)
Definition: resnet.py:317
def _residual_block(block_function, filters, repetitions, is_first_layer=False)
Definition: resnet.py:106
def build_resnet_152(input_shape, num_outputs)
Definition: resnet.py:333
Definition: input.h:9
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
def build_resnet_34(input_shape, num_outputs)
Definition: resnet.py:321
def bottleneck(filters, init_strides=(1, 1), is_first_block_of_first_layer=False)
Definition: resnet.py:144
auto const & get(AssnsNode< L, R, D > const &r)
Definition: AssnsNode.h:115
def build_resnet_18_merged(input_shape, num_outputs, merge_type='concat')
Definition: resnet.py:259
def build(input_shape, num_outputs, block_fn, repetitions, branches=False)
Definition: resnet.py:196
def build_resnet_101(input_shape, num_outputs)
Definition: resnet.py:329
def _handle_dim_ordering()
Definition: resnet.py:171