3 from keras.models
import Sequential, model_from_json
7 parser = argparse.ArgumentParser(description=
'This is a simple script to dump Keras model into simple format suitable for porting into pure C++ model')
9 parser.add_argument(
'-a',
'--architecture', help=
"JSON with model architecture", required=
True)
10 parser.add_argument(
'-w',
'--weights', help=
"Model weights in HDF5 format", required=
True)
11 parser.add_argument(
'-o',
'--output', help=
"Ouput file name", required=
True)
13 args = parser.parse_args()
15 print 'Read architecture from', args.architecture
16 print 'Read weights from', args.weights
17 print 'Writing to', args.output
21 model = model_from_json(arch)
22 model.load_weights(args.weights)
23 model.compile(loss=
'categorical_crossentropy', optimizer=
'adadelta')
24 arch = json.loads(arch)
26 with
open(args.output,
'w')
as fout:
27 fout.write(
'layers ' +
str(len(model.layers)) +
'\n')
32 fout.write(
'layer ' +
str(ind) +
' ' + l[
'class_name'] +
'\n')
34 print str(ind), l[
'class_name']
35 layers += [l[
'class_name']]
36 if l[
'class_name'] ==
'Convolution2D':
43 W = model.layers[ind].get_weights()[0]
45 fout.write(
str(W.shape[0]) +
' ' +
str(W.shape[1]) +
' ' +
str(W.shape[2]) +
' ' +
str(W.shape[3]) +
' ' + l[
'config'][
'border_mode'] +
'\n')
47 for i
in range(W.shape[0]):
48 for j
in range(W.shape[1]):
49 for k
in range(W.shape[2]):
50 fout.write(
str(W[i,j,k]) +
'\n')
51 fout.write(
str(model.layers[ind].get_weights()[1]) +
'\n')
53 if l[
'class_name'] ==
'Activation':
54 fout.write(l[
'config'][
'activation'] +
'\n')
55 if l[
'class_name'] ==
'MaxPooling2D':
56 fout.write(
str(l[
'config'][
'pool_size'][0]) +
' ' +
str(l[
'config'][
'pool_size'][1]) +
'\n')
59 if l[
'class_name'] ==
'Dense':
61 W = model.layers[ind].get_weights()[0]
63 fout.write(
str(W.shape[0]) +
' ' +
str(W.shape[1]) +
'\n')
67 fout.write(
str(w) +
'\n')
68 fout.write(
str(model.layers[ind].get_weights()[1]) +
'\n')
int open(const char *, int)
Opens a file descriptor.
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
int read(int, char *, size_t)
Read bytes from a file descriptor.