58 assert(_inputs ==
nullptr);
60 size_t num_input_nodes;
61 status = OrtSessionGetInputCount(_session, &num_input_nodes);
64 _inputs = std::make_unique<TensorSet>(_allocator.get(), num_input_nodes);
66 for (
size_t i = 0; i < num_input_nodes; ++i)
69 status = OrtSessionGetInputName(_session, i, _allocator.get(), &input_name);
72 assert(input_name !=
nullptr);
74 std::string name{input_name};
75 _allocator->Free(input_name);
77 OrtTypeInfo *typeinfo;
78 status = OrtSessionGetInputTypeInfo(_session, i, &typeinfo);
81 const OrtTensorTypeAndShapeInfo *tensor_info = OrtCastTypeInfoToTensorInfo(typeinfo);
82 ONNXTensorElementDataType type = OrtGetTensorElementType(tensor_info);
84 uint32_t num_dims = OrtGetNumOfDimensions(tensor_info);
85 std::vector<size_t> dims(num_dims);
86 OrtGetDimensions(tensor_info, (int64_t *)dims.data(), num_dims);
93 for (uint32_t j = 0; j < num_dims; ++j)
100 OrtReleaseTypeInfo(typeinfo);
102 _inputs->set(i, name, type, dims);
110 assert(_outputs ==
nullptr);
112 size_t num_output_nodes;
113 status = OrtSessionGetOutputCount(_session, &num_output_nodes);
116 _outputs = std::make_unique<TensorSet>(_allocator.get(), num_output_nodes);
118 for (
size_t i = 0; i < num_output_nodes; ++i)
121 status = OrtSessionGetOutputName(_session, i, _allocator.get(), &output_name);
124 assert(output_name !=
nullptr);
126 std::string name{output_name};
127 _allocator->Free(output_name);
129 OrtTypeInfo *typeinfo;
130 status = OrtSessionGetOutputTypeInfo(_session, i, &typeinfo);
133 const OrtTensorTypeAndShapeInfo *tensor_info = OrtCastTypeInfoToTensorInfo(typeinfo);
134 ONNXTensorElementDataType type = OrtGetTensorElementType(tensor_info);
136 uint32_t num_dims = OrtGetNumOfDimensions(tensor_info);
137 std::vector<size_t> dims(num_dims);
138 OrtGetDimensions(tensor_info, (int64_t *)dims.data(), num_dims);
145 for (uint32_t j = 0; j < num_dims; ++j)
152 OrtReleaseTypeInfo(typeinfo);
154 _outputs->set(i, name, type, dims);
162 auto pinput_names = _inputs->names();
163 std::vector<const char *> input_names(pinput_names.size());
164 for (
size_t i = 0; i < pinput_names.size(); ++i)
166 input_names[i] = pinput_names[i].c_str();
169 auto poutput_names = _outputs->names();
170 std::vector<const char *> output_names(poutput_names.size());
171 for (
size_t i = 0; i < poutput_names.size(); ++i)
173 output_names[i] = poutput_names[i].c_str();
176 status = OrtRun(_session, NULL, input_names.data(), _inputs->tensors().data(), _inputs->size(),
177 output_names.data(), _outputs->size(), _outputs->mutable_tensors().data());