ONE - On-device Neural Engine
Loading...
Searching...
No Matches
gen_h5_random_inputs.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2import h5py as h5
3import numpy as np
4import tensorflow as tf
5import argparse
6
7#
8# This script generates a pack of random input data (.h5) expected by the input tflite model
9#
10# Basic usage:
11# gen_h5_inputs.py --model <path/to/tflite/model> --num_data <number/of/data> --output <path/to/output/data>
12# ex: gen_h5_inputs.py --model add.tflite --num_data 3 --output add.tflite.input.h5
13# (This will create add.tflite.input.h5 composed of three random inputs in the same directory as the model)
14parser = argparse.ArgumentParser()
15parser.add_argument('--model', type=str, required=True)
16parser.add_argument('--num_data', type=int, required=True)
17parser.add_argument('--output', type=str, required=True)
18args = parser.parse_args()
19
20model = args.model
21
22num_data = args.num_data
23
24output_path = args.output
25
26# Build TFLite interpreter. (to get the information of model input)
27interpreter = tf.lite.Interpreter(model)
28input_details = interpreter.get_input_details()
29
30# Create h5 file
31h5_file = h5.File(output_path, 'w')
32group = h5_file.create_group("value")
33group.attrs['desc'] = "Input data for " + model
34
35# Generate random data
36for i in range(num_data):
37 sample = group.create_group(str(i))
38 sample.attrs['desc'] = "Input data " + str(i)
39
40 for j in range(len(input_details)):
41 input_detail = input_details[j]
42 print(input_detail["dtype"])
43 if input_detail["dtype"] == np.bool_:
44 # Generate random bool [0, 1]
45 input_data = np.array(np.random.random_integers(0, 1, input_detail["shape"]),
46 input_detail["dtype"])
47 elif input_detail["dtype"] == np.float32:
48 # Generate random input [-5, 5)
49 input_data = np.array(10 * np.random.random_sample(input_detail["shape"]) - 5,
50 input_detail["dtype"])
51 sample.create_dataset(str(j), data=input_data)
52
53h5_file.close()