se.py
Go to the documentation of this file.
1 from keras.layers import GlobalAveragePooling2D, Reshape, Dense, multiply, Permute
2 from keras import backend as K
3 
4 
5 def squeeze_excite_block(input, ratio=16):
6  ''' Create a squeeze-excite block
7  Args:
8  input: input tensor
9  filters: number of output filters
10  k: width factor
11 
12  Returns: a keras tensor
13  '''
14  init = input
15  channel_axis = 1 if K.image_data_format() == "channels_first" else -1
16  filters = init._keras_shape[channel_axis]
17  se_shape = (1, 1, filters)
18 
19  se = GlobalAveragePooling2D()(init)
20  se = Reshape(se_shape)(se)
21  se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
22  se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
23 
24  if K.image_data_format() == 'channels_first':
25  se = Permute((3, 1, 2))(se)
26 
27  x = multiply([init, se])
28  return x
def squeeze_excite_block(input, ratio=16)
Definition: se.py:5