ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onnx_legalizer.py
Go to the documentation of this file.
1#!/usr/bin/python3
2
3# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import onnx
18import onnx.numpy_helper
19import sys
20import numpy as np
21import re
22
23# Transform onnx model to make it compilable with our toolchain
24#
25# This code works with onnx model in proto format. See proto buffers format in
26# https://github.com/onnx/onnx/blob/96516aecd4c110b0ac57eba08ac236ebf7205728/onnx/onnx.proto3
27#
28# More examples of handling onnx models could be found here:
29# https://github.com/onnx/onnx/tree/96516aecd4c110b0ac57eba08ac236ebf7205728/onnx/examples
30#
31# List of transformations:
32# - Replace RNN operation with unrolled subgraph
33# - Replace LSTM operation with unrolled subgraph
34
35
37 """Controls transformations that legalizer apply
38
39 Attributes:
40 unroll_rnn (bool): default is False. If True - unrolls RNN operations
41 unroll_lstm (bool): default is False. If True - unrolls LSTM operations
42 """
43
44 unroll_rnn = False
45 unroll_lstm = False
46
47
49 return ''.join(reversed(s))
50
51
53 """Splits tensor name to base part and serial number
54
55 Most of tensor names have following format: "tensor_123".
56 This function breaks name into two values: "tensor_" and 123.
57 Tensor names like this: "321" are broken into "" and 321.
58
59 Serial number is used to create unique tensor names using given base name.
60
61 Args:
62 name (str): tensor name
63
64 Returns:
65 tuple of str, int: base name and serial number of tensor
66 """
67 rev = _reverse_str(name)
68 m = re.match('(\d*)(.*)', rev)
69 if m.groups()[0] != '':
70 return (_reverse_str(m.groups()[1]), int(_reverse_str(m.groups()[0])))
71 else:
72 return (_reverse_str(m.groups()[1]), 0)
73
74
76 """Helper for onnx model transformation
77
78 This helper is used for convenient operation replacement in onnx model
79
80 Attributes:
81 _model (onnx.onnx_ml_pb2.ModelProto): target model that should be changed
82 _nodes_to_delete (list of onnx.onnx_ml_pb2.NodeProto): list of replaced operations
83 _insert_id (int): position to insert created operations (should be in topologically sorted)
84 _base_name_idx (dict from str to int): maps tensor "base" name to
85 largest existing serial num. For example model has tensors "t_1", "t_2", "t_4",
86 in that case _base_name_idx["t_"] == 4.
87 This attribute is used for unique tensor name generation.
88 """
89 def __init__(self, model):
90 self._model = model
92 self._insert_id = 0
93 # each tensor has name containing base name and unique number. for example:
94 # "abc_123": "abs_" - base name, "123" - unique number
95 # if no number in name, consider it is equal to "0"
96
97 # mapping from base names to largest given number
99 # gather name information for existing tensors
100 for node in model.graph.node:
101 for t in list(node.input) + list(node.output):
102 base_name, number = _parse_tensor_name(t)
103 if base_name in self._base_name_idx:
104 self._base_name_idx[base_name] = max(self._base_name_idx[base_name],
105 number)
106 else:
107 self._base_name_idx[base_name] = number
108
109 def make_tensor_with_base_name(self, base_name):
110 """ Create unique name for given base_name
111
112 Args:
113 base_name (str): base tensor name
114
115 Returns:
116 str : unique tensor name that starts with base_name
117 """
118 if base_name in self._base_name_idx:
119 self._base_name_idx[base_name] += 1
120 return base_name + str(self._base_name_idx[base_name])
121 else:
122 self._base_name_idx[base_name] = 0
123 return base_name + '0'
124
125 def make_node(self, opcode, inputs, outputs, *p_args, **k_args):
126 """Create arbitrary node and insert it in graph.
127
128 Args:
129 opcode (str): opcode name of desired operation
130 inputs (list of str): names of input tensors
131 outputs (list of str or int): names of existing tensors to use as output tensors for operation or
132 number of tensors that should be created
133 p_args: additional arguments for onnx make_node helper
134 k_args: attributes for onnx node
135
136 Returns:
137 list of str: list of output tensor names
138 """
139 if type(outputs) == int:
140 outputs = [self.make_tensor_with_base_name('') for i in range(outputs)]
141 assert (type(outputs) == list)
142 node = onnx.helper.make_node(opcode, inputs, outputs, *p_args, **k_args)
143 self._model.graph.node.insert(self._insert_id, node)
144 self._insert_id += 1
145 return outputs
146
147 def make_split(self, input, split_sizes, axis):
148 """Create Split operation and insert it in graph.
149
150 Args:
151 input (str): name of input tensor
152 split_sizes (list of int): list of split sizes
153 axis (int): number of axis to split
154
155 Returns:
156 list: list of output tensor names
157 """
158 return self.make_node('Split', [input],
159 len(split_sizes),
160 axis=axis,
161 split=split_sizes)
162
163 def make_concat(self, inputs, axis):
164 """Create Concat operation and insert it in graph.
165
166 Args:
167 inputs (list of str): list of tensors names to concat
168 axis (int): axis number to concat
169
170 Returns:
171 str: output tensor name
172 """
173 return self.make_node('Concat', inputs, 1, axis=axis)[0]
174
175 def make_squeeze(self, input, axes):
176 """Create Squeeze operation and insert it in graph.
177
178 Args:
179 input (str): name of input tensor
180 axes (list of int): list of dimension containing ones to remove
181
182 Returns:
183 str: output tensor name
184 """
185 return self.make_node('Squeeze', [input], 1, axes=axes)[0]
186
187 def make_unsqueeze(self, input, axes):
188 """Create Unsqueeze operation and insert it in graph.
189
190 Args:
191 input (str): name of input tensor
192 axes (list of int): list of dimension to insert ones
193
194 Returns:
195 str: output tensor name
196 """
197 return self.make_node('Unsqueeze', [input], 1, axes=axes)[0]
198
199 def make_gemm(self, A, B, C, trans_a=False, trans_b=False):
200 """Create Gemm operation and insert it in graph.
201
202 Result tensor contains A*B + C
203
204 Args:
205 A (str): name of tensor A
206 B (str): name of tensor B
207 C (str): name of tensor C
208 transA (bool): if True, transpose tensor A before multiplication
209 transB (bool): if True, transpose tensor B before multiplication
210
211 Returns:
212 str: output tensor name
213 """
214 return self.make_node('Gemm', [A, B, C],
215 1,
216 transA=bool(trans_a),
217 transB=bool(trans_b))[0]
218
219 def make_add(self, a, b):
220 """Creates Add operation and insert it in graph.
221
222 Args:
223 a (str): name of left operand tensor
224 b (str): name of right operand tensor
225
226 Returns:
227 str: output tensor name
228 """
229 return self.make_node('Add', [a, b], 1)[0]
230
231 def make_mul(self, a, b):
232 """Creates Mul operation and insert it in graph.
233
234 Args:
235 a (str): name of left operand tensor
236 b (str): name of right operand tensor
237
238 Returns:
239 str: output tensor name
240 """
241 return self.make_node('Mul', [a, b], 1)[0]
242
243 def make_clip(self, input, min, max):
244 """Create Clip operation and insert it in graph.
245
246 Args:
247 input (str): input tensor name
248 min (int/float): lower clip bound
249 max (int/float ): upper clip bound
250
251 Returns:
252 str: output tensor name
253 """
254 return self.make_node('Clip', [input], 1, min=min, max=max)[0]
255
256 def make_act(self, input, act_name):
257 """Create activation function operation and insert it in graph.
258
259 Args:
260 input (str): input tensor name
261 act_name (str): name of activation function, one of ['Relu', 'Tanh', 'Sigmoid']
262
263 Returns:
264 str: output tensor name
265 """
266 assert (act_name in ['Relu', 'Tanh', 'Sigmoid'])
267 return self.make_node(act_name, [input], 1)[0]
268
269 def make_constant_tensor(self, tensor_data, base_name):
270 """Creates onnx constant tensor
271
272 Args:
273 tensor_data (numpy.ndarray): tensor data
274 base_name (str): prefix of constant tensor name
275
276 Returns:
277 str: name of created constant tensor
278 """
279 tensor = onnx.numpy_helper.from_array(tensor_data)
280 tensor.name = self.make_tensor_with_base_name(base_name)
281 self._model.graph.initializer.append(tensor)
282 return tensor.name
283
284 def mark_for_deletion(self, node):
285 self._nodes_to_delete += [node]
286
287 def get_insert_id(self):
288 return self._insert_id
289
290 def set_insert_id(self, insert_id):
291 self._insert_id = insert_id
292
294 for node in self._nodes_to_delete:
295 self._model.graph.node.remove(node)
296
297
299 def __init__(self, dtype, shape):
300 self.dtype = dtype
301 self.shape = shape
302
303
305 """Infer tensor shapes and dtypes
306 Args:
307 model (onnx.onnx_ml_pb2.ModelProto): model to process
308
309 Returns:
310 dict from str to _TensorInfo: maps tensor name to shape and dtype information
311 """
312
313 inferred_shape_model = onnx.shape_inference.infer_shapes(model)
314
315 infos = {}
316 for tensor in list(inferred_shape_model.graph.value_info) + list(
317 inferred_shape_model.graph.input):
318 info = _TensorInfo(tensor.type.tensor_type.elem_type, [])
319 for dim in tensor.type.tensor_type.shape.dim:
320 info.shape += [dim.dim_value]
321 infos[tensor.name] = info
322
323 for tensor in list(model.graph.initializer):
324 infos[tensor.name] = _TensorInfo(tensor.data_type, tensor.dims)
325 return infos
326
327
328def _dtype_to_np(dtype):
329 """Convert onnx dtype value to numpy dtype class
330
331 For more types see:
332 https://github.com/onnx/onnx/blob/96516aecd4c110b0ac57eba08ac236ebf7205728/onnx/onnx.proto3#L484
333
334 Args:
335 dtype (int): onnx dtype
336
337 Returns:
338 numpy data type: numpy dtype, like np.float32
339 """
340
341 if dtype == 1:
342 return np.float32
343 else:
344 raise NotImplementedError('unsupported data type')
345
346
347def _generate_one_direction_RNN(transformer, X, W, R, B, initial_h, clip,
348 activation_name):
349 """Generate subgraph of one direction of unrolled RNN layer
350
351 Args:
352 transformer (_ModelTransformerHelper): helper for model generation
353 X (list of str): names of input tensors in sequence. Tensor shapes: [batch_size, input_size].
354 W (str): name of weight tensor
355 R (str): name of recurrence weight tensor
356 B (str): name of bias tensor
357 initial_h (str or None): name of tensor containing initial hidden state. Shape [batch_size, hidden_size]
358 clip (float or None): range which clips input of activations
359 act (str): activation function
360 """
361 # one direction RNN:
362 #
363 # For details see:
364 # https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Changelog.md#RNN-7
365 #
366 # H = f(X*(W^T) + h*(R^T) + B)
367 #
368 # H - new hidden state
369 # h - previous hidden state
370 # X - current input
371 # W - input weights matrix
372 # R - reccurent weights matrix
373 # Wb - input weights matmul bias
374 # Rb - reccurent weights matmul bias
375 # f - activation function
376
377 seq_length = len(X)
378 first_iter = 0
379 state_tensors = []
380 if initial_h is not None:
381 previous_state_tensor = initial_h
382 else:
383 first_iter = 1
384 state_tensor = transformer.make_gemm(X[0], W, B, trans_b=True)
385 if clip != None:
386 state_tensor = transformer.make_clip(state_tensor, min=-clip, max=clip)
387 previous_state_tensor = transformer.make_act(state_tensor, activation_name)
388 state_tensors += [previous_state_tensor]
389
390 for i in range(first_iter, seq_length):
391 state_tensor = transformer.make_gemm(X[i], W, B, trans_b=True)
392 state_tensor = transformer.make_gemm(previous_state_tensor,
393 R,
394 state_tensor,
395 trans_b=True)
396 if clip != None:
397 state_tensor = transformer.make_clip(state_tensor, min=-clip, max=clip)
398 previous_state_tensor = transformer.make_act(state_tensor, activation_name)
399 state_tensors += [previous_state_tensor]
400 return state_tensors
401
402
403def _transform_unidirectional_RNN(transformer, original_node, x, tensor_infos, activation,
404 clip, direction, hidden_size, layout):
405 """Generate Simple (forward or reverse) unrolled RNN
406
407 Args:
408 transformer (_ModelTransformerHelper): transformation helper
409 original_node (onnx.onnx_ml_pb2.NodeProto): unidirectional RNN operation to unroll
410 x (list of str): list of input tensors (input tensor split along "time" dimension)
411 tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
412 activation (str): name of activation function
413 clip (float or None): range which clips input of activations
414 direction (str): "forward" or "reverse"
415 hidden_size (int): size of hidden state
416 layout (int): See attribute description:
417 https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-56
418 """
419
420 inputs = original_node.input
421 outputs = original_node.output
422 if direction == 'reverse':
423 x.reverse()
424 w = transformer.make_squeeze(inputs[1], axes=[0])
425 r = transformer.make_squeeze(inputs[2], axes=[0])
426 if len(inputs) > 3 and inputs[3] != '':
427 raw_bias_tensor = transformer.make_squeeze(inputs[3], axes=[0])
428 splitted_bias_tensors = transformer.make_split(raw_bias_tensor,
429 split_sizes=[hidden_size] * 2,
430 axis=0)
431 b = transformer.make_add(splitted_bias_tensors[0], splitted_bias_tensors[1])
432 else:
433 data_type = _dtype_to_np(tensor_infos[inputs[2]].dtype)
434 b = transformer.make_constant_tensor(np.zeros(hidden_size, dtype=data_type),
435 "zero_bias")
436 if len(inputs) > 5 and inputs[5] != '':
437 direction_dim = layout
438 initial_h = transformer.make_squeeze(inputs[5], axes=[direction_dim])
439 else:
440 initial_h = None
441 state_tensors = _generate_one_direction_RNN(transformer, x, w, r, b, initial_h, clip,
442 activation)
443 y_direction_dim = layout + 1
444 y_h_direction_dim = layout
445 state_layout_tensors = []
446 seq_length_dim = layout
447 for state in state_tensors:
448 state_layout_tensors += [
449 transformer.make_unsqueeze(state, axes=[seq_length_dim, y_direction_dim])
450 ]
451
452 # use low-level interface to attach to existing tensors
453 Y_h = outputs[1]
454 transformer.make_node('Unsqueeze', [state_tensors[-1]], [Y_h],
455 axes=[y_h_direction_dim])
456 Y = outputs[0]
457 transformer.make_node('Concat', state_layout_tensors, [Y], axis=seq_length_dim)
458
459
460def _transform_bidirectional_RNN(transformer, original_node, x, tensor_infos, activations,
461 clip, hidden_size, layout):
462 """Generate Bidirectional unrolled RNN
463
464 Args:
465 transformer (_ModelTransformerHelper): transformation helper
466 original_node (onnx.onnx_ml_pb2.NodeProto): bidirectional RNN operation to unroll
467 x (list of str): list of input tensors (input tensor split along "time" dimension)
468 tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
469 activations (list of str): list of len (2) containing names of forward and reverse activations
470 clip (float or None): range which clips input of activations
471 hidden_size (int): size of hidden state
472 layout (int): See attribute description:
473 https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-56
474 """
475
476 inputs = original_node.input
477 outputs = original_node.output
478 w_bi = transformer.make_split(inputs[1], split_sizes=[1, 1], axis=0)
479 r_bi = transformer.make_split(inputs[2], split_sizes=[1, 1], axis=0)
480 w = []
481 r = []
482 for d in range(2):
483 w += [transformer.make_squeeze(w_bi[d], axes=[0])]
484 r += [transformer.make_squeeze(r_bi[d], axes=[0])]
485
486 b = []
487 if len(inputs) > 3 and inputs[3] != '':
488 raw_bias_tensors = transformer.make_split(inputs[3], split_sizes=[1, 1], axis=0)
489 for d in range(2):
490 raw_bias_tensors_squeezed = transformer.make_squeeze(raw_bias_tensors[d],
491 axes=[0])
492 splitted_bias_tensors = transformer.make_split(raw_bias_tensors_squeezed,
493 split_sizes=[hidden_size] * 2,
494 axis=0)
495 b += [
496 transformer.make_add(splitted_bias_tensors[0], splitted_bias_tensors[1])
497 ]
498 else:
499 data_type = _dtype_to_np(tensor_infos[inputs[2]].dtype)
500 b = [
501 transformer.make_constant_tensor(np.zeros(hidden_size, dtype=data_type),
502 "zero_bias")
503 ] * 2
504 initial_h = [None, None]
505 if len(inputs) > 5 and inputs[5] != '':
506 direction_dim = layout
507 initial_h = transformer.make_split(inputs[5],
508 split_sizes=[1, 1],
509 axis=direction_dim)
510 for d in range(2):
511 initial_h[d] = transformer.make_squeeze(initial_h[d], axes=[direction_dim])
512
513 state_f_tensors = _generate_one_direction_RNN(transformer, x, w[0], r[0], b[0],
514 initial_h[0], clip, activations[0])
515 x.reverse()
516 state_b_tensors = _generate_one_direction_RNN(transformer, x, w[1], r[1], b[1],
517 initial_h[1], clip, activations[1])
518 state_b_tensors.reverse()
519
520 y_direction_dim = layout + 1
521 y_h_direction_dim = layout
522 state_layout_tensors = []
523 seq_length_dim = layout
524 seq_length = len(x)
525 for t in range(seq_length):
526 state_f = state_f_tensors[t]
527 state_b = state_b_tensors[t]
528 state_layout_tensors_f = transformer.make_unsqueeze(
529 state_f, axes=[seq_length_dim, y_direction_dim])
530 state_layout_tensors_b = transformer.make_unsqueeze(
531 state_b, axes=[seq_length_dim, y_direction_dim])
532 state_layout_tensors += [
533 transformer.make_concat([state_layout_tensors_f, state_layout_tensors_b],
534 axis=y_direction_dim)
535 ]
536
537 last_f_state_layout_tensor = transformer.make_unsqueeze(state_f_tensors[-1],
538 axes=[y_h_direction_dim])
539 last_b_state_layout_tensor = transformer.make_unsqueeze(state_b_tensors[0],
540 axes=[y_h_direction_dim])
541
542 # use low-level interface to attach to existing tensors
543 Y_h = outputs[1]
544 transformer.make_node('Concat',
545 [last_f_state_layout_tensor, last_b_state_layout_tensor], [Y_h],
546 axis=y_h_direction_dim)
547
548 Y = outputs[0]
549 transformer.make_node('Concat', state_layout_tensors, [Y], axis=seq_length_dim)
550
551
552def _legalize_RNN(transformer, tensor_infos, node):
553 """Unroll RNN operation
554
555 Args:
556 transformer (_ModelTransformerHelper): transformation helper
557 tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
558 node (onnx.onnx_ml_pb2.NodeProto): RNN operation to unroll
559 """
560 inputs = node.input
561 if len(inputs) > 4 and inputs[4] != '':
562 raise NotImplementedError('Variadic length of output is not supported')
563 # attributes
564 activation_alpha = []
565 activation_beta = []
566 activations = ['Tanh', 'Tanh']
567 clip = None
568 direction = 'forward'
569 hidden_size = 0
570 layout = 0
571
572 for attr in node.attribute:
573 if attr.name == 'activation_alpha':
574 activation_alpha = attr.floats
575 if attr.name == 'activation_beta':
576 activation_beta = attr.floats
577 if attr.name == 'activations':
578 activations = list(map(lambda item: item.decode('UTF-8'), list(attr.strings)))
579 if attr.name == 'clip':
580 clip = attr.f
581 if attr.name == 'direction':
582 direction = attr.s.decode('UTF-8')
583 if attr.name == 'hidden_size':
584 hidden_size = attr.i
585 if attr.name == 'layout':
586 layout = attr.i
587
588 if len(activation_alpha) > 0 or len(activation_beta) > 0:
589 raise NotImplementedError('Unsupported parameters for LSTM activations')
590
591 for act in activations:
592 if act not in ['Relu', 'Tanh', 'Sigmoid']:
593 raise NotImplementedError('Unsupported activation function')
594
595 seq_length_dim = layout
596 seq_length = tensor_infos[inputs[0]].shape[seq_length_dim]
597 if hidden_size == 0:
598 hidden_size = tensor_infos[inputs[2]].shape[2]
599
600 input_split_tensor = transformer.make_split(inputs[0],
601 split_sizes=[1] * seq_length,
602 axis=seq_length_dim)
603 x = []
604 for i in range(len(input_split_tensor)):
605 input_frame_tensor = input_split_tensor[i]
606 squeezed_frame_tensor = transformer.make_squeeze(input_frame_tensor, axes=[0])
607 x += [squeezed_frame_tensor]
608
609 if direction in ['forward', 'reverse']:
610 _transform_unidirectional_RNN(transformer, node, x, tensor_infos, activations[0],
611 clip, direction, hidden_size, layout)
612 elif direction == 'bidirectional':
613 _transform_bidirectional_RNN(transformer, node, x, tensor_infos, activations,
614 clip, hidden_size, layout)
615 else:
616 raise RuntimeError('Unknown RNN type')
617
618 transformer.mark_for_deletion(node)
619
620
621def _generate_one_direction_LSTM(transformer, X, W, R, B, initial_h, initial_c, P, clip,
622 act, dtype, hidden_size, batch_size):
623 """Generate subgraph for one direction of unrolled LSTM layer
624
625 Args:
626 transformer (_ModelTransformerHelper): helper for model generation
627 X (list of str): names of tensors in input sequence. Each tensor shape: [batch_size, input_size]
628 W (str): name of concatenated weight tensor: [input, output, forget, cell]
629 R (str): name of concatenated recurrence weights tensor: [input, output, forget, cell]
630 B (str): name of concatenated bias tensor: [input, output, forget, cell]
631 initial_h (str or None): name of tensor containing initial hidden state. Shape [batch_size, hidden_size]
632 initial_c (str or None): name of tensor containing initial cell state. Shape [batch_size, hidden_size]
633 P (str or None): name of concatenated peephole tensor: [input, output, forget]
634 clip (float or None): range which clips input of activations
635 act (dict of str): activation functions {'f': 'Sigmoid', 'g': 'Tanh', 'h': 'Tanh'}
636 dtype (numpy dtype): data type used in created LSTM operation
637 hidden_size (int): hidden dimension
638 batch_size (int): batch dimension
639 """
640 # one direction LSTM:
641 #
642 # For details see:
643 # https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Changelog.md#LSTM-7
644 #
645 # it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
646 # ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
647 # ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
648 # Ct = ft (.) Ct-1 + it (.) ct
649 # ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
650 # Ht = ot (.) h(Ct)
651 #
652 # X - input tensor
653 # i - input gate
654 # o - output gate
655 # f - forget gate
656 # c - cell gate
657 # t - time step (t-1 means previous time step)
658 # W[iofc] - W parameter weight matrix for input, output, forget, and cell gates
659 # R[iofc] - R recurrence weight matrix for input, output, forget, and cell gates
660 # Wb[iofc] - W bias vectors for input, output, forget, and cell gates
661 # Rb[iofc] - R bias vectors for input, output, forget, and cell gates
662 # P[iof] - P peephole weight vector for input, output, and forget gates
663 # WB[iofc] - W parameter weight matrix for backward input, output, forget, and cell gates
664 # RB[iofc] - R recurrence weight matrix for backward input, output, forget, and cell gates
665 # WBb[iofc] - W bias vectors for backward input, output, forget, and cell gates
666 # RBb[iofc] - R bias vectors for backward input, output, forget, and cell gates
667 # PB[iof] - P peephole weight vector for backward input, output, and forget gates
668 # H - Hidden state
669
670 seq_length = len(X)
671 state_h_tensors = []
672
673 w_tensors = transformer.make_split(W, split_sizes=[hidden_size] * 4, axis=0)
674 W = {'i': w_tensors[0], 'o': w_tensors[1], 'f': w_tensors[2], 'c': w_tensors[3]}
675
676 r_tensors = transformer.make_split(R, split_sizes=[hidden_size] * 4, axis=0)
677 R = {'i': r_tensors[0], 'o': r_tensors[1], 'f': r_tensors[2], 'c': r_tensors[3]}
678
679 if B is not None:
680 separate_b_tensors = transformer.make_split(B,
681 split_sizes=[hidden_size] * 8,
682 axis=0)
683 b_tensors = []
684 for i in range(4):
685 b_tensors += [
686 transformer.make_add(separate_b_tensors[i], separate_b_tensors[i + 4])
687 ]
688 else:
689 b_tensors = [
690 transformer.make_constant_tensor(np.zeros(
691 (hidden_size), dtype=dtype), 'zero_b')
692 ] * 4
693 B = {'i': b_tensors[0], 'o': b_tensors[1], 'f': b_tensors[2], 'c': b_tensors[3]}
694
695 if initial_h is not None:
696 previous_h_state_tensor = initial_h
697 else:
698 previous_h_state_tensor = transformer.make_constant_tensor(
699 np.zeros((batch_size, hidden_size), dtype=dtype), 'initial_h')
700
701 if initial_c is not None:
702 previous_c_state_tensor = initial_c
703 else:
704 previous_c_state_tensor = transformer.make_constant_tensor(
705 np.zeros((batch_size, hidden_size), dtype=dtype), 'initial_c')
706
707 if P is not None:
708 p_tensors = transformer.make_split(P, split_sizes=[hidden_size] * 3, axis=0)
709 P = {'i': p_tensors[0], 'o': p_tensors[1], 'f': p_tensors[2]}
710 else:
711 zero = transformer.make_constant_tensor(np.zeros((hidden_size), dtype=dtype),
712 'zero_peephole')
713 P = {'i': zero, 'o': zero, 'f': zero}
714
715 for i in range(seq_length):
716 # it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
717 it = transformer.make_gemm(X[i], W['i'], B['i'], trans_b=True)
718 it = transformer.make_gemm(previous_h_state_tensor, R['i'], it, trans_b=True)
719 peephole_it = transformer.make_mul(P['i'], previous_c_state_tensor)
720 it = transformer.make_add(it, peephole_it)
721 if clip is not None:
722 it = transformer.make_clip(it, min=-clip, max=clip)
723 it = transformer.make_act(it, act['f'])
724
725 # ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
726 ft = transformer.make_gemm(X[i], W['f'], B['f'], trans_b=True)
727 ft = transformer.make_gemm(previous_h_state_tensor, R['f'], ft, trans_b=True)
728 peephole_ft = transformer.make_mul(P['f'], previous_c_state_tensor)
729 ft = transformer.make_add(ft, peephole_ft)
730 if clip is not None:
731 ft = transformer.make_clip(ft, min=-clip, max=clip)
732 ft = transformer.make_act(ft, act['f'])
733
734 # ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
735 ct = transformer.make_gemm(X[i], W['c'], B['c'], trans_b=True)
736 ct = transformer.make_gemm(previous_h_state_tensor, R['c'], ct, trans_b=True)
737 if clip is not None:
738 ct = transformer.make_clip(ct, min=-clip, max=clip)
739 ct = transformer.make_act(ct, act['g'])
740
741 # Ct = ft (.) Ct-1 + it (.) ct
742 ft_Ct = transformer.make_mul(ft, previous_c_state_tensor)
743 it_ct = transformer.make_mul(it, ct)
744 Ct = transformer.make_add(ft_Ct, it_ct)
745 previous_c_state_tensor = Ct
746
747 # ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
748 ot = transformer.make_gemm(X[i], W['o'], B['o'], trans_b=True)
749 ot = transformer.make_gemm(previous_h_state_tensor, R['o'], ot, trans_b=True)
750 peephole_ot = transformer.make_mul(P['o'], Ct)
751 ot = transformer.make_add(ot, peephole_ot)
752 if clip is not None:
753 ot = transformer.make_clip(ot, min=-clip, max=clip)
754 ot = transformer.make_act(ot, act['f'])
755
756 # Ht = ot (.) h(Ct)
757 Ht = transformer.make_act(Ct, act['h'])
758 Ht = transformer.make_mul(ot, Ht)
759 previous_h_state_tensor = Ht
760 state_h_tensors += [Ht]
761
762 return (state_h_tensors, previous_c_state_tensor)
763
764
765def _transform_unidirectional_LSTM(transformer, original_node, x, tensor_infos,
766 activations, clip, direction, hidden_size, layout):
767 """Generate Simple (forward or reverse) unrolled LSTM
768
769 Args:
770 transformer (_ModelTransformerHelper): transformation helper
771 original_node (onnx.onnx_ml_pb2.NodeProto): unidirectional LSTM operation to unroll
772 x (list of str): list of input tensors (input tensor split along "time" dimension)
773 tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
774 activations (list of str): list of length 3 containing names of activation functions
775 clip (float or None): range which clips input of activations
776 direction (str): "forward" or "reverse"
777 hidden_size (int): size of hidden state
778 layout (int): See attribute description:
779 https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-37
780 """
781
782 inputs = original_node.input
783 outputs = original_node.output
784 if direction == 'reverse':
785 x.reverse()
786 w = transformer.make_squeeze(inputs[1], axes=[0])
787 r = transformer.make_squeeze(inputs[2], axes=[0])
788
789 b = None
790 if len(inputs) > 3 and inputs[3] != '':
791 b = transformer.make_squeeze(inputs[3], axes=[0])
792
793 initial_h = None
794 if len(inputs) > 5 and inputs[5] != '':
795 direction_dim = layout
796 initial_h = transformer.make_squeeze(inputs[5], axes=[direction_dim])
797
798 initial_c = None
799 if len(inputs) > 6 and inputs[6] != '':
800 direction_dim = layout
801 initial_c = transformer.make_squeeze(inputs[6], axes=[direction_dim])
802
803 p = None
804 if len(inputs) > 7 and inputs[7] != '':
805 p = transformer.make_squeeze(inputs[7], axes=[0])
806
807 dtype = _dtype_to_np(tensor_infos[inputs[0]].dtype)
808 batch_size = tensor_infos[inputs[0]].shape[1 - layout]
809
810 act = {'f': activations[0], 'g': activations[1], 'h': activations[2]}
811
812 state_h_tensors, state_c_tensor = _generate_one_direction_LSTM(
813 transformer, x, w, r, b, initial_h, initial_c, p, clip, act, dtype, hidden_size,
814 batch_size)
815
816 y_direction_dim = layout + 1
817 y_h_direction_dim = layout
818 state_layout_tensors = []
819 seq_length_dim = layout
820 for h_state in state_h_tensors:
821 state_layout_tensors += [
822 transformer.make_unsqueeze(h_state, axes=[seq_length_dim, y_direction_dim])
823 ]
824
825 # use low-level interface to attach to existing tensors
826 Y_h = outputs[1]
827 transformer.make_node('Unsqueeze', [state_h_tensors[-1]], [Y_h],
828 axes=[y_h_direction_dim])
829 Y_c = outputs[2]
830 transformer.make_node('Unsqueeze', [state_c_tensor], [Y_c], axes=[y_h_direction_dim])
831 if direction == 'reverse':
832 state_layout_tensors.reverse()
833 Y = outputs[0]
834 transformer.make_node('Concat', state_layout_tensors, [Y], axis=seq_length_dim)
835
836
837def _transform_bidirectional_LSTM(transformer, original_node, x, tensor_infos,
838 activations, clip, hidden_size, layout):
839 """Generate Bidirectional unrolled LSTM
840
841 Args:
842 transformer (_ModelTransformerHelper): transformation helper
843 original_node (onnx.onnx_ml_pb2.NodeProto): bidirectional LSTM operation to unroll
844 x (list of str): list of input tensors (input tensor split along "time" dimension)
845 tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
846 activations (list of str): list of length 6, containing names of forward and reverse activations
847 clip (float or None): range which clips input of activations
848 hidden_size (int): size of hidden state
849 layout (int): See attribute description:
850 https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-37
851 """
852
853 inputs = original_node.input
854 outputs = original_node.output
855
856 w = transformer.make_split(inputs[1], split_sizes=[1, 1], axis=0)
857 r = transformer.make_split(inputs[2], split_sizes=[1, 1], axis=0)
858 for d in range(2):
859 w[d] = transformer.make_squeeze(w[d], axes=[0])
860 r[d] = transformer.make_squeeze(r[d], axes=[0])
861
862 b = [None, None]
863 if len(inputs) > 3 and inputs[3] != '':
864 b = transformer.make_split(inputs[3], split_sizes=[1, 1], axis=0)
865 for d in range(2):
866 b[d] = transformer.make_squeeze(b[d], axes=[0])
867
868 initial_h = [None, None]
869 if len(inputs) > 5 and inputs[5] != '':
870 direction_dim = layout
871 initial_h = transformer.make_split(inputs[5],
872 split_sizes=[1, 1],
873 axis=direction_dim)
874 for d in range(2):
875 initial_h[d] = transformer.make_squeeze(initial_h[d], axes=[direction_dim])
876
877 initial_c = [None, None]
878 if len(inputs) > 6 and inputs[6] != '':
879 direction_dim = layout
880 initial_c = transformer.make_split(inputs[6],
881 split_sizes=[1, 1],
882 axis=direction_dim)
883 for d in range(2):
884 initial_c[d] = transformer.make_squeeze(initial_c[d], axes=[direction_dim])
885
886 p = [None, None]
887 if len(inputs) > 7 and inputs[7] != '':
888 p = transformer.make_split(inputs[7], split_sizes=[1, 1], axis=0)
889 for d in range(2):
890 p[d] = transformer.make_squeeze(p[d], axes=[0])
891
892 dtype = _dtype_to_np(tensor_infos[inputs[0]].dtype)
893 batch_size = tensor_infos[inputs[0]].shape[1 - layout]
894
895 act = [{
896 'f': activations[0],
897 'g': activations[1],
898 'h': activations[2]
899 }, {
900 'f': activations[3],
901 'g': activations[4],
902 'h': activations[5]
903 }]
904
905 state_f_h_tensors, state_f_c_tensor = _generate_one_direction_LSTM(
906 transformer, x, w[0], r[0], b[0], initial_h[0], initial_c[0], p[0], clip, act[0],
907 dtype, hidden_size, batch_size)
908 x.reverse()
909 state_b_h_tensors, state_b_c_tensor = _generate_one_direction_LSTM(
910 transformer, x, w[1], r[1], b[1], initial_h[1], initial_c[1], p[1], clip, act[1],
911 dtype, hidden_size, batch_size)
912 state_b_h_tensors.reverse()
913
914 y_direction_dim = layout + 1
915 y_c_direction_dim = layout
916 state_layout_tensors = []
917 seq_length_dim = layout
918 for f_h_state, b_h_state in zip(state_f_h_tensors, state_b_h_tensors):
919 state_f_layout_tensors = transformer.make_unsqueeze(
920 f_h_state, axes=[seq_length_dim, y_direction_dim])
921 state_b_layout_tensors = transformer.make_unsqueeze(
922 b_h_state, axes=[seq_length_dim, y_direction_dim])
923 state_layout_tensors += [
924 transformer.make_concat([state_f_layout_tensors, state_b_layout_tensors],
925 axis=y_direction_dim)
926 ]
927
928 last_f_state_layout_tensor = transformer.make_unsqueeze(state_f_h_tensors[-1],
929 axes=[y_c_direction_dim])
930 last_b_state_layout_tensor = transformer.make_unsqueeze(state_b_h_tensors[0],
931 axes=[y_c_direction_dim])
932
933 Y_h = outputs[1]
934 transformer.make_node('Concat',
935 [last_f_state_layout_tensor, last_b_state_layout_tensor], [Y_h],
936 axis=y_c_direction_dim)
937
938 Y_f_c = transformer.make_unsqueeze(state_f_c_tensor, axes=[y_c_direction_dim])
939 Y_b_c = transformer.make_unsqueeze(state_b_c_tensor, axes=[y_c_direction_dim])
940 Y_c = outputs[2]
941 transformer.make_node('Concat', [Y_f_c, Y_b_c], [Y_c], axis=y_c_direction_dim)
942
943 Y = outputs[0]
944 transformer.make_node('Concat', state_layout_tensors, [Y], axis=seq_length_dim)
945
946
947def _legalize_LSTM(transformer, tensor_infos, node):
948 """Unroll LSTM operation
949
950 Args:
951 transformer (_ModelTransformerHelper): transformation helper
952 tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
953 node (onnx.onnx_ml_pb2.NodeProto): LSTM operation to unroll
954 """
955 inputs = node.input
956 if len(inputs) > 4 and inputs[4] != '':
957 raise NotImplementedError('Variadic length of output is not supported')
958 # attributes
959 activation_alpha = []
960 activation_beta = []
961 activations = ['Sigmoid', 'Tanh', 'Tanh'] * 2
962 clip = None
963 direction = 'forward'
964 hidden_size = 0
965 input_forget = 0
966 layout = 0
967
968 for attr in node.attribute:
969 if attr.name == 'activation_alpha':
970 activation_alpha = attr.floats
971 if attr.name == 'activation_beta':
972 activation_beta = attr.floats
973 if attr.name == 'activations':
974 activations = list(map(lambda item: item.decode('UTF-8'), list(attr.strings)))
975 if attr.name == 'clip':
976 clip = attr.f
977 if attr.name == 'direction':
978 direction = attr.s.decode('UTF-8')
979 if attr.name == 'hidden_size':
980 hidden_size = attr.i
981 if attr.name == 'input_forget':
982 input_forget = attr.i
983 if attr.name == 'layout':
984 layout = attr.i
985
986 if len(activation_alpha) > 0 or len(activation_beta) > 0:
987 raise NotImplementedError('Unsupported parameters for LSTM activations')
988
989 for act in activations:
990 if act not in ['Relu', 'Tanh', 'Sigmoid']:
991 raise NotImplementedError('Unsupported activation function')
992
993 if input_forget != 0:
994 raise NotImplementedError('Unsupported input_forget attribute value')
995
996 seq_length_dim = layout
997 seq_length = tensor_infos[inputs[0]].shape[seq_length_dim]
998 if hidden_size == 0:
999 hidden_size = tensor_infos[inputs[2]].shape[2]
1000
1001 input_split_tensor = transformer.make_split(inputs[0],
1002 split_sizes=[1] * seq_length,
1003 axis=seq_length_dim)
1004 x = []
1005 for i in range(len(input_split_tensor)):
1006 input_frame_tensor = input_split_tensor[i]
1007 squeezed_frame_tensor = transformer.make_squeeze(input_frame_tensor, axes=[0])
1008 x += [squeezed_frame_tensor]
1009
1010 if direction in ['forward', 'reverse']:
1011 _transform_unidirectional_LSTM(transformer, node, x, tensor_infos, activations,
1012 clip, direction, hidden_size, layout)
1013 elif direction == 'bidirectional':
1014 _transform_bidirectional_LSTM(transformer, node, x, tensor_infos, activations,
1015 clip, hidden_size, layout)
1016 else:
1017 raise RuntimeError('Unknown LSTM type')
1018
1019 transformer.mark_for_deletion(node)
1020
1021
1022def legalize(model, options):
1023 """Replace selected operations in onnx model
1024
1025 Replaces operations, selected by given options with different operation sequences.
1026 For example remove unsupported parts of graph with sequences of supported operations.
1027
1028 Note that graph is changes inplace.
1029
1030 Args:
1031 model (onnx.onnx_ml_pb2.ModelProto): target model
1032 options (LegalizeOptions):
1033 """
1034 tensor_infos = _get_tensor_infos(model)
1035
1036 transformer = _ModelTransformerHelper(model)
1037
1038 node_id = 0
1039 while node_id < len(model.graph.node):
1040 node = model.graph.node[node_id]
1041 if node.op_type == 'RNN' and options.unroll_rnn:
1042 # opset version is required by split operation
1043 if model.opset_import[0].version >= 13:
1044 raise NotImplementedError(
1045 'Can not generate code with opcode version 13 and greater')
1046 transformer.set_insert_id(node_id)
1047 _legalize_RNN(transformer, tensor_infos, node)
1048 node_id = transformer.get_insert_id()
1049 elif node.op_type == 'LSTM' and options.unroll_lstm:
1050 if model.opset_import[0].version >= 13:
1051 raise NotImplementedError(
1052 'Can not generate code with opcode version 13 and greater')
1053 transformer.set_insert_id(node_id)
1054 _legalize_LSTM(transformer, tensor_infos, node)
1055 node_id = transformer.get_insert_id()
1056 node_id += 1
1057
1058 transformer.delete_marked_nodes()
1059
1060
1061if __name__ == '__main__':
1062 if len(sys.argv) < 3:
1063 print(
1064 'usage: ./legalize_onnx.py <path to input model> <path to output model>\n'
1065 '\n'
1066 ' In stand-alone utility mode this tool provides basic funtionality\n'
1067 ' If you want to have more control over applied transformations, use this legalizer as a library'
1068 )
1069 exit(1)
1071 options.unroll_lstm = True
1072 options.unroll_rnn = True
1073 model = onnx.load(sys.argv[1])
1074 legalize(model, options)
1075 onnx.save(model, sys.argv[2])
make_node(self, opcode, inputs, outputs, *p_args, **k_args)
make_constant_tensor(self, tensor_data, base_name)
make_gemm(self, A, B, C, trans_a=False, trans_b=False)
make_split(self, input, split_sizes, axis)
__init__(self, dtype, shape)
_legalize_LSTM(transformer, tensor_infos, node)
_transform_unidirectional_RNN(transformer, original_node, x, tensor_infos, activation, clip, direction, hidden_size, layout)
_transform_bidirectional_LSTM(transformer, original_node, x, tensor_infos, activations, clip, hidden_size, layout)
_legalize_RNN(transformer, tensor_infos, node)
_transform_bidirectional_RNN(transformer, original_node, x, tensor_infos, activations, clip, hidden_size, layout)
_transform_unidirectional_LSTM(transformer, original_node, x, tensor_infos, activations, clip, direction, hidden_size, layout)
_generate_one_direction_RNN(transformer, X, W, R, B, initial_h, clip, activation_name)
_generate_one_direction_LSTM(transformer, X, W, R, B, initial_h, initial_c, P, clip, act, dtype, hidden_size, batch_size)
legalize(model, options)