ONE - On-device Neural Engine
Loading...
Searching...
No Matches
TensorVariant.cpp
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3
*
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
7
*
8
* http://www.apache.org/licenses/LICENSE-2.0
9
*
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
15
*/
16
17
#include "
mir/TensorVariant.h
"
18
#include <cstring>
19
20
namespace
mir
21
{
22
23
TensorVariant::TensorVariant
(
const
TensorType
&type) : _type(type), _strides(type.getShape().rank())
24
{
25
_element_size =
getDataTypeSize
(
getElementType
());
26
std::size_t data_size =
getShape
().
numElements
() * _element_size;
27
_data.reset(
new
char
[data_size], std::default_delete<
char
[]>());
28
29
int
stride = 1;
30
for
(
int
d =
getShape
().rank() - 1; d >= 0; --d)
31
{
32
_strides[d] = stride;
33
stride *=
getShape
().
dim
(d);
34
}
35
}
36
37
TensorVariant::TensorVariant
(
DataType
element_type,
const
Shape
&shape)
38
:
TensorVariant
(
TensorType
(element_type, shape))
39
{
40
}
41
42
TensorVariant::TensorVariant
(
const
TensorType
&type,
const
void
*data) :
TensorVariant
(type)
43
{
44
std::size_t data_size =
getShape
().
numElements
() * _element_size;
45
std::memcpy(_data.get(), data, data_size);
46
}
47
48
TensorVariant::TensorVariant
(
DataType
element_type,
const
Shape
&shape,
const
void
*data)
49
:
TensorVariant
(
TensorType
(element_type, shape), data)
50
{
51
}
52
59
TensorVariant::TensorVariant
(
const
TensorVariant
&t_old,
const
Shape
&shape)
60
: _type(t_old.
getType
().getElementType(), shape), _data(t_old._data),
61
_strides(static_cast<size_t>(shape.rank())), _element_size(t_old._element_size)
62
{
63
int
axis_old = t_old.
getShape
().
rank
() - 1;
64
for
(
int
d = shape.
rank
() - 1; d >= 0; d--)
65
{
66
if
(axis_old == -1)
67
break
;
68
if
(t_old.
getShape
().
dim
(axis_old) != 1)
69
_strides[d] = t_old._strides[axis_old];
70
axis_old--;
71
}
72
}
73
74
}
// namespace mir
TensorVariant.h
mir::Shape
Definition
Shape.h:31
mir::Shape::dim
int32_t & dim(int32_t axis) noexcept
Definition
Shape.h:47
mir::Shape::numElements
int32_t numElements() const
Definition
Shape.cpp:30
mir::Shape::rank
int32_t rank() const
Definition
Shape.h:43
mir::TensorType
Definition
TensorType.h:28
mir::TensorVariant
Definition
TensorVariant.h:33
mir::TensorVariant::getShape
const Shape & getShape() const
Definition
TensorVariant.h:69
mir::TensorVariant::getElementType
DataType getElementType() const
Definition
TensorVariant.h:68
mir::TensorVariant::TensorVariant
TensorVariant(const TensorType &type)
Definition
TensorVariant.cpp:23
mir
Definition
Attributes.h:25
mir::DataType
DataType
Definition
DataType.h:27
mir::getDataTypeSize
std::size_t getDataTypeSize(DataType type)
Definition
DataType.h:36
getType
NNFW_TYPE getType(const char *type="")
Definition
nnfw_api_wrapper.cc:69
compiler
mir
src
TensorVariant.cpp
Generated by
1.9.8