11parser = argparse.ArgumentParser()
12parser.add_argument(
'--num_data', type=int, required=
True)
13parser.add_argument(
'--output_dir', type=str, required=
True)
14parser.add_argument(
'--artifact_dir', type=str, required=
True)
15parser.add_argument(
'--model', type=str, required=
True, nargs=
'+')
16args = parser.parse_args()
18num_data = args.num_data
19output_dir = args.output_dir
20artifact_dir = args.artifact_dir
21model_list = args.model
23for model_name
in model_list:
24 model_path = os.path.join(artifact_dir, model_name +
'.tflite')
25 h5_path = os.path.join(output_dir, model_name +
'.tflite.input.h5')
27 interpreter = tf.lite.Interpreter(model_path)
28 input_details = interpreter.get_input_details()
31 h5_file = h5.File(h5_path,
'w')
32 group = h5_file.create_group(
"value")
33 group.attrs[
'desc'] =
"Input data for " + model_path
36 for i
in range(num_data):
37 sample = group.create_group(
str(i))
38 sample.attrs[
'desc'] =
"Input data " +
str(i)
40 for j
in range(len(input_details)):
41 input_detail = input_details[j]
42 print(input_detail[
"dtype"])
43 if input_detail[
"dtype"] == np.bool_:
45 input_data = np.array(
46 np.random.random_integers(0, 1, input_detail[
"shape"]),
47 input_detail[
"dtype"])
48 elif input_detail[
"dtype"] == np.float32:
50 input_data = np.array(
51 10 * np.random.random_sample(input_detail[
"shape"]) - 5,
52 input_detail[
"dtype"])
53 sample.create_dataset(
str(j), data=input_data)