dump_to_simple_cpp.py
Go to the documentation of this file.
1 import numpy as np
2 np.random.seed(1337)
3 from keras.models import Sequential, model_from_json
4 import json
5 import argparse
6 
7 parser = argparse.ArgumentParser(description='This is a simple script to dump Keras model into simple format suitable for porting into pure C++ model')
8 
9 parser.add_argument('-a', '--architecture', help="JSON with model architecture", required=True)
10 parser.add_argument('-w', '--weights', help="Model weights in HDF5 format", required=True)
11 parser.add_argument('-o', '--output', help="Ouput file name", required=True)
12 
13 args = parser.parse_args()
14 
15 print 'Read architecture from', args.architecture
16 print 'Read weights from', args.weights
17 print 'Writing to', args.output
18 
19 
20 arch = open(args.architecture).read()
21 model = model_from_json(arch)
22 model.load_weights(args.weights)
23 model.compile(loss='categorical_crossentropy', optimizer='adadelta')
24 arch = json.loads(arch)
25 
26 with open(args.output, 'w') as fout:
27  fout.write('layers ' + str(len(model.layers)) + '\n')
28 
29  layers = []
30  for ind, l in enumerate(arch["config"]):
31  print ind, l
32  fout.write('layer ' + str(ind) + ' ' + l['class_name'] + '\n')
33 
34  print str(ind), l['class_name']
35  layers += [l['class_name']]
36  if l['class_name'] == 'Convolution2D':
37  #fout.write(str(l['config']['nb_filter']) + ' ' + str(l['config']['nb_col']) + ' ' + str(l['config']['nb_row']) + ' ')
38 
39  #if 'batch_input_shape' in l['config']:
40  # fout.write(str(l['config']['batch_input_shape'][1]) + ' ' + str(l['config']['batch_input_shape'][2]) + ' ' + str(l['config']['batch_input_shape'][3]))
41  #fout.write('\n')
42 
43  W = model.layers[ind].get_weights()[0]
44  print W.shape
45  fout.write(str(W.shape[0]) + ' ' + str(W.shape[1]) + ' ' + str(W.shape[2]) + ' ' + str(W.shape[3]) + ' ' + l['config']['border_mode'] + '\n')
46 
47  for i in range(W.shape[0]):
48  for j in range(W.shape[1]):
49  for k in range(W.shape[2]):
50  fout.write(str(W[i,j,k]) + '\n')
51  fout.write(str(model.layers[ind].get_weights()[1]) + '\n')
52 
53  if l['class_name'] == 'Activation':
54  fout.write(l['config']['activation'] + '\n')
55  if l['class_name'] == 'MaxPooling2D':
56  fout.write(str(l['config']['pool_size'][0]) + ' ' + str(l['config']['pool_size'][1]) + '\n')
57  #if l['class_name'] == 'Flatten':
58  # print l['config']['name']
59  if l['class_name'] == 'Dense':
60  #fout.write(str(l['config']['output_dim']) + '\n')
61  W = model.layers[ind].get_weights()[0]
62  print W.shape
63  fout.write(str(W.shape[0]) + ' ' + str(W.shape[1]) + '\n')
64 
65 
66  for w in W:
67  fout.write(str(w) + '\n')
68  fout.write(str(model.layers[ind].get_weights()[1]) + '\n')
int open(const char *, int)
Opens a file descriptor.
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
int read(int, char *, size_t)
Read bytes from a file descriptor.
static QCString str