1 from keras
import layers
2 from keras
import models
22 ResNeXt by default. For ResNet set `cardinality` = 1 above. 25 def add_common_layers(y):
26 y = layers.BatchNormalization()(y)
27 y = layers.LeakyReLU()(y)
31 def grouped_convolution(y, nb_channels, _strides):
34 return layers.Conv2D(nb_channels, kernel_size=(3, 3), strides=_strides, padding=
'same')(y)
36 assert not nb_channels % cardinality
37 _d = nb_channels // cardinality
42 for j
in range(cardinality):
43 group = layers.Lambda(
lambda z: z[:, :, :, j * _d:j * _d + _d])(y)
44 groups.append(layers.Conv2D(_d, kernel_size=(3, 3), strides=_strides, padding=
'same')(group))
47 y = layers.concatenate(groups)
51 def residual_block(y, nb_channels_in, nb_channels_out, _strides=(1, 1), _project_shortcut=
False):
53 Our network consists of a stack of residual blocks. These blocks have the same topology, 54 and are subject to two simple rules: 55 - If producing spatial maps of the same size, the blocks share the same hyper-parameters (width and filter sizes). 56 - Each time the spatial map is down-sampled by a factor of 2, the width of the blocks is multiplied by a factor of 2. 61 y = layers.Conv2D(nb_channels_in, kernel_size=(1, 1), strides=(1, 1), padding=
'same')(y)
62 y = add_common_layers(y)
65 y = grouped_convolution(y, nb_channels_in, _strides=_strides)
66 y = add_common_layers(y)
68 y = layers.Conv2D(nb_channels_out, kernel_size=(1, 1), strides=(1, 1), padding=
'same')(y)
70 y = layers.BatchNormalization()(y)
73 if _project_shortcut
or _strides != (1, 1):
76 shortcut = layers.Conv2D(nb_channels_out, kernel_size=(1, 1), strides=_strides, padding=
'same')(shortcut)
77 shortcut = layers.BatchNormalization()(shortcut)
79 y = layers.add([shortcut, y])
83 y = layers.LeakyReLU()(y)
88 x = layers.Conv2D(64, kernel_size=(7, 7), strides=(2, 2), padding=
'same')(x)
89 x = add_common_layers(x)
92 x = layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding=
'same')(x)
94 project_shortcut =
True if i == 0
else False 95 x = residual_block(x, 128, 256, _project_shortcut=project_shortcut)
100 strides = (2, 2)
if i == 0
else (1, 1)
101 x = residual_block(x, 256, 512, _strides=strides)
105 strides = (2, 2)
if i == 0
else (1, 1)
106 x = residual_block(x, 512, 1024, _strides=strides)
110 strides = (2, 2)
if i == 0
else (1, 1)
111 x = residual_block(x, 1024, 2048, _strides=strides)
113 x = layers.GlobalAveragePooling2D()(x)
114 x = layers.Dense(13)(x)
119 image_tensor = layers.Input(shape=(img_height, img_width, img_channels))
122 model = models.Model(inputs=[image_tensor], outputs=[network_output])