ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Reshape.h File Reference
#include "Shape.h"
#include <cstdint>

Go to the source code of this file.

Functions

bool reshapePrepare (const Shape &input, const int32_t *targetDims, const int32_t targetDimsSize, Shape *output)
 
bool reshapeGeneric (const void *inputData, const Shape &inputShape, void *outputData, const Shape &outputShape)
 

Function Documentation

◆ reshapeGeneric()

bool reshapeGeneric ( const void *  inputData,
const Shape inputShape,
void *  outputData,
const Shape outputShape 
)

Definition at line 67 of file Reshape.cpp.

69{
70 size_t count = sizeOfData(inputShape.type, inputShape.dimensions);
71 memcpy(outputData, inputData, count);
72 return true;
73}
uint32_t sizeOfData(const Operand &operand)
Definition Operand.h:56
OperandType type
Definition Shape.h:29
std::vector< uint32_t > dimensions
Definition Shape.h:30

References Shape::dimensions, sizeOfData(), and Shape::type.

◆ reshapePrepare()

bool reshapePrepare ( const Shape input,
const int32_t *  targetDims,
const int32_t  targetDimsSize,
Shape output 
)

Definition at line 24 of file Reshape.cpp.

26{
27 // Reshape allows one of the targetDims components to have the
28 // special -1 value, meaning it will be calculated automatically based on the
29 // input. Here we calculate what that dimension should be so that the number
30 // of output elements in the same as the number of input elements.
31 int32_t numInputElements = (int32_t)getNumberOfElements(input);
32
33 std::vector<uint32_t> outDims(targetDimsSize);
34 int32_t numOutputElements = 1;
35 int32_t strechDim = -1;
36 for (int32_t i = 0; i < targetDimsSize; ++i)
37 {
38 int32_t value = targetDims[i];
39 if (value == -1)
40 {
41 ASSERT(strechDim == -1);
42 strechDim = i;
43 }
44 else
45 {
46 numOutputElements *= value;
47 outDims[i] = (uint32_t)value;
48 }
49 }
50 if (strechDim != -1)
51 {
52 int32_t strechValue = numInputElements / numOutputElements;
53 outDims[strechDim] = (uint32_t)strechValue;
54 numOutputElements *= strechValue;
55 }
56
57 ASSERT(numInputElements == numOutputElements);
58
59 output->type = input.type;
60 output->dimensions = outDims;
61 output->offset = input.offset;
62 output->scale = input.scale;
63
64 return true;
65}
#define ASSERT(v)
Definition Assert.h:24
uint32_t getNumberOfElements(const Shape &shape)
Definition Shape.cpp:48

References ASSERT, and getNumberOfElements().