ONE - On-device Neural Engine
|
#include <TensorVariant.h>
Public Member Functions | |
TensorVariant (const TensorType &type) | |
TensorVariant (const TensorType &type, const void *data) | |
TensorVariant (DataType element_type, const Shape &shape) | |
TensorVariant (DataType element_type, const Shape &shape, const void *data) | |
TensorVariant (const TensorVariant &t_old, const Shape &shape) | |
Construct a TensorVariant from t_old that has strides with 0 where dim = 1 Used for broadcasting. | |
virtual | ~TensorVariant ()=default |
char * | at (const Index &idx) const |
char * | atOffset (int32_t offset) const |
size_t | getOffset (const Index &idx) const |
const TensorType & | getType () const |
DataType | getElementType () const |
const Shape & | getShape () const |
DataType | getDataType () const |
size_t | getElementSize () const |
Definition at line 32 of file TensorVariant.h.
|
explicit |
Definition at line 23 of file TensorVariant.cpp.
References mir::Shape::dim(), mir::getDataTypeSize(), getElementType(), getShape(), and mir::Shape::numElements().
mir::TensorVariant::TensorVariant | ( | const TensorType & | type, |
const void * | data | ||
) |
Definition at line 42 of file TensorVariant.cpp.
References getShape(), and mir::Shape::numElements().
Definition at line 37 of file TensorVariant.cpp.
Definition at line 48 of file TensorVariant.cpp.
mir::TensorVariant::TensorVariant | ( | const TensorVariant & | t_old, |
const Shape & | shape | ||
) |
Construct a TensorVariant from t_old that has strides with 0 where dim = 1 Used for broadcasting.
t_old | TensorVariant to use as base |
shape | shape to broadcast to |
Definition at line 59 of file TensorVariant.cpp.
References mir::Shape::dim(), getShape(), and mir::Shape::rank().
|
virtualdefault |
|
inline |
Definition at line 49 of file TensorVariant.h.
References getOffset().
Referenced by mir::Tensor< T >::at(), mir::Tensor< T >::at(), mir::Tensor< T >::getRegion(), mir_interpreter::FullyConnectedImpl< T >::run(), TEST(), mir::transposeTensor(), and mir2loco::Transformer::visit().
|
inline |
Definition at line 51 of file TensorVariant.h.
References getShape(), and offset().
Referenced by mir::Tensor< T >::atOffset(), mir::Tensor< T >::atOffset(), mir_interpreter::erase(), mir_interpreter::FullyConnectedImpl< uint8_t >::run(), mir_interpreter::Conv2DImpl< T >::run(), and mir_interpreter::DeConv2DImpl< T >::run().
|
inline |
|
inline |
|
inline |
Definition at line 68 of file TensorVariant.h.
References mir::TensorType::getElementType().
Referenced by mir_interpreter::Add(), mir_onnx::constantToShape(), mir_interpreter::Div(), mir_interpreter::Equal(), mir_interpreter::Fill(), mir_interpreter::FullyConnected(), mir_interpreter::Greater(), mir_interpreter::Less(), mir_interpreter::Max(), mir_interpreter::Mul(), mir_interpreter::GatherByT< T >::run(), mir_interpreter::Sub(), and TensorVariant().
|
inline |
Definition at line 57 of file TensorVariant.h.
References mir::Index::at(), getShape(), offset(), mir::Index::rank(), and mir::Shape::rank().
|
inline |
Definition at line 69 of file TensorVariant.h.
References mir::TensorType::getShape().
Referenced by atOffset(), mir_interpreter::erase(), getOffset(), mir::Tensor< T >::getShape(), mir_interpreter::AvgPool2DImpl< T >::run(), mir_interpreter::DepthwiseConv2DImpl< uint8_t >::run(), mir_interpreter::SliceImpl< T >::run(), mir_interpreter::TransposeImpl< T >::run(), mir_interpreter::FullyConnectedImpl< T >::run(), mir_interpreter::FullyConnectedImpl< uint8_t >::run(), mir_interpreter::GatherImpl< T, IndicesT >::run(), mir_interpreter::Conv2DImpl< T >::run(), mir_interpreter::Conv2DImpl< uint8_t >::run(), mir_interpreter::DeConv2DImpl< T >::run(), mir_interpreter::AddImpl< T >::run(), mir_interpreter::AddImpl< uint8_t >::run(), mir_interpreter::DivImpl< T >::run(), mir_interpreter::EqualImpl< T >::run(), mir_interpreter::GreaterImpl< T >::run(), mir_interpreter::LessImpl< T >::run(), mir_interpreter::MaxImpl< T >::run(), mir_interpreter::MulImpl< T >::run(), mir_interpreter::SubImpl< T >::run(), mir_interpreter::FillImpl< T >::run(), TensorVariant(), TensorVariant(), TensorVariant(), TEST(), and nnc::AclCppOpGenerator::visit().
|
inline |
Definition at line 66 of file TensorVariant.h.
Referenced by mir_interpreter::DepthwiseConv2DImpl< uint8_t >::run(), mir_interpreter::FullyConnectedImpl< uint8_t >::run(), mir_interpreter::Conv2DImpl< uint8_t >::run(), and mir_interpreter::AddImpl< uint8_t >::run().