ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir::TensorVariant Class Reference

#include <TensorVariant.h>

Public Member Functions

 TensorVariant (const TensorType &type)
 
 TensorVariant (const TensorType &type, const void *data)
 
 TensorVariant (DataType element_type, const Shape &shape)
 
 TensorVariant (DataType element_type, const Shape &shape, const void *data)
 
 TensorVariant (const TensorVariant &t_old, const Shape &shape)
 Construct a TensorVariant from t_old that has strides with 0 where dim = 1 Used for broadcasting.
 
virtual ~TensorVariant ()=default
 
char * at (const Index &idx) const
 
char * atOffset (int32_t offset) const
 
size_t getOffset (const Index &idx) const
 
const TensorTypegetType () const
 
DataType getElementType () const
 
const ShapegetShape () const
 
DataType getDataType () const
 
size_t getElementSize () const
 

Detailed Description

Definition at line 32 of file TensorVariant.h.

Constructor & Destructor Documentation

◆ TensorVariant() [1/5]

mir::TensorVariant::TensorVariant ( const TensorType type)
explicit

Definition at line 23 of file TensorVariant.cpp.

23 : _type(type), _strides(type.getShape().rank())
24{
25 _element_size = getDataTypeSize(getElementType());
26 std::size_t data_size = getShape().numElements() * _element_size;
27 _data.reset(new char[data_size], std::default_delete<char[]>());
28
29 int stride = 1;
30 for (int d = getShape().rank() - 1; d >= 0; --d)
31 {
32 _strides[d] = stride;
33 stride *= getShape().dim(d);
34 }
35}
int32_t & dim(int32_t axis) noexcept
Definition Shape.h:47
int32_t numElements() const
Definition Shape.cpp:30
const Shape & getShape() const
DataType getElementType() const
type
Definition infer.py:18
std::size_t getDataTypeSize(DataType type)
Definition DataType.h:36

References mir::Shape::dim(), mir::getDataTypeSize(), getElementType(), getShape(), and mir::Shape::numElements().

◆ TensorVariant() [2/5]

mir::TensorVariant::TensorVariant ( const TensorType type,
const void *  data 
)

Definition at line 42 of file TensorVariant.cpp.

42 : TensorVariant(type)
43{
44 std::size_t data_size = getShape().numElements() * _element_size;
45 std::memcpy(_data.get(), data, data_size);
46}
TensorVariant(const TensorType &type)

References getShape(), and mir::Shape::numElements().

◆ TensorVariant() [3/5]

mir::TensorVariant::TensorVariant ( DataType  element_type,
const Shape shape 
)

Definition at line 37 of file TensorVariant.cpp.

38 : TensorVariant(TensorType(element_type, shape))
39{
40}

◆ TensorVariant() [4/5]

mir::TensorVariant::TensorVariant ( DataType  element_type,
const Shape shape,
const void *  data 
)

Definition at line 48 of file TensorVariant.cpp.

49 : TensorVariant(TensorType(element_type, shape), data)
50{
51}

◆ TensorVariant() [5/5]

mir::TensorVariant::TensorVariant ( const TensorVariant t_old,
const Shape shape 
)

Construct a TensorVariant from t_old that has strides with 0 where dim = 1 Used for broadcasting.

Parameters
t_oldTensorVariant to use as base
shapeshape to broadcast to

Definition at line 59 of file TensorVariant.cpp.

60 : _type(t_old.getType().getElementType(), shape), _data(t_old._data),
61 _strides(static_cast<size_t>(shape.rank())), _element_size(t_old._element_size)
62{
63 int axis_old = t_old.getShape().rank() - 1;
64 for (int d = shape.rank() - 1; d >= 0; d--)
65 {
66 if (axis_old == -1)
67 break;
68 if (t_old.getShape().dim(axis_old) != 1)
69 _strides[d] = t_old._strides[axis_old];
70 axis_old--;
71 }
72}

References mir::Shape::dim(), getShape(), and mir::Shape::rank().

◆ ~TensorVariant()

virtual mir::TensorVariant::~TensorVariant ( )
virtualdefault

Member Function Documentation

◆ at()

char * mir::TensorVariant::at ( const Index idx) const
inline

Definition at line 49 of file TensorVariant.h.

49{ return _data.get() + getOffset(idx) * _element_size; }
size_t getOffset(const Index &idx) const

References getOffset().

Referenced by mir::Tensor< T >::at(), mir::Tensor< T >::at(), mir::Tensor< T >::getRegion(), mir_interpreter::FullyConnectedImpl< T >::run(), TEST(), mir::transposeTensor(), and mir2loco::Transformer::visit().

◆ atOffset()

char * mir::TensorVariant::atOffset ( int32_t  offset) const
inline

Definition at line 51 of file TensorVariant.h.

52 {
53 assert(offset >= 0 && offset < getShape().numElements());
54 return _data.get() + offset * _element_size;
55 }
__global uchar * offset(const Image *img, int x, int y)
Definition helpers.h:540
uint32_t numElements(const luci::CircleNode *node)
Definition Utils.cpp:41

References getShape(), and offset().

Referenced by mir::Tensor< T >::atOffset(), mir::Tensor< T >::atOffset(), mir_interpreter::erase(), mir_interpreter::FullyConnectedImpl< uint8_t >::run(), mir_interpreter::Conv2DImpl< T >::run(), and mir_interpreter::DeConv2DImpl< T >::run().

◆ getDataType()

DataType mir::TensorVariant::getDataType ( ) const
inline

Definition at line 72 of file TensorVariant.h.

72{ return _type.getElementType(); }
DataType getElementType() const
Definition TensorType.h:41

References mir::TensorType::getElementType().

◆ getElementSize()

size_t mir::TensorVariant::getElementSize ( ) const
inline

Definition at line 74 of file TensorVariant.h.

74{ return _element_size; }

Referenced by TEST().

◆ getElementType()

◆ getOffset()

size_t mir::TensorVariant::getOffset ( const Index idx) const
inline

Definition at line 57 of file TensorVariant.h.

58 {
59 assert(idx.rank() == getShape().rank());
60 std::size_t offset = 0;
61 for (int i = 0; i < getShape().rank(); ++i)
62 offset += idx.at(i) * _strides[i];
63 return offset;
64 }
int32_t rank() const
Definition Shape.h:43

References mir::Index::at(), getShape(), offset(), mir::Index::rank(), and mir::Shape::rank().

Referenced by at(), and TEST().

◆ getShape()

◆ getType()


The documentation for this class was generated from the following files: