generate_datataset.py
Go to the documentation of this file.
1 """
2 This is the dataset generator module.
3 """
4 
5 __version__ = '1.0'
6 __author__ = 'Saul Alonso-Monsalve'
7 __email__ = "saul.alonso.monsalve@cern.ch"
8 
9 import numpy as np
10 import glob
11 import ast
12 import ntpath
13 import pickle
14 import configparser
15 import logging
16 import sys
17 import time
18 import random
19 import zlib
20 
21 from sklearn.utils import class_weight
22 from collections import Counter
23 
24 '''
25 ****************************************
26 ************** PARAMETERS **************
27 ****************************************
28 '''
29 
30 logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
31 
32 config = configparser.ConfigParser()
33 config.read('config/config.ini')
34 
35 # random
36 
37 SEED = int(config['random']['seed'])
38 
39 if SEED == -1:
40  SEED = int(time.time())
41 
42 np.random.seed(SEED)
43 
44 # images
45 
46 IMAGES_PATH = config['images']['path']
47 VIEWS = int(config['images']['views'])
48 PLANES = int(config['images']['planes'])
49 CELLS = int(config['images']['cells'])
50 
51 # dataset
52 
53 DATASET_PATH = config['dataset']['path']
54 PARTITION_PREFIX = config['dataset']['partition_prefix']
55 LABELS_PREFIX = config['dataset']['labels_prefix']
56 UNIFORM = ast.literal_eval(config['dataset']['uniform'])
57 
58 # model
59 
60 OUTPUTS = int(config['model']['outputs'])
61 
62 # train
63 
64 TRAIN_FRACTION = float(config['train']['fraction'])
65 WEIGHTED_LOSS_FUNCTION = ast.literal_eval(config['train']['weighted_loss_function'])
66 CLASS_WEIGHTS_PREFIX = config['train']['class_weights_prefix']
67 
68 # validation
69 
70 VALIDATION_FRACTION = float(config['validation']['fraction'])
71 
72 # test
73 
74 TEST_FRACTION = float(config['test']['fraction'])
75 
76 if((TRAIN_FRACTION + VALIDATION_FRACTION + TEST_FRACTION) > 1):
77  logging.error('(TRAIN_FRACTION + VALIDATION_FRACTION + TEST_FRACTION) must be <= 1')
78  exit(-1)
79 
80 # Return 3 if value > 2
81 def normalize(value):
82  if value > 2:
83  return 3
84  return value
85 
86 # Return 1 if N < 0 else 0
87 def normalize2(value):
88  if value < 0:
89  return 1
90  return 0
91 
92 count_flavour = [0]*4
93 count_category = [0]*14
94 
95 '''
96 ****************************************
97 *************** DATASETS ***************
98 ****************************************
99 '''
100 
101 partition = {'train' : [], 'validation' : [], 'test' : []} # Train, validation, and test IDs
102 labels = {} # ID : label
103 y_train = []
104 y1_class_weights = []
105 y2_class_weights = []
106 
107 only_train = ['nutau2', 'nutau3']
108 
109 if UNIFORM:
110  pass
111 
112 # Iterate through label folders
113 
114 logging.info('Filling datasets...')
115 
116 count_neutrinos = 0
117 count_antineutrinos = 0
118 count_empty_views = 0
119 count_empty_events = 0
120 count_less_10nonzero_views = 0
121 count_less_10nonzero_events = 0
122 
123 for images_path in glob.iglob(IMAGES_PATH + '/*'):
124 
125  count_train, count_val, count_test = (0, 0, 0)
126 
127  print images_path
128 
129  if 'nutau2' in images_path or 'nutau3' in images_path:
130  continue
131 
132  files = list(glob.iglob(images_path + "/images/*"))
133  random.shuffle(files)
134 
135  for imagefile in files:
136  #print imagefile
137  ID = imagefile.split("/")[-1][:-3]
138  infofile = images_path + '/info/' + ID + '.info'
139 
140  #print infofile
141 
142  info = open(infofile, 'r').readlines()
143  fInt = int(info[0].strip())
144  flavour = fInt // 4
145  interaction = fInt % 4
146 
147  fNuEnergy = float(info[1].strip())
148  fLepEnergy = float(info[2].strip())
149  fRecoNueEnergy = float(info[3].strip())
150  fRecoNumuEnergy = float(info[4].strip())
151  fEventWeight = float(info[5].strip())
152 
153  fNuPDG = normalize2(int(info[6].strip()))
154  fNProton = normalize(int(info[7].strip()))
155  fNPion = normalize(int(info[8].strip()))
156  fNPizero = normalize(int(info[9].strip()))
157  fNNeutron = normalize(int(info[10].strip()))
158 
159  # special case: NC
160  if fInt == 13:
161  fNuPDG = -1
162  flavour = 3
163  interaction = -1
164 
165  if fNuPDG == 0:
166  count_neutrinos+=1
167  elif fNuPDG == 1:
168  count_antineutrinos+=1
169 
170  '''
171  print "ID:", ID
172  print "Info:", info
173  print "fInt:", fInt
174  print "flavour:", flavour
175  print "interaction:", interaction
176  print "fNuEnergy:", fNuEnergy
177  print "fLepEnergy:", fLepEnergy
178  print "fRecoNueEnergy:", fRecoNueEnergy
179  print "fRecoNumuEnergy:", fRecoNumuEnergy
180  print "fEventWeight:", fEventWeight
181  print "fNuPDG:", fNuPDG
182  print "fNProton:", fNProton
183  print "fNPion:", fNPion
184  print "fNPizero:", fNPizero
185  print "fNNeutron:", fNNeutron
186  '''
187 
188  random_value = np.random.uniform(0,1)
189 
190  with open(imagefile, 'rb') as image_file:
191  pixels = np.fromstring(zlib.decompress(image_file.read()), dtype=np.uint8, sep='').reshape(VIEWS, PLANES, CELLS)
192 
193  #pixels = np.load(self.images_path + '/' + labels[ID] + '/' + ID + '.npy')
194 
195  views = [None]*VIEWS
196  empty_view = [0,0,0]
197  non_empty_view = [0,0,0]
198 
199  # Events that contain any view with less than 10 non-zero pixels are filtered out.
200  # This is to remove empty and almost empty images from the set
201  count_empty = 0
202  count_less_10nonzero = 0
203  for i in range(len(views)):
204  views[i] = pixels[i, :, :].reshape(PLANES, CELLS, 1)
205  maxi = np.max(views[i]) # max pixel value (normally 255)
206  mini = np.min(views[i]) # min pixel value (normally 0)
207  nonzero = np.count_nonzero(views[i]) # non-zero pixel values
208  total = np.sum(views[i])
209  avg = np.mean(views[i])
210  if nonzero == 0:
211  count_empty+=1
212  count_empty_views+=1
213  if nonzero < 10:
214  count_less_10nonzero+=1
215  count_less_10nonzero_views+=1
216  if count_empty == len(views):
217  count_empty_events+=1
218  if count_less_10nonzero > 0:
219  count_less_10nonzero_events+=1
220  continue
221 
222  # Fill training set
223  if(random_value < TRAIN_FRACTION): #or True in [k in ID for k in only_train]):
224  #if True in [k in ID for k in only_train] and (flavour != 2 or fNuEnergy > 15):
225  # continue
226  count_flavour[flavour] += 1
227  count_category[fInt] += 1
228  partition['train'].append(ID)
229  count_train += 1
230  # Fill validation set
231  elif(random_value < (TRAIN_FRACTION + VALIDATION_FRACTION)):
232  partition['validation'].append(ID)
233  count_val += 1
234  # Fill test set
235  elif(random_value < (TRAIN_FRACTION + VALIDATION_FRACTION + TEST_FRACTION)):
236  partition['test'].append(ID)
237  count_test += 1
238 
239  # Label
240  if OUTPUTS == 1:
241  labels[ID] = fInt
242  elif OUTPUTS == 5:
243  labels[ID] = [flavour, fNProton, fNPion, fNPizero, fNNeutron]
244  else:
245  labels[ID] = [fNuPDG, flavour, interaction, fNProton, fNPion, fNPizero, fNNeutron]
246 
247  logging.debug('%d train images', count_train)
248  logging.debug('%d val images', count_val)
249  logging.debug('%d test images', count_test)
250  logging.debug('%d total images', count_train + count_val + count_test)
251 
252 # Print dataset statistics
253 
254 print count_flavour
255 print count_category
256 
257 logging.info('Number of neutrino events: %d', count_neutrinos)
258 logging.info('Number of antineutrino events: %d', count_antineutrinos)
259 
260 logging.info('Number of empty views: %d', count_empty_views)
261 logging.info('Number of views with <10 non-zero pixels: %d', count_less_10nonzero_views)
262 logging.info('Number of empty events: %d', count_empty_events)
263 logging.info('Number of events with at least one view with <10 non-zero pixels: %d', count_less_10nonzero_events)
264 
265 logging.info('Number of training examples: %d', len(partition['train']))
266 logging.info('Number of validation examples: %d', len(partition['validation']))
267 logging.info('Number of test examples: %d', len(partition['test']))
268 
269 # Serialize partition and labels
270 
271 logging.info('Serializing datasets...')
272 
273 with open(DATASET_PATH + PARTITION_PREFIX + '.p', 'w') as partition_file:
274  pickle.dump(partition, partition_file)
275 
276 with open(DATASET_PATH + LABELS_PREFIX + '.p', 'w') as labels_file:
277  pickle.dump(labels, labels_file)
int open(const char *, int)
Opens a file descriptor.
if(!yymsg) yymsg