ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Nodes.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __LOCO_IR_NODES_H__
18#define __LOCO_IR_NODES_H__
19
20#include "loco/IR/Node.h"
21#include "loco/IR/Use.h"
22#include "loco/IR/Domain.h"
23#include "loco/IR/DataType.h"
25#include "loco/IR/Dimension.h"
26#include "loco/IR/Window.h"
27#include "loco/IR/Stride.h"
28#include "loco/IR/Padding2D.h"
29#include "loco/IR/PaddingND.h"
30#include "loco/IR/TensorAxis.h"
33#include "loco/IR/FilterCodec.h"
35#include "loco/IR/MatrixCodec.h"
36#include "loco/IR/NodeMixins.h"
40
41namespace loco
42{
43
44class Graph;
45class GraphInput;
46class GraphOutput;
47
51class Push /* to user */ final
52 : public CanonicalNodeDef<CanonicalOpcode::Push, FixedArity<1>::Mixin>
53{
54public:
55 Push() = default;
56
57public:
58 Node *from(void) const { return at(0)->node(); }
59 void from(Node *node) { at(0)->node(node); }
60
61public:
62 void index(const GraphOutputIndex &index);
63
72 GraphOutputIndex index(void) const;
73
79 bool indexed(void) const { return _index != -1; }
80
81private:
82 int64_t _index = -1; // Uninitialized
83};
84
85void link(GraphOutput *, Push *push);
86
88Push *push_node(Graph *g, const GraphOutputIndex &index);
89
93class Pull /* from user */ final
94 : public CanonicalNodeDef<CanonicalOpcode::Pull, FixedArity<0>::Mixin,
95 With<NodeTrait::TensorShape>::Mixin>
96{
97public:
98 Pull() = default;
99
100public:
101 void index(const GraphInputIndex &index);
102
111 GraphInputIndex index(void) const;
112
118 bool indexed(void) const { return _index != -1; }
119
120public:
121 void dtype(const DataType &d);
122 DataType dtype(void) const;
123
124private:
125 int64_t _index = -1; // Uninitialized
126
132 DataType _dtype = DataType::Unknown;
133};
134
135void link(GraphInput *, Pull *pull);
136
138Pull *pull_node(Graph *g, const GraphInputIndex &index);
139
145class Forward final : public CanonicalNodeDef<CanonicalOpcode::Forward, FixedArity<1>::Mixin>
146{
147public:
148 Forward() = default;
149
150public:
151 Node *input(void) const { return at(0)->node(); }
152 void input(Node *node) { at(0)->node(node); }
153};
154
158class ReLU final : public CanonicalNodeDef<CanonicalOpcode::ReLU, FixedArity<1>::Mixin>
159{
160public:
161 ReLU() = default;
162
163public:
164 Node *input(void) const { return at(0)->node(); }
165 void input(Node *node) { at(0)->node(node); }
166};
167
171class ReLU6 final : public CanonicalNodeDef<CanonicalOpcode::ReLU6, FixedArity<1>::Mixin>
172{
173public:
174 ReLU6() = default;
175
176public:
177 Node *input(void) const { return at(0)->node(); }
178 void input(Node *node) { at(0)->node(node); }
179};
180
184class Tanh final : public CanonicalNodeDef<CanonicalOpcode::Tanh, FixedArity<1>::Mixin>
185{
186public:
187 Tanh() = default;
188
189public:
190 Node *input(void) const { return at(0)->node(); }
191 void input(Node *node) { at(0)->node(node); }
192};
193
215class ConstGen final
216 : public CanonicalNodeDef<CanonicalOpcode::ConstGen, FixedArity<0>::Mixin,
217 With<NodeTrait::DataType>::Mixin, With<NodeTrait::TensorShape>::Mixin>
218{
219public:
220 ConstGen() = default;
221
222public:
227 template <DataType DT> uint32_t size(void) const;
228
232 template <DataType DT> void size(uint32_t size);
233
238 template <DataType DT> const typename DataTypeImpl<DT>::Type &at(uint32_t n) const;
239
244 template <DataType DT> typename DataTypeImpl<DT>::Type &at(uint32_t n);
245
246private:
248 std::vector<uint8_t> _data;
249};
250
304class MaxPool2D final : public CanonicalNodeDef<CanonicalOpcode::MaxPool2D, FixedArity<1>::Mixin>
305{
306public:
307 Node *ifm(void) const { return at(0)->node(); }
308 void ifm(Node *node) { at(0)->node(node); }
309
310public:
311 const Padding2D *pad(void) const { return &_pad; }
312 Padding2D *pad(void) { return &_pad; }
313
314public:
315 const Window<2> *window(void) const { return &_window; }
316 Window<2> *window(void) { return &_window; }
317
318public:
319 const Stride<2> *stride(void) const { return &_stride; }
320 Stride<2> *stride(void) { return &_stride; }
321
322private:
323 // Pad
324 Padding2D _pad;
325 // Window
326 Window<2> _window;
327 // Stride
328 Stride<2> _stride;
329};
330
336class AvgPool2D final : public CanonicalNodeDef<CanonicalOpcode::AvgPool2D, FixedArity<1>::Mixin>
337{
338public:
339 enum class Convention
340 {
341 Unknown,
342 // Use the number of elements in each receptive field as a divisor
343 Full,
344 // Use the number of valid (non-padding) elements in each receptive field as a divisor
345 Valid
346 };
347
348public:
349 Node *ifm(void) const { return at(0)->node(); }
350 void ifm(Node *node) { at(0)->node(node); }
351
352public:
353 Convention convention(void) const { return _convention; }
354 void convention(const Convention &convention) { _convention = convention; }
355
356public:
357 const Padding2D *pad(void) const { return &_pad; }
358 Padding2D *pad(void) { return &_pad; }
359
360public:
361 const Window<2> *window(void) const { return &_window; }
362 Window<2> *window(void) { return &_window; }
363
364public:
365 const Stride<2> *stride(void) const { return &_stride; }
366 Stride<2> *stride(void) { return &_stride; }
367
368private:
369 Convention _convention = Convention::Unknown;
370 Padding2D _pad;
371 Window<2> _window;
372 Stride<2> _stride;
373};
374
378class FeatureEncode final
379 : public CanonicalNodeDef<CanonicalOpcode::FeatureEncode, FixedArity<1>::Mixin>
380{
381public:
382 Node *input(void) const { return at(0)->node(); }
383 void input(Node *node) { at(0)->node(node); }
384
385public:
386 FeatureEncoder *encoder(void) const { return _enc.get(); }
387 void encoder(std::unique_ptr<FeatureEncoder> &&enc) { _enc = std::move(enc); }
388
389private:
391 std::unique_ptr<FeatureEncoder> _enc{nullptr};
392};
393
397class FeatureDecode final
398 : public CanonicalNodeDef<CanonicalOpcode::FeatureDecode, FixedArity<1>::Mixin>
399{
400public:
401 Node *input(void) const { return at(0)->node(); }
402 void input(Node *node) { at(0)->node(node); }
403
404public:
405 FeatureDecoder *decoder(void) const { return _dec.get(); }
406 void decoder(std::unique_ptr<FeatureDecoder> &&dec) { _dec = std::move(dec); }
407
408private:
410 std::unique_ptr<FeatureDecoder> _dec{nullptr};
411};
412
416class FilterEncode final
417 : public CanonicalNodeDef<CanonicalOpcode::FilterEncode, FixedArity<1>::Mixin>
418{
419public:
420 Node *input(void) const { return at(0)->node(); }
421 void input(Node *node) { at(0)->node(node); }
422
423public:
424 FilterEncoder *encoder(void) const { return _enc.get(); }
425 void encoder(std::unique_ptr<FilterEncoder> &&enc) { _enc = std::move(enc); }
426
427private:
429 std::unique_ptr<FilterEncoder> _enc{nullptr};
430};
431
435class FilterDecode final
436 : public CanonicalNodeDef<CanonicalOpcode::FilterDecode, FixedArity<1>::Mixin>
437{
438public:
439 Node *input(void) const { return at(0)->node(); }
440 void input(Node *node) { at(0)->node(node); }
441
442public:
443 FilterDecoder *decoder(void) const { return _dec.get(); }
444 void decoder(std::unique_ptr<FilterDecoder> &&dec) { _dec = std::move(dec); }
445
446private:
448 std::unique_ptr<FilterDecoder> _dec{nullptr};
449};
450
455 : public CanonicalNodeDef<CanonicalOpcode::DepthwiseFilterEncode, FixedArity<1>::Mixin>
456{
457public:
458 Node *input(void) const { return at(0)->node(); }
459 void input(Node *node) { at(0)->node(node); }
460
461public:
462 DepthwiseFilterEncoder *encoder(void) const { return _enc.get(); }
463 void encoder(std::unique_ptr<DepthwiseFilterEncoder> &&enc) { _enc = std::move(enc); }
464
465private:
467 std::unique_ptr<DepthwiseFilterEncoder> _enc{nullptr};
468};
469
474 : public CanonicalNodeDef<CanonicalOpcode::DepthwiseFilterDecode, FixedArity<1>::Mixin>
475{
476public:
477 Node *input(void) const { return at(0)->node(); }
478 void input(Node *node) { at(0)->node(node); }
479
480public:
481 DepthwiseFilterDecoder *decoder(void) const { return _dec.get(); }
482 void decoder(std::unique_ptr<DepthwiseFilterDecoder> &&dec) { _dec = std::move(dec); }
483
484private:
486 std::unique_ptr<DepthwiseFilterDecoder> _dec{nullptr};
487};
488
489enum class ReshapeType
490{
491 Fixed, // shape is known at compile time
492 // Add another type for a case when shape is not known at compile time
493};
494
495template <ReshapeType RT> class Reshape;
496
513template <>
515 : public CanonicalNodeDef<CanonicalOpcode::FixedReshape, FixedArity<1>::Mixin,
516 With<NodeTrait::TensorShape>::Mixin>
517{
518public:
519 Node *input(void) const { return at(0)->node(); }
520 void input(Node *node) { at(0)->node(node); }
521};
522
524
531class TensorConcat final
532 : public CanonicalNodeDef<CanonicalOpcode::TensorConcat, FixedArity<2>::Mixin>
533{
534public:
535 Node *lhs(void) const { return at(0)->node(); }
536 void lhs(Node *node) { at(0)->node(node); }
537
538 Node *rhs(void) const { return at(1)->node(); }
539 void rhs(Node *node) { at(1)->node(node); }
540
541public:
542 uint32_t axis(void) const { return _axis; }
543 void axis(uint32_t val) { _axis = val; }
544
545private:
546 // Axis
547 uint32_t _axis{0};
548};
549
553class Conv2D final : public CanonicalNodeDef<CanonicalOpcode::Conv2D, FixedArity<2>::Mixin>
554{
555public:
556 Node *ifm(void) const { return at(0)->node(); }
557 void ifm(Node *node) { at(0)->node(node); }
558
559 Node *ker(void) const { return at(1)->node(); }
560 void ker(Node *node) { at(1)->node(node); }
561
562public:
563 const Padding2D *pad(void) const { return &_pad; }
564 Padding2D *pad(void) { return &_pad; }
565
566public:
567 const Stride<2> *stride(void) const { return &_stride; }
568 Stride<2> *stride(void) { return &_stride; }
569
570private:
571 Padding2D _pad;
572 Stride<2> _stride;
573
574 // TODO Support "Dilation"
575};
576
581 : public CanonicalNodeDef<CanonicalOpcode::DepthwiseConv2D, FixedArity<2>::Mixin>
582{
583public:
584 Node *ifm(void) const { return at(0)->node(); }
585 void ifm(Node *node) { at(0)->node(node); }
586
587 Node *ker(void) const { return at(1)->node(); }
588 void ker(Node *node) { at(1)->node(node); }
589
590public:
591 const Padding2D *pad(void) const { return &_pad; }
592 Padding2D *pad(void) { return &_pad; }
593
594public:
595 const Stride<2> *stride(void) const { return &_stride; }
596 Stride<2> *stride(void) { return &_stride; }
597
598private:
599 Padding2D _pad;
600 Stride<2> _stride;
601
602 // TODO Support "Dilation"
603};
604
608enum class ReduceFunc
609{
610 Mean, // ReduceMean
611 // TODO Support other reduce operations
612};
613
618class TensorReduce final
619 : public CanonicalNodeDef<CanonicalOpcode::TensorReduce, FixedArity<1>::Mixin>
620{
621public:
622 Node *input(void) const { return at(0)->node(); }
623 void input(Node *node) { at(0)->node(node); }
624
625public:
626 const TensorAxisSet *axes(void) const { return &_axes; }
627 TensorAxisSet *axes(void) { return &_axes; }
628
629public:
630 ReduceFunc func(void) const { return _func; }
631 void func(ReduceFunc func) { _func = func; }
632
633private:
634 TensorAxisSet _axes;
636};
637
687 : public CanonicalNodeDef<CanonicalOpcode::TransposedConv2D, FixedArity<2>::Mixin>
688{
689public:
690 Node *ifm(void) const { return at(0)->node(); }
691 void ifm(Node *node) { at(0)->node(node); }
692
693 Node *ker(void) const { return at(1)->node(); }
694 void ker(Node *node) { at(1)->node(node); }
695
696public:
697 const Padding2D *pad(void) const { return &_pad; }
698 Padding2D *pad(void) { return &_pad; }
699
700public:
701 const Stride<2> *stride(void) const { return &_stride; }
702 Stride<2> *stride(void) { return &_stride; }
703
704private:
705 Padding2D _pad;
706 Stride<2> _stride;
707
708 // TODO Support "Dilation"
709};
710
714template <Domain D> class Softmax;
715
719template <>
720class Softmax<Domain::Tensor> final
721 : public CanonicalNodeDef<CanonicalOpcode::TensorSoftmax, FixedArity<1>::Mixin>
722{
723public:
724 Softmax() = default;
725
726public:
727 Node *input(void) const { return at(0)->node(); }
728 void input(Node *node) { return at(0)->node(node); }
729
730 uint32_t axis(void) const { return _axis; }
731 void axis(uint32_t axis) { _axis = axis; }
732
733private:
734 uint32_t _axis = 0;
735};
736
738
742class BiasDecode final : public CanonicalNodeDef<CanonicalOpcode::BiasDecode, FixedArity<1>::Mixin>
743{
744public:
745 BiasDecode() = default;
746
747public:
748 Node *input(void) const { return at(0)->node(); }
749 void input(Node *node) { at(0)->node(node); }
750};
751
757class BiasEncode final : public CanonicalNodeDef<CanonicalOpcode::BiasEncode, FixedArity<1>::Mixin>
758{
759public:
760 BiasEncode() = default;
761
762public:
763 Node *input(void) const { return at(0)->node(); }
764 void input(Node *node) { at(0)->node(node); }
765};
766
770template <Domain D> class BiasAdd;
771
778template <>
779class BiasAdd<Domain::Tensor> final
780 : public CanonicalNodeDef<CanonicalOpcode::TensorBiasAdd, FixedArity<2>::Mixin>
781{
782public:
783 BiasAdd() = default;
784
785public:
786 Node *value(void) const { return at(0)->node(); }
787 void value(Node *node) { return at(0)->node(node); }
788
789 Node *bias(void) const { return at(1)->node(); }
790 void bias(Node *node) { return at(1)->node(node); }
791
792 uint32_t axis(void) const { return _axis; }
793 void axis(uint32_t axis) { _axis = axis; }
794
795private:
796 uint32_t _axis = 0;
797};
798
799//
800// Alias for external users
801//
802// loco::TensorBiasAdd
803// vs.
804// loco::BiasAdd<loco::Domain::Tensor>
805//
807
814template <>
815class BiasAdd<Domain::Feature> final
816 : public CanonicalNodeDef<CanonicalOpcode::FeatureBiasAdd, FixedArity<2>::Mixin>
817{
818public:
819 BiasAdd() = default;
820
821public:
822 Node *value(void) const { return at(0)->node(); }
823 void value(Node *node) { return at(0)->node(node); }
824
825 Node *bias(void) const { return at(1)->node(); }
826 void bias(Node *node) { return at(1)->node(node); }
827};
828
830
851 : public CanonicalNodeDef<CanonicalOpcode::TensorConstantPad, FixedArity<2>::Mixin>
852{
853public:
854 Node *input(void) const { return at(0)->node(); }
855 void input(Node *node) { at(0)->node(node); }
856
857 Node *constant(void) const { return at(1)->node(); }
858 void constant(Node *node) { at(1)->node(node); }
859
860public:
861 const PaddingND *padding(void) const { return &_padding; }
862 PaddingND *padding(void) { return &_padding; }
863
864private:
865 PaddingND _padding;
866};
867
871class EltwiseAdd final : public CanonicalNodeDef<CanonicalOpcode::EltwiseAdd, FixedArity<2>::Mixin>
872{
873public:
874 EltwiseAdd() = default;
875
876public:
877 Node *lhs(void) const { return at(0)->node(); }
878 void lhs(Node *node) { return at(0)->node(node); }
879
880 Node *rhs(void) const { return at(1)->node(); }
881 void rhs(Node *node) { return at(1)->node(node); }
882};
883
889class EltwiseMax final : public CanonicalNodeDef<CanonicalOpcode::EltwiseMax, FixedArity<2>::Mixin>
890{
891public:
892 EltwiseMax() = default;
893
894public:
895 Node *lhs(void) const { return at(0)->node(); }
896 void lhs(Node *node) { return at(0)->node(node); }
897
898 Node *rhs(void) const { return at(1)->node(); }
899 void rhs(Node *node) { return at(1)->node(node); }
900};
901
905class EltwiseMul final : public CanonicalNodeDef<CanonicalOpcode::EltwiseMul, FixedArity<2>::Mixin>
906{
907public:
908 EltwiseMul() = default;
909
910public:
911 Node *lhs(void) const { return at(0)->node(); }
912 void lhs(Node *node) { return at(0)->node(node); }
913
914 Node *rhs(void) const { return at(1)->node(); }
915 void rhs(Node *node) { return at(1)->node(node); }
916};
917
921class EltwiseSub final : public CanonicalNodeDef<CanonicalOpcode::EltwiseSub, FixedArity<2>::Mixin>
922{
923public:
924 EltwiseSub() = default;
925
926public:
927 Node *lhs(void) const { return at(0)->node(); }
928 void lhs(Node *node) { return at(0)->node(node); }
929
930 Node *rhs(void) const { return at(1)->node(); }
931 void rhs(Node *node) { return at(1)->node(node); }
932};
933
937class EltwiseDiv final : public CanonicalNodeDef<CanonicalOpcode::EltwiseDiv, FixedArity<2>::Mixin>
938{
939public:
940 EltwiseDiv() = default;
941
942public:
943 Node *lhs(void) const { return at(0)->node(); }
944 void lhs(Node *node) { return at(0)->node(node); }
945
946 Node *rhs(void) const { return at(1)->node(); }
947 void rhs(Node *node) { return at(1)->node(node); }
948};
949
953class EltwiseSqrt final
954 : public CanonicalNodeDef<CanonicalOpcode::EltwiseSqrt, FixedArity<1>::Mixin>
955{
956public:
957 EltwiseSqrt() = default;
958
959public:
960 Node *input(void) const { return at(0)->node(); }
961 void input(Node *node) { at(0)->node(node); }
962};
963
979 : public CanonicalNodeDef<CanonicalOpcode::TensorBroadcast, FixedArity<1>::Mixin>
980{
981public:
982 TensorBroadcast() = default;
983
984public:
985 Node *input(void) const { return at(0)->node(); }
986 void input(Node *node) { at(0)->node(node); }
987
988public:
989 class Mapping final
990 {
991 public:
992 Mapping() = default;
993
994 public:
995 bool defined(const TensorAxis &axis) const;
996
997 const Dimension &dim(const TensorAxis &axis) const;
998 Dimension &dim(const TensorAxis &axis);
999
1000 private:
1001 std::map<TensorAxis, Dimension> _content;
1002 };
1003
1004 Mapping *mapping(void) { return &_mapping; }
1005 const Mapping *mapping(void) const { return &_mapping; }
1006
1007private:
1008 Mapping _mapping;
1009};
1010
1016class MatrixEncode final
1017 : public CanonicalNodeDef<CanonicalOpcode::MatrixEncode, FixedArity<1>::Mixin>
1018{
1019public:
1020 MatrixEncode() = default;
1021
1022public:
1023 Node *input(void) const { return at(0)->node(); }
1024 void input(Node *node) { at(0)->node(node); }
1025
1026public:
1027 MatrixEncoder *encoder(void) const { return _enc.get(); }
1028 void encoder(std::unique_ptr<MatrixEncoder> &&enc) { _enc = std::move(enc); }
1029
1030private:
1032 std::unique_ptr<MatrixEncoder> _enc{nullptr};
1033};
1034
1040class MatrixDecode final
1041 : public CanonicalNodeDef<CanonicalOpcode::MatrixDecode, FixedArity<1>::Mixin>
1042{
1043public:
1044 MatrixDecode() = default;
1045
1046public:
1047 Node *input(void) const { return at(0)->node(); }
1048 void input(Node *node) { at(0)->node(node); }
1049
1050public:
1051 MatrixDecoder *decoder(void) const { return _dec.get(); }
1052 void decoder(std::unique_ptr<MatrixDecoder> &&dec) { _dec = std::move(dec); }
1053
1054private:
1056 std::unique_ptr<MatrixDecoder> _dec{nullptr};
1057};
1058
1064class MatMul final : public CanonicalNodeDef<CanonicalOpcode::MatMul, FixedArity<2>::Mixin>
1065{
1066public:
1067 MatMul() = default;
1068
1069public:
1070 Node *lhs(void) const { return at(0)->node(); }
1071 void lhs(Node *node) { return at(0)->node(node); }
1072
1073 Node *rhs(void) const { return at(1)->node(); }
1074 void rhs(Node *node) { return at(1)->node(node); }
1075};
1076
1089 : public CanonicalNodeDef<CanonicalOpcode::TensorTranspose, FixedArity<1>::Mixin>
1090{
1091public:
1092 TensorTranspose() = default;
1093
1094public:
1095 Node *input(void) const { return at(0)->node(); }
1096 void input(Node *node) { return at(0)->node(node); }
1097
1098 class Perm final
1099 {
1100 public:
1101 Perm() = default;
1102
1103 public:
1104 uint32_t size() const { return _vals.size(); }
1105 void size(uint32_t size) { _vals.resize(size); }
1106
1107 const TensorAxis &axis(TensorAxis n) const { return _vals[n]; }
1108 TensorAxis &axis(TensorAxis n) { return _vals[n]; }
1109
1110 private:
1111 std::vector<TensorAxis> _vals;
1112 };
1113
1114 Perm *perm(void) { return &_perm; }
1115 const Perm *perm(void) const { return &_perm; }
1116
1117private:
1118 Perm _perm;
1119};
1120
1121} // namespace loco
1122
1123#endif // __LOCO_IR_NODES_H__
2D Average Pooling
Definition Nodes.h:337
Convention convention(void) const
Definition Nodes.h:353
Node * ifm(void) const
Definition Nodes.h:349
const Stride< 2 > * stride(void) const
Definition Nodes.h:365
const Padding2D * pad(void) const
Definition Nodes.h:357
const Window< 2 > * window(void) const
Definition Nodes.h:361
Window< 2 > * window(void)
Definition Nodes.h:362
Padding2D * pad(void)
Definition Nodes.h:358
Stride< 2 > * stride(void)
Definition Nodes.h:366
void convention(const Convention &convention)
Definition Nodes.h:354
void ifm(Node *node)
Definition Nodes.h:350
Add Feature and Bias along "depth" axis.
Definition Nodes.h:817
Node * value(void) const
Definition Nodes.h:822
Node * bias(void) const
Definition Nodes.h:825
Add Tensor and Bias.
Definition Nodes.h:781
void bias(Node *node)
Definition Nodes.h:790
Node * value(void) const
Definition Nodes.h:786
void value(Node *node)
Definition Nodes.h:787
uint32_t axis(void) const
Definition Nodes.h:792
void axis(uint32_t axis)
Definition Nodes.h:793
Node * bias(void) const
Definition Nodes.h:789
Produce a value of domain D from an input value (of domain D) and a bias.
Definition Nodes.h:770
Create a "Tensor" from a "Bias".
Definition Nodes.h:743
Node * input(void) const
Definition Nodes.h:748
void input(Node *node)
Definition Nodes.h:749
BiasDecode()=default
Create a "Bias" from a "Tensor".
Definition Nodes.h:758
BiasEncode()=default
Node * input(void) const
Definition Nodes.h:763
void input(Node *node)
Definition Nodes.h:764
Create a value from constant byte array.
Definition Nodes.h:218
uint32_t size(void) const
Return the number of reserved elements.
Definition Nodes.cpp:185
ConstGen()=default
const DataTypeImpl< DT >::Type & at(uint32_t n) const
Get the element at a given position @require at(n) is valid only when n < size()
Definition Nodes.cpp:198
2D Spatial Convolution
Definition Nodes.h:554
void ifm(Node *node)
Definition Nodes.h:557
const Stride< 2 > * stride(void) const
Definition Nodes.h:567
Node * ker(void) const
Definition Nodes.h:559
const Padding2D * pad(void) const
Definition Nodes.h:563
Padding2D * pad(void)
Definition Nodes.h:564
void ker(Node *node)
Definition Nodes.h:560
Node * ifm(void) const
Definition Nodes.h:556
Stride< 2 > * stride(void)
Definition Nodes.h:568
Depthwise 2D Convolution.
Definition Nodes.h:582
void ifm(Node *node)
Definition Nodes.h:585
Padding2D * pad(void)
Definition Nodes.h:592
Node * ifm(void) const
Definition Nodes.h:584
const Stride< 2 > * stride(void) const
Definition Nodes.h:595
const Padding2D * pad(void) const
Definition Nodes.h:591
void ker(Node *node)
Definition Nodes.h:588
Node * ker(void) const
Definition Nodes.h:587
Stride< 2 > * stride(void)
Definition Nodes.h:596
Create a tensor from a depthwise filter.
Definition Nodes.h:475
DepthwiseFilterDecoder * decoder(void) const
Definition Nodes.h:481
void decoder(std::unique_ptr< DepthwiseFilterDecoder > &&dec)
Definition Nodes.h:482
Node * input(void) const
Definition Nodes.h:477
void input(Node *node)
Definition Nodes.h:478
Create a depthwise filter from a tensor.
Definition Nodes.h:456
Node * input(void) const
Definition Nodes.h:458
void input(Node *node)
Definition Nodes.h:459
DepthwiseFilterEncoder * encoder(void) const
Definition Nodes.h:462
void encoder(std::unique_ptr< DepthwiseFilterEncoder > &&enc)
Definition Nodes.h:463
The value of one dimension in a tensor shape.
Definition Dimension.h:30
Elementwise Add lhs and rhs.
Definition Nodes.h:872
void lhs(Node *node)
Definition Nodes.h:878
Node * rhs(void) const
Definition Nodes.h:880
void rhs(Node *node)
Definition Nodes.h:881
Node * lhs(void) const
Definition Nodes.h:877
EltwiseAdd()=default
Elementwise Div lhs and rhs.
Definition Nodes.h:938
void lhs(Node *node)
Definition Nodes.h:944
Node * rhs(void) const
Definition Nodes.h:946
Node * lhs(void) const
Definition Nodes.h:943
void rhs(Node *node)
Definition Nodes.h:947
EltwiseDiv()=default
Elementwise Maximum of lhs and rhs.
Definition Nodes.h:890
Node * rhs(void) const
Definition Nodes.h:898
void lhs(Node *node)
Definition Nodes.h:896
EltwiseMax()=default
void rhs(Node *node)
Definition Nodes.h:899
Node * lhs(void) const
Definition Nodes.h:895
Elementwise Mul lhs and rhs.
Definition Nodes.h:906
void lhs(Node *node)
Definition Nodes.h:912
void rhs(Node *node)
Definition Nodes.h:915
EltwiseMul()=default
Node * rhs(void) const
Definition Nodes.h:914
Node * lhs(void) const
Definition Nodes.h:911
Elementwise Sqrt of input.
Definition Nodes.h:955
void input(Node *node)
Definition Nodes.h:961
Node * input(void) const
Definition Nodes.h:960
EltwiseSqrt()=default
Elementwise Sub lhs and rhs.
Definition Nodes.h:922
Node * rhs(void) const
Definition Nodes.h:930
EltwiseSub()=default
void rhs(Node *node)
Definition Nodes.h:931
Node * lhs(void) const
Definition Nodes.h:927
void lhs(Node *node)
Definition Nodes.h:928
Create a tensor from a feature map.
Definition Nodes.h:399
void decoder(std::unique_ptr< FeatureDecoder > &&dec)
Definition Nodes.h:406
Node * input(void) const
Definition Nodes.h:401
FeatureDecoder * decoder(void) const
Definition Nodes.h:405
void input(Node *node)
Definition Nodes.h:402
Create a feature map from a tensor.
Definition Nodes.h:380
void encoder(std::unique_ptr< FeatureEncoder > &&enc)
Definition Nodes.h:387
FeatureEncoder * encoder(void) const
Definition Nodes.h:386
Node * input(void) const
Definition Nodes.h:382
void input(Node *node)
Definition Nodes.h:383
Create a tensor from a filter.
Definition Nodes.h:437
Node * input(void) const
Definition Nodes.h:439
void input(Node *node)
Definition Nodes.h:440
void decoder(std::unique_ptr< FilterDecoder > &&dec)
Definition Nodes.h:444
FilterDecoder * decoder(void) const
Definition Nodes.h:443
Create a filter from a tensor.
Definition Nodes.h:418
void input(Node *node)
Definition Nodes.h:421
Node * input(void) const
Definition Nodes.h:420
FilterEncoder * encoder(void) const
Definition Nodes.h:424
void encoder(std::unique_ptr< FilterEncoder > &&enc)
Definition Nodes.h:425
Create a new value identical to its input.
Definition Nodes.h:146
Forward()=default
void input(Node *node)
Definition Nodes.h:152
Node * input(void) const
Definition Nodes.h:151
Matrix Multiplication lhs and rhs.
Definition Nodes.h:1065
MatMul()=default
Node * rhs(void) const
Definition Nodes.h:1073
Node * lhs(void) const
Definition Nodes.h:1070
void rhs(Node *node)
Definition Nodes.h:1074
void lhs(Node *node)
Definition Nodes.h:1071
Create Tensor from Matrix.
Definition Nodes.h:1042
MatrixDecoder * decoder(void) const
Definition Nodes.h:1051
Node * input(void) const
Definition Nodes.h:1047
void decoder(std::unique_ptr< MatrixDecoder > &&dec)
Definition Nodes.h:1052
void input(Node *node)
Definition Nodes.h:1048
MatrixDecode()=default
Create Matrix from Tensor.
Definition Nodes.h:1018
MatrixEncoder * encoder(void) const
Definition Nodes.h:1027
Node * input(void) const
Definition Nodes.h:1023
void encoder(std::unique_ptr< MatrixEncoder > &&enc)
Definition Nodes.h:1028
void input(Node *node)
Definition Nodes.h:1024
MatrixEncode()=default
2D Max Pooling
Definition Nodes.h:305
Window< 2 > * window(void)
Definition Nodes.h:316
const Padding2D * pad(void) const
Definition Nodes.h:311
const Stride< 2 > * stride(void) const
Definition Nodes.h:319
Stride< 2 > * stride(void)
Definition Nodes.h:320
const Window< 2 > * window(void) const
Definition Nodes.h:315
Padding2D * pad(void)
Definition Nodes.h:312
void ifm(Node *node)
Definition Nodes.h:308
Node * ifm(void) const
Definition Nodes.h:307
Logical unit of computation.
Definition Node.h:54
Create a value from user data.
Definition Nodes.h:96
GraphInputIndex index(void) const
Get associated input index.
Definition Nodes.cpp:122
Pull()=default
DataType dtype(void) const
Definition Nodes.cpp:147
bool indexed(void) const
Check whether index is initialized.
Definition Nodes.h:118
Make a value visible to user.
Definition Nodes.h:53
void from(Node *node)
Definition Nodes.h:59
GraphOutputIndex index(void) const
Get associated output index.
Definition Nodes.cpp:58
Node * from(void) const
Definition Nodes.h:58
Push()=default
bool indexed(void) const
Check whether index is initialized.
Definition Nodes.h:79
Create a new value that rectifies its input capping the units at 6.
Definition Nodes.h:172
ReLU6()=default
Node * input(void) const
Definition Nodes.h:177
void input(Node *node)
Definition Nodes.h:178
Create a new value that rectifies its input.
Definition Nodes.h:159
ReLU()=default
void input(Node *node)
Definition Nodes.h:165
Node * input(void) const
Definition Nodes.h:164
Reshape a tensor to another tensor whose shape is known at compile time.
Definition Nodes.h:517
Computes softmax activations for Tensor domain.
Definition Nodes.h:722
Node * input(void) const
Definition Nodes.h:727
uint32_t axis(void) const
Definition Nodes.h:730
void axis(uint32_t axis)
Definition Nodes.h:731
void input(Node *node)
Definition Nodes.h:728
Computes softmax activations.
Definition Nodes.h:714
Stride configuration for N-dimensional spatial operations.
Definition Stride.h:28
Create a new value that rectifies its input by tanh.
Definition Nodes.h:185
void input(Node *node)
Definition Nodes.h:191
Node * input(void) const
Definition Nodes.h:190
Tanh()=default
const Dimension & dim(const TensorAxis &axis) const
Definition Nodes.cpp:236
bool defined(const TensorAxis &axis) const
Definition Nodes.cpp:231
Duplicate elements along specified axes.
Definition Nodes.h:980
Mapping * mapping(void)
Definition Nodes.h:1004
Node * input(void) const
Definition Nodes.h:985
const Mapping * mapping(void) const
Definition Nodes.h:1005
void input(Node *node)
Definition Nodes.h:986
Concatenate two tensors.
Definition Nodes.h:533
void rhs(Node *node)
Definition Nodes.h:539
uint32_t axis(void) const
Definition Nodes.h:542
void lhs(Node *node)
Definition Nodes.h:536
Node * rhs(void) const
Definition Nodes.h:538
void axis(uint32_t val)
Definition Nodes.h:543
Node * lhs(void) const
Definition Nodes.h:535
Pads a tensor with constant value.
Definition Nodes.h:852
Node * constant(void) const
Definition Nodes.h:857
Node * input(void) const
Definition Nodes.h:854
const PaddingND * padding(void) const
Definition Nodes.h:861
void input(Node *node)
Definition Nodes.h:855
PaddingND * padding(void)
Definition Nodes.h:862
void constant(Node *node)
Definition Nodes.h:858
Computes ReduceFunc operations for Tensor domain.
Definition Nodes.h:620
TensorAxisSet * axes(void)
Definition Nodes.h:627
Node * input(void) const
Definition Nodes.h:622
void input(Node *node)
Definition Nodes.h:623
ReduceFunc func(void) const
Definition Nodes.h:630
const TensorAxisSet * axes(void) const
Definition Nodes.h:626
void func(ReduceFunc func)
Definition Nodes.h:631
const TensorAxis & axis(TensorAxis n) const
Definition Nodes.h:1107
void size(uint32_t size)
Definition Nodes.h:1105
TensorAxis & axis(TensorAxis n)
Definition Nodes.h:1108
uint32_t size() const
Definition Nodes.h:1104
Permute an input.
Definition Nodes.h:1090
Perm * perm(void)
Definition Nodes.h:1114
void input(Node *node)
Definition Nodes.h:1096
const Perm * perm(void) const
Definition Nodes.h:1115
Node * input(void) const
Definition Nodes.h:1095
2D Transposed Convolution
Definition Nodes.h:688
Node * ifm(void) const
Definition Nodes.h:690
Node * ker(void) const
Definition Nodes.h:693
Padding2D * pad(void)
Definition Nodes.h:698
void ifm(Node *node)
Definition Nodes.h:691
Stride< 2 > * stride(void)
Definition Nodes.h:702
void ker(Node *node)
Definition Nodes.h:694
const Stride< 2 > * stride(void) const
Definition Nodes.h:701
const Padding2D * pad(void) const
Definition Nodes.h:697
ND Receptive Field Shape.
Definition Window.h:30
uint32_t GraphInputIndex
void link(GraphOutput *, Push *push)
Definition Nodes.cpp:65
uint32_t GraphOutputIndex
uint32_t TensorAxis
Definition TensorAxis.h:25
ReduceFunc
Reduce type functions.
Definition Nodes.h:609
Pull * pull_node(Graph *g, const GraphInputIndex &index)
Find a Pull node with a given input index.
Definition Nodes.cpp:162
ReshapeType
Definition Nodes.h:490
DataType
"scalar" value type
Definition DataType.h:27
Domain
Describe the kind of (N-dimensional) loco values.
Definition Domain.h:40
Push * push_node(Graph *g, const GraphOutputIndex &index)
Find a Push node with a given output index.
Definition Nodes.cpp:67
C++ scalar type corresponding to each DataType.
Describe how to build a tensor from a depthwise convolution filter.
Describe how to build a depthwise convolution filter from a tensor.
Describe how to build a tensor from a (convolution) feature map.
Decribe how to build a (convolution) feature map from a tensor.
Decribe how to build a a tensor from a filter.
Definition FilterCodec.h:54
Decribe how to build a (convolution) filter from a tensor.
Definition FilterCodec.h:43
Describe how to build a tensor from a matrix.
Definition MatrixCodec.h:64
Decribe how to build a matrix from a tensor.
Definition MatrixCodec.h:43