ONE - On-device Neural Engine
Loading...
Searching...
No Matches
TensorContext.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __NNKIT_TENSOR_CONTEXT_H__
18#define __NNKIT_TENSOR_CONTEXT_H__
19
23
24#include <string>
25#include <functional>
26#include <stdexcept>
27#include <cstdint>
28
29namespace nnkit
30{
31
32// NOTE This interface is subject to change.
34{
35 template <typename T>
36 using TypedReader = std::function<void(const TensorContext &, uint32_t n,
38
39 template <typename T>
41 std::function<void(const TensorContext &, uint32_t n, nncc::core::ADT::tensor::Accessor<T> &)>;
42
43 virtual ~TensorContext() = default;
44
45 // The number of tensors that this context provides
46 virtual uint32_t size(void) const = 0;
47
48 // Query on properties of each tensor
49 virtual std::string name(uint32_t n) const = 0;
50 virtual nncc::core::ADT::tensor::Shape shape(uint32_t n) const = 0;
51
52 // TODO Support generic byte tensor
53 // TODO Support typed tensor for primitive types such as half(fp16), double(fp64), int8(s8),
54 // uint8(u8), uint(u32)
55
56 // Float (fp32) tensor support
57 virtual bool isFloatTensor(uint32_t n) const
58 {
59 throw std::runtime_error("This method should be overriden");
60 }
61
62 virtual void getMutableFloatTensor(uint32_t n, const TypedAccessor<float> &cb)
63 {
64 throw std::runtime_error("This method should be overriden");
65 }
66
67 virtual void getConstFloatTensor(uint32_t n, const TypedReader<float> &cb) const
68 {
69 throw std::runtime_error("This method should be overriden");
70 }
71
72 // S32
73 virtual bool isS32Tensor(uint32_t n) const
74 {
75 throw std::runtime_error("This method should be overriden");
76 }
77
78 virtual void getMutableS32Tensor(uint32_t n, const TypedAccessor<int32_t> &cb)
79 {
80 throw std::runtime_error("This method should be overriden");
81 }
82
83 virtual void getConstS32Tensor(uint32_t n, const TypedReader<int32_t> &cb) const
84 {
85 throw std::runtime_error("This method should be overriden");
86 }
87};
88
89} // namespace nnkit
90
91#endif // __NNKIT_TENSOR_CONTEXT_H__
virtual void getMutableFloatTensor(uint32_t n, const TypedAccessor< float > &cb)
virtual void getConstS32Tensor(uint32_t n, const TypedReader< int32_t > &cb) const
virtual nncc::core::ADT::tensor::Shape shape(uint32_t n) const =0
std::function< void(const TensorContext &, uint32_t n, const nncc::core::ADT::tensor::Reader< T > &)> TypedReader
virtual std::string name(uint32_t n) const =0
virtual void getConstFloatTensor(uint32_t n, const TypedReader< float > &cb) const
virtual uint32_t size(void) const =0
std::function< void(const TensorContext &, uint32_t n, nncc::core::ADT::tensor::Accessor< T > &)> TypedAccessor
virtual bool isS32Tensor(uint32_t n) const
virtual bool isFloatTensor(uint32_t n) const
virtual void getMutableS32Tensor(uint32_t n, const TypedAccessor< int32_t > &cb)
virtual ~TensorContext()=default