ONE - On-device Neural Engine
Loading...
Searching...
No Matches
TensorVariant.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "mir/TensorVariant.h"
18#include <cstring>
19
20namespace mir
21{
22
23TensorVariant::TensorVariant(const TensorType &type) : _type(type), _strides(type.getShape().rank())
24{
25 _element_size = getDataTypeSize(getElementType());
26 std::size_t data_size = getShape().numElements() * _element_size;
27 _data.reset(new char[data_size], std::default_delete<char[]>());
28
29 int stride = 1;
30 for (int d = getShape().rank() - 1; d >= 0; --d)
31 {
32 _strides[d] = stride;
33 stride *= getShape().dim(d);
34 }
35}
36
37TensorVariant::TensorVariant(DataType element_type, const Shape &shape)
38 : TensorVariant(TensorType(element_type, shape))
39{
40}
41
42TensorVariant::TensorVariant(const TensorType &type, const void *data) : TensorVariant(type)
43{
44 std::size_t data_size = getShape().numElements() * _element_size;
45 std::memcpy(_data.get(), data, data_size);
46}
47
48TensorVariant::TensorVariant(DataType element_type, const Shape &shape, const void *data)
49 : TensorVariant(TensorType(element_type, shape), data)
50{
51}
52
60 : _type(t_old.getType().getElementType(), shape), _data(t_old._data),
61 _strides(static_cast<size_t>(shape.rank())), _element_size(t_old._element_size)
62{
63 int axis_old = t_old.getShape().rank() - 1;
64 for (int d = shape.rank() - 1; d >= 0; d--)
65 {
66 if (axis_old == -1)
67 break;
68 if (t_old.getShape().dim(axis_old) != 1)
69 _strides[d] = t_old._strides[axis_old];
70 axis_old--;
71 }
72}
73
74} // namespace mir
int32_t & dim(int32_t axis) noexcept
Definition Shape.h:47
int32_t numElements() const
Definition Shape.cpp:30
int32_t rank() const
Definition Shape.h:43
const Shape & getShape() const
DataType getElementType() const
TensorVariant(const TensorType &type)
DataType
Definition DataType.h:27
std::size_t getDataTypeSize(DataType type)
Definition DataType.h:36
NNFW_TYPE getType(const char *type="")