Static Public Member Functions | List of all members
resnet.ResnetBuilder Class Reference
Inheritance diagram for resnet.ResnetBuilder:

Static Public Member Functions

def build (input_shape, num_outputs, block_fn, repetitions, branches=False)
 
def build_resnet_18_merged (input_shape, num_outputs, merge_type='concat')
 
def build_resnet_18 (input_shape, num_outputs)
 
def build_resnet_34 (input_shape, num_outputs)
 
def build_resnet_50 (input_shape, num_outputs)
 
def build_resnet_101 (input_shape, num_outputs)
 
def build_resnet_152 (input_shape, num_outputs)
 

Detailed Description

Definition at line 194 of file resnet.py.

Member Function Documentation

def resnet.ResnetBuilder.build (   input_shape,
  num_outputs,
  block_fn,
  repetitions,
  branches = False 
)
static
Builds a custom ResNet like architecture.

Args:
    input_shape: The input shape in the form (nb_channels, nb_rows, nb_cols)
    num_outputs: The number of outputs at final softmax layer
    block_fn: The block function to use. This is either `basic_block` or `bottleneck`.
The original paper used basic_block for layers < 50
    repetitions: Number of repetitions of various block units.
At each block unit, the number of filters are doubled and the input size is halved

Returns:
    The keras `Model`.

Definition at line 196 of file resnet.py.

196  def build(input_shape, num_outputs, block_fn, repetitions, branches=False):
197  """Builds a custom ResNet like architecture.
198 
199  Args:
200  input_shape: The input shape in the form (nb_channels, nb_rows, nb_cols)
201  num_outputs: The number of outputs at final softmax layer
202  block_fn: The block function to use. This is either `basic_block` or `bottleneck`.
203  The original paper used basic_block for layers < 50
204  repetitions: Number of repetitions of various block units.
205  At each block unit, the number of filters are doubled and the input size is halved
206 
207  Returns:
208  The keras `Model`.
209  """
211  if len(input_shape) != 3:
212  raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)")
213 
214  '''
215  # Permute dimension order if necessary
216  if K.image_dim_ordering() == 'tf':
217  input_shape = (input_shape[1], input_shape[2], input_shape[0])
218  '''
219 
220  # Load function from str if needed.
221  block_fn = _get_block(block_fn)
222 
223  input = Input(shape=input_shape)
224  conv1 = _conv_bn_relu(filters=64, kernel_size=(7, 7), strides=(2, 2))(input)
225  pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(conv1)
226 
227  block = pool1
228  filters = 64
229  for i, r in enumerate(repetitions):
230  block = _residual_block(block_fn, filters=filters, repetitions=r, is_first_layer=(i == 0))(block)
231  filters *= 2
232 
233  # Last activation
234  block = _bn_relu(block)
235 
236  # Classifier block
237  block_shape = K.int_shape(block)
238  pool2 = AveragePooling2D(pool_size=(block_shape[ROW_AXIS], block_shape[COL_AXIS]),
239  strides=(1, 1))(block)
240  flatten1 = Flatten()(pool2)
241 
242  if branches:
243 
244  # don't include dense layer when the network has more than one branch
245 
246  model = Model(inputs=input, outputs=flatten1)
247 
248  else:
249 
250  # last layer of the network should be a dense layer
251 
252  dense = Dense(units=num_outputs, kernel_initializer="he_normal",
253  activation="softmax")(flatten1)
254  model = Model(inputs=input, outputs=dense)
255 
256  return model
257 
def _bn_relu(input)
Definition: resnet.py:33
def _conv_bn_relu(conv_params)
Definition: resnet.py:40
def _get_block(identifier)
Definition: resnet.py:185
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def _residual_block(block_function, filters, repetitions, is_first_layer=False)
Definition: resnet.py:106
Definition: input.h:9
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
def build(input_shape, num_outputs, block_fn, repetitions, branches=False)
Definition: resnet.py:196
def _handle_dim_ordering()
Definition: resnet.py:171
def resnet.ResnetBuilder.build_resnet_101 (   input_shape,
  num_outputs 
)
static

Definition at line 329 of file resnet.py.

329  def build_resnet_101(input_shape, num_outputs):
330  return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 23, 3])
331 
def build_resnet_101(input_shape, num_outputs)
Definition: resnet.py:329
def resnet.ResnetBuilder.build_resnet_152 (   input_shape,
  num_outputs 
)
static

Definition at line 333 of file resnet.py.

333  def build_resnet_152(input_shape, num_outputs):
334  return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 8, 36, 3])
335 
def build_resnet_152(input_shape, num_outputs)
Definition: resnet.py:333
def resnet.ResnetBuilder.build_resnet_18 (   input_shape,
  num_outputs 
)
static

Definition at line 317 of file resnet.py.

317  def build_resnet_18(input_shape, num_outputs):
318  return ResnetBuilder.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2])
319 
def build_resnet_18(input_shape, num_outputs)
Definition: resnet.py:317
def resnet.ResnetBuilder.build_resnet_18_merged (   input_shape,
  num_outputs,
  merge_type = 'concat' 
)
static

Definition at line 259 of file resnet.py.

259  def build_resnet_18_merged(input_shape, num_outputs, merge_type='concat'):
260 
261  branches = input_shape[2] # number of branches == number of views
262  input_shape[2] = 1 # convert shape from (PLANES, CELLS, VIEWS) into (PLANES, CELLS, 1)
263 
264  branches_inputs = [] # inputs of the network branches
265  branches_outputs = [] # outputs of the network branches
266 
267  for branch in range(branches):
268 
269  # generate branche and save its input and output
270 
271  branch_model = ResnetBuilder.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2], branches=True)
272  branches_inputs.append(branch_model.input)
273  branches_outputs.append(branch_model.output)
274 
275  # merge the branches
276 
277  if merge_type == 'add':
278 
279  merged = Add()(branches_outputs)
280 
281  elif merge_type == 'sub':
282 
283  merged = Substract()(branches_outputs)
284 
285  elif merge_type == 'mul':
286 
287  merged = Multiply()(branches_outputs)
288 
289  elif merge_type == 'avg':
290 
291  merged = Average()(branches_outputs)
292 
293  elif merge_type == 'max':
294 
295  merged = Maximum()(branches_outputs)
296 
297  elif merge_type == 'dot':
298 
299  merged = Dot()(branches_outputs)
300 
301  else:
302 
303  merged = Concatenate()(branches_outputs)
304 
305  # dense output layer
306 
307  dense = Dense(units=num_outputs, kernel_initializer="he_normal", activation="softmax")(merged)
308 
309  # generate final model
310 
311  model = Model(branches_inputs, dense)
312  #model = Model(branches_inputs, merged)
313 
314  return model
315 
def build_resnet_18_merged(input_shape, num_outputs, merge_type='concat')
Definition: resnet.py:259
def resnet.ResnetBuilder.build_resnet_34 (   input_shape,
  num_outputs 
)
static

Definition at line 321 of file resnet.py.

321  def build_resnet_34(input_shape, num_outputs):
322  return ResnetBuilder.build(input_shape, num_outputs, basic_block, [3, 4, 6, 3])
323 
def build_resnet_34(input_shape, num_outputs)
Definition: resnet.py:321
def resnet.ResnetBuilder.build_resnet_50 (   input_shape,
  num_outputs 
)
static

Definition at line 325 of file resnet.py.

325  def build_resnet_50(input_shape, num_outputs):
326  return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 6, 3])
327 
def build_resnet_50(input_shape, num_outputs)
Definition: resnet.py:325

The documentation for this class was generated from the following file: