ONE - On-device Neural Engine
Loading...
Searching...
No Matches
gen_h5_random_inputs_all.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2import h5py as h5
3import numpy as np
4import tensorflow as tf
5import argparse
6import os
7
8#
9# This script generates a pack of random input data (.h5) expected by the input tflite models
10#
11parser = argparse.ArgumentParser()
12parser.add_argument('--num_data', type=int, required=True)
13parser.add_argument('--output_dir', type=str, required=True)
14parser.add_argument('--artifact_dir', type=str, required=True)
15parser.add_argument('--model', type=str, required=True, nargs='+')
16args = parser.parse_args()
17
18num_data = args.num_data
19output_dir = args.output_dir
20artifact_dir = args.artifact_dir
21model_list = args.model
22
23for model_name in model_list:
24 model_path = os.path.join(artifact_dir, model_name + '.tflite')
25 h5_path = os.path.join(output_dir, model_name + '.tflite.input.h5')
26 # Build TFLite interpreter. (to get the information of model input)
27 interpreter = tf.lite.Interpreter(model_path)
28 input_details = interpreter.get_input_details()
29
30 # Create h5 file
31 h5_file = h5.File(h5_path, 'w')
32 group = h5_file.create_group("value")
33 group.attrs['desc'] = "Input data for " + model_path
34
35 # Generate random data
36 for i in range(num_data):
37 sample = group.create_group(str(i))
38 sample.attrs['desc'] = "Input data " + str(i)
39
40 for j in range(len(input_details)):
41 input_detail = input_details[j]
42 print(input_detail["dtype"])
43 if input_detail["dtype"] == np.bool_:
44 # Generate random bool [0, 1]
45 input_data = np.array(
46 np.random.random_integers(0, 1, input_detail["shape"]),
47 input_detail["dtype"])
48 elif input_detail["dtype"] == np.float32:
49 # Generate random input [-5, 5)
50 input_data = np.array(
51 10 * np.random.random_sample(input_detail["shape"]) - 5,
52 input_detail["dtype"])
53 sample.create_dataset(str(j), data=input_data)
54
55 h5_file.close()