ONE - On-device Neural Engine
Loading...
Searching...
No Matches
HardSwish.cpp
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3
*
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
7
*
8
* http://www.apache.org/licenses/LICENSE-2.0
9
*
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
15
*/
16
17
#include "
HardSwish.h
"
18
#include "
Common.h
"
19
20
namespace
mir_interpreter
21
{
22
23
template
<
typename
T>
struct
HardSwishImpl
24
{
25
static
void
run
(
const
mir::TensorVariant
&input,
mir::TensorVariant
&result);
26
};
27
28
template
<
typename
T>
29
void
HardSwishImpl<T>::run
(
const
mir::TensorVariant
&input,
mir::TensorVariant
&result)
30
{
31
auto
output_data =
reinterpret_cast<
T *
>
(result.atOffset(0));
32
auto
input_data =
reinterpret_cast<
T *
>
(input.atOffset(0));
33
auto
in_end = input_data + input.getShape().numElements();
34
for
(; input_data < in_end; input_data++, output_data++)
35
{
36
const
auto
in = *input_data;
37
*output_data = in * std::min<T>(6.f, std::max<T>(0.f, in + 3.f)) / 6.f;
38
}
39
}
40
41
template
<>
struct
HardSwishImpl
<uint8_t>
42
{
43
static
void
run
(
const
mir::TensorVariant
&input,
mir::TensorVariant
&result)
44
{
45
throw
std::runtime_error{
"NYI"
};
46
}
47
};
48
49
void
HardSwish
(
const
mir::TensorVariant
&input,
mir::TensorVariant
&result)
50
{
51
dispatch<HardSwishImpl>(input.getElementType(), input, result);
52
}
53
54
}
// namespace mir_interpreter
mir::TensorVariant
Definition
TensorVariant.h:33
mir_interpreter
Definition
MirInterpreter.h:27
mir_interpreter::HardSwish
void HardSwish(const mir::TensorVariant &input, mir::TensorVariant &result)
Definition
HardSwish.cpp:49
Common.h
mir_interpreter::HardSwishImpl< uint8_t >::run
static void run(const mir::TensorVariant &input, mir::TensorVariant &result)
Definition
HardSwish.cpp:43
mir_interpreter::HardSwishImpl
Definition
HardSwish.cpp:24
mir_interpreter::HardSwishImpl::run
static void run(const mir::TensorVariant &input, mir::TensorVariant &result)
Definition
HardSwish.cpp:29
HardSwish.h
compiler
mir-interpreter
src
ops
HardSwish.cpp
Generated by
1.9.8