ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnkit::support::onnx::TensorContext Class Referencefinal

#include <TensorContext.h>

Collaboration diagram for nnkit::support::onnx::TensorContext:

Public Member Functions

 TensorContext (TensorSet &tensors)
 
uint32_t size (void) const override
 
std::string name (uint32_t n) const override
 
nncc::core::ADT::tensor::Shape shape (uint32_t n) const override
 
bool isFloatTensor (uint32_t n) const override
 
void getMutableFloatTensor (uint32_t n, const TensorContext::TypedAccessor< float > &f) override
 
void getConstFloatTensor (uint32_t n, const TensorContext::TypedReader< float > &f) const override
 
- Public Member Functions inherited from nnkit::TensorContext
virtual ~TensorContext ()=default
 
virtual bool isS32Tensor (uint32_t n) const
 
virtual void getMutableS32Tensor (uint32_t n, const TypedAccessor< int32_t > &cb)
 
virtual void getConstS32Tensor (uint32_t n, const TypedReader< int32_t > &cb) const
 

Additional Inherited Members

- Public Types inherited from nnkit::TensorContext
template<typename T >
using TypedReader = std::function< void(const TensorContext &, uint32_t n, const nncc::core::ADT::tensor::Reader< T > &)>
 
template<typename T >
using TypedAccessor = std::function< void(const TensorContext &, uint32_t n, nncc::core::ADT::tensor::Accessor< T > &)>
 

Detailed Description

Definition at line 34 of file TensorContext.h.

Constructor & Destructor Documentation

◆ TensorContext()

nnkit::support::onnx::TensorContext::TensorContext ( TensorSet tensors)
inline

Definition at line 37 of file TensorContext.h.

37 : _tensors(tensors)
38 {
39 // DO NOTHING
40 }

Member Function Documentation

◆ getConstFloatTensor()

void nnkit::support::onnx::TensorContext::getConstFloatTensor ( uint32_t  n,
const TensorContext::TypedReader< float > &  f 
) const
inlineoverridevirtual

Reimplemented from nnkit::TensorContext.

Definition at line 87 of file TensorContext.h.

88 {
89 if (_tensors.type(n) != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
90 {
91 throw std::runtime_error{"type mismatch"};
92 }
93
96
97 Status status;
98
99 OrtValue *base = _tensors.mutable_tensor(n);
100 float *data;
101
102 status = OrtGetTensorMutableData(base, (void **)&data);
103 status.throwOnError();
104
105 auto overlay = make_overlay<float, LexicalLayout>(shape(n), data);
106
107 f(*this, n, overlay);
108 }
nncc::core::ADT::tensor::Shape shape(uint32_t n) const override
ONNXTensorElementDataType type(size_t index)
Definition TensorSet.h:74
OrtValue * mutable_tensor(size_t index)
Definition TensorSet.h:81
Overlay< T > make_overlay(const Shape &shape, T *base)
Definition Overlay.h:48

References nncc::core::ADT::tensor::make_overlay(), nnkit::support::onnx::TensorSet::mutable_tensor(), shape(), nnkit::support::onnx::Status::throwOnError(), and nnkit::support::onnx::TensorSet::type().

◆ getMutableFloatTensor()

void nnkit::support::onnx::TensorContext::getMutableFloatTensor ( uint32_t  n,
const TensorContext::TypedAccessor< float > &  f 
)
inlineoverridevirtual

Reimplemented from nnkit::TensorContext.

Definition at line 64 of file TensorContext.h.

65 {
66 if (_tensors.type(n) != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
67 {
68 throw std::runtime_error{"type mismatch"};
69 }
70
73
74 Status status;
75
76 OrtValue *base = _tensors.mutable_tensor(n);
77 float *data;
78
79 status = OrtGetTensorMutableData(base, (void **)&data);
80 status.throwOnError();
81
82 auto overlay = make_overlay<float, LexicalLayout>(shape(n), data);
83
84 f(*this, n, overlay);
85 }

References nncc::core::ADT::tensor::make_overlay(), nnkit::support::onnx::TensorSet::mutable_tensor(), shape(), nnkit::support::onnx::Status::throwOnError(), and nnkit::support::onnx::TensorSet::type().

◆ isFloatTensor()

bool nnkit::support::onnx::TensorContext::isFloatTensor ( uint32_t  n) const
inlineoverridevirtual

Reimplemented from nnkit::TensorContext.

Definition at line 59 of file TensorContext.h.

60 {
61 return (_tensors.type(n) == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
62 }

References nnkit::support::onnx::TensorSet::type().

◆ name()

std::string nnkit::support::onnx::TensorContext::name ( uint32_t  n) const
inlineoverridevirtual

Implements nnkit::TensorContext.

Definition at line 44 of file TensorContext.h.

44{ return std::string{_tensors.name(n)}; }
const char * name(size_t index)
Definition TensorSet.h:71

References nnkit::support::onnx::TensorSet::name().

◆ shape()

nncc::core::ADT::tensor::Shape nnkit::support::onnx::TensorContext::shape ( uint32_t  n) const
inlineoverridevirtual

Implements nnkit::TensorContext.

Definition at line 46 of file TensorContext.h.

47 {
48 const std::vector<size_t> &dims = _tensors.dim(n);
49
51 shape.resize(dims.size());
52 for (size_t i = 0; i < dims.size(); ++i)
53 {
54 shape.dim(i) = dims[i];
55 }
56 return shape;
57 }
uint32_t & dim(uint32_t axis)
Definition Shape.cpp:42
Shape & resize(uint32_t size)
Definition Shape.cpp:36
const std::vector< size_t > & dim(size_t index)
Definition TensorSet.h:76

References nnkit::support::onnx::TensorSet::dim(), nncc::core::ADT::tensor::Shape::dim(), nncc::core::ADT::tensor::Shape::resize(), and shape().

Referenced by RandomDataGenerator.RandomDataGenerator::_gen_float32(), RandomDataGenerator.RandomDataGenerator::_gen_int16(), RandomDataGenerator.RandomDataGenerator::_gen_uint8(), getConstFloatTensor(), getMutableFloatTensor(), and shape().

◆ size()

uint32_t nnkit::support::onnx::TensorContext::size ( void  ) const
inlineoverridevirtual

Implements nnkit::TensorContext.

Definition at line 42 of file TensorContext.h.

42{ return _tensors.size(); }

References nnkit::support::onnx::TensorSet::size().


The documentation for this class was generated from the following file: