From 64664f5d858aba28d361556440083ae9e8b30d1c Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 11 Feb 2022 11:09:26 +0300 Subject: [PATCH] [one-cmds] Update onnx_legalizer.py docstrings (#8405) This PR adds more comments with some references to onnx documents. ONE-DCO-1.0-Signed-off-by: Alexander Efimov --- compiler/one-cmds/onnx_legalizer.py | 274 ++++++++++++++++++++++------ 1 file changed, 216 insertions(+), 58 deletions(-) diff --git a/compiler/one-cmds/onnx_legalizer.py b/compiler/one-cmds/onnx_legalizer.py index 8b8dfda4b5c..186ffb1d7dd 100755 --- a/compiler/one-cmds/onnx_legalizer.py +++ b/compiler/one-cmds/onnx_legalizer.py @@ -21,12 +21,26 @@ import re # Transform onnx model to make it compilable with our toolchain +# +# This code works with onnx model in proto format. See proto buffers format in +# https://github.com/onnx/onnx/blob/96516aecd4c110b0ac57eba08ac236ebf7205728/onnx/onnx.proto3 +# +# More examples of handling onnx models could be found here: +# https://github.com/onnx/onnx/tree/96516aecd4c110b0ac57eba08ac236ebf7205728/onnx/examples +# # List of transformations: # - Replace RNN operation with unrolled subgraph # - Replace LSTM operation with unrolled subgraph class LegalizeOptions: + """Controls transformations that legalizer apply + + Attributes: + unroll_rnn (bool): default is False. If True - unrolls RNN operations + unroll_lstm (bool): default is False. If True - unrolls LSTM operations + """ + unroll_rnn = False unroll_lstm = False @@ -36,6 +50,20 @@ def reverse_str(s): def parse_tensor_name(name): + """Splits tensor name to base part and serial number + + Most of tensor names have following format: "tensor_123". + This function breaks name into two values: "tensor_" and 123. + Tensor names like this: "321" are broken into "" and 321. + + Serial number is used to create unique tensor names using given base name. + + Args: + name (str): tensor name + + Returns: + tuple of str, int: base name and serial number of tensor + """ rev = reverse_str(name) m = re.match('(\d*)(.*)', rev) if m.groups()[0] != '': @@ -45,6 +73,20 @@ def parse_tensor_name(name): class ModelTransformerHelper: + """Helper for onnx model transformation + + This helper is used for convenient operation replacement in onnx model + + Attributes: + _model (onnx.onnx_ml_pb2.ModelProto): target model that should be changed + _nodes_to_delete (list of onnx.onnx_ml_pb2.NodeProto): list of replaced operations + _insert_id (int): position to insert created operations (should be in topologically sorted) + _base_name_idx (dict from str to int): maps tensor "base" name to + largest existing serial num. For example model has tensors "t_1", "t_2", "t_4", + in that case _base_name_idx["t_"] == 4. + This attribute is used for unique tensor name generation. + """ + def __init__(self, model): self._model = model self._nodes_to_delete = [] @@ -66,6 +108,14 @@ def __init__(self, model): self._base_name_idx[base_name] = number def make_tensor_with_base_name(self, base_name): + """ Create unique name for given base_name + + Args: + base_name (str): base tensor name + + Returns: + str : unique tensor name that starts with base_name + """ if base_name in self._base_name_idx: self._base_name_idx[base_name] += 1 return base_name + str(self._base_name_idx[base_name]) @@ -74,15 +124,18 @@ def make_tensor_with_base_name(self, base_name): return base_name + '0' def make_node(self, opcode, inputs, outputs, *p_args, **k_args): - """ - Create arbitrary node and insert it in graph. + """Create arbitrary node and insert it in graph. - Parameters: + Args: opcode (str): opcode name of desired operation inputs (list of str): names of input tensors - outputs (list of str or int): names of output tensors or number of tensors that should be created + outputs (list of str or int): names of existing tensors to use as output tensors for operation or + number of tensors that should be created p_args: additional arguments for onnx make_node helper k_args: attributes for onnx node + + Returns: + list of str: list of output tensor names """ if type(outputs) == int: outputs = [self.make_tensor_with_base_name('') for i in range(outputs)] @@ -93,61 +146,57 @@ def make_node(self, opcode, inputs, outputs, *p_args, **k_args): return outputs def make_split(self, input, split_sizes, axis): - '''Create Split operation and insert it in graph. + """Create Split operation and insert it in graph. Args: input (str): name of input tensor - split_sizes (list): list of split sizes + split_sizes (list of int): list of split sizes axis (int): number of axis to split Returns: list: list of output tensor names - - ''' + """ return self.make_node( 'Split', [input], len(split_sizes), axis=axis, split=split_sizes) def make_concat(self, inputs, axis): - '''Create Concat operation and insert it in graph. + """Create Concat operation and insert it in graph. Args: - inputs (list): list of tensors names to concat + inputs (list of str): list of tensors names to concat axis (int): axis number to concat Returns: - str:: output tensor name - - ''' + str: output tensor name + """ return self.make_node('Concat', inputs, 1, axis=axis)[0] def make_squeeze(self, input, axes): - '''Create Squeeze operation and insert it in graph. + """Create Squeeze operation and insert it in graph. Args: input (str): name of input tensor - axes (list): list of dimension containing ones to remove + axes (list of int): list of dimension containing ones to remove Returns: str: output tensor name - - ''' + """ return self.make_node('Squeeze', [input], 1, axes=axes)[0] def make_unsqueeze(self, input, axes): - '''Create Unsqueeze operation and insert it in graph. + """Create Unsqueeze operation and insert it in graph. Args: input (str): name of input tensor - axes (list): list of dimension to insert ones + axes (list of int): list of dimension to insert ones Returns: str: output tensor name - - ''' + """ return self.make_node('Unsqueeze', [input], 1, axes=axes)[0] def make_gemm(self, A, B, C, trans_a=False, trans_b=False): - '''Create Gemm operation and insert it in graph. + """Create Gemm operation and insert it in graph. Result tensor contains A*B + C @@ -160,13 +209,12 @@ def make_gemm(self, A, B, C, trans_a=False, trans_b=False): Returns: str: output tensor name - - ''' + """ return self.make_node( 'Gemm', [A, B, C], 1, transA=bool(trans_a), transB=bool(trans_b))[0] def make_add(self, a, b): - '''Creates Add operation and insert it in graph. + """Creates Add operation and insert it in graph. Args: a (str): name of left operand tensor @@ -174,12 +222,11 @@ def make_add(self, a, b): Returns: str: output tensor name - - ''' + """ return self.make_node('Add', [a, b], 1)[0] def make_mul(self, a, b): - '''Creates Mul operation and insert it in graph. + """Creates Mul operation and insert it in graph. Args: a (str): name of left operand tensor @@ -187,12 +234,11 @@ def make_mul(self, a, b): Returns: str: output tensor name - - ''' + """ return self.make_node('Mul', [a, b], 1)[0] def make_clip(self, input, min, max): - '''Create Clip operation and insert it in graph. + """Create Clip operation and insert it in graph. Args: input (str): input tensor name @@ -201,12 +247,11 @@ def make_clip(self, input, min, max): Returns: str: output tensor name - - ''' + """ return self.make_node('Clip', [input], 1, min=min, max=max)[0] def make_act(self, input, act_name): - '''Create activation function operation and insert it in graph. + """Create activation function operation and insert it in graph. Args: input (str): input tensor name @@ -214,12 +259,20 @@ def make_act(self, input, act_name): Returns: str: output tensor name - - ''' + """ assert (act_name in ['Relu', 'Tanh', 'Sigmoid']) return self.make_node(act_name, [input], 1)[0] def make_constant_tensor(self, tensor_data, base_name): + """Creates onnx constant tensor + + Args: + tensor_data (numpy.ndarray): tensor data + base_name (str): prefix of constant tensor name + + Returns: + str: name of created constant tensor + """ tensor = onnx.numpy_helper.from_array(tensor_data) tensor.name = self.make_tensor_with_base_name(base_name) self._model.graph.initializer.append(tensor) @@ -246,6 +299,14 @@ def __init__(self, dtype, shape): def get_tensor_infos(model): + """Infer tensor shapes and dtypes + Args: + model (onnx.onnx_ml_pb2.ModelProto): model to process + + Returns: + dict from str to Info: maps tensor name to shape and dtype information + """ + inferred_shape_model = onnx.shape_inference.infer_shapes(model) infos = {} @@ -262,6 +323,18 @@ def get_tensor_infos(model): def dtype_to_np(dtype): + """Convert onnx dtype value to numpy dtype class + + For more types see: + https://github.com/onnx/onnx/blob/96516aecd4c110b0ac57eba08ac236ebf7205728/onnx/onnx.proto3#L484 + + Args: + dtype (int): onnx dtype + + Returns: + numpy data type: numpy dtype, like np.float32 + """ + if dtype == 1: return np.float32 else: @@ -269,21 +342,23 @@ def dtype_to_np(dtype): def generate_one_direction_RNN(transformer, X, W, R, B, initial_h, clip, activation_name): - """ - This function generates subgraph that represents one direction of unrolled RNN layer - - Parameters: - transformer (ModelTransformerHelper): helper for model generation - X (list of str): names of input tensors in sequence. Tensor shapes: [batch_size, input_size]. - W (list of str): name of weight tensor - R (list of str): name of recurrence weight tensor - B (list of str): name of bias tensor - initial_h (str or None): name of tensor containing initial hidden state. Shape [batch_size, hidden_size] - clip (float or None): range which clips input of activations - act (str): activation function + """Generate subgraph of one direction of unrolled RNN layer + + Args: + transformer (ModelTransformerHelper): helper for model generation + X (list of str): names of input tensors in sequence. Tensor shapes: [batch_size, input_size]. + W (str): name of weight tensor + R (str): name of recurrence weight tensor + B (str): name of bias tensor + initial_h (str or None): name of tensor containing initial hidden state. Shape [batch_size, hidden_size] + clip (float or None): range which clips input of activations + act (str): activation function """ # one direction RNN: # + # For details see: + # https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Changelog.md#RNN-7 + # # H = f(X*(W^T) + h*(R^T) + B) # # H - new hidden state @@ -321,6 +396,21 @@ def generate_one_direction_RNN(transformer, X, W, R, B, initial_h, clip, activat def transform_unidirectional_RNN(transformer, original_node, x, tensor_infos, activations, clip, direction, hidden_size, layout): + """Generate Simple (forward or reverse) unrolled RNN + + Args: + transformer (ModelTransformerHelper): transformation helper + original_node (onnx.onnx_ml_pb2.NodeProto): unidirectional RNN operation to unroll + x (list of str): list of input tensors (input tensor split along "time" dimension) + tensor_infos (dict from str to Info): dict maps tensor name to it's shape and dtype info + activations (list of str): list containing name of activation function in first element + clip (float or None): range which clips input of activations + direction (str): "forward" or "reverse" + hidden_size (int): size of hidden state + layout (int): See attribute description: + https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-56 + """ + inputs = original_node.input outputs = original_node.output if direction == 'reverse': @@ -362,6 +452,20 @@ def transform_unidirectional_RNN(transformer, original_node, x, tensor_infos, ac def transform_bidirectional_RNN(transformer, original_node, x, tensor_infos, activations, clip, hidden_size, layout): + """Generate Bidirectional unrolled RNN + + Args: + transformer (ModelTransformerHelper): transformation helper + original_node (onnx.onnx_ml_pb2.NodeProto): bidirectional RNN operation to unroll + x (list of str): list of input tensors (input tensor split along "time" dimension) + tensor_infos (dict from str to Info): dict maps tensor name to it's shape and dtype info + activations (list of str): list of len (2) containing names of forward and reverse activations + clip (float or None): range which clips input of activations + hidden_size (int): size of hidden state + layout (int): See attribute description: + https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-56 + """ + inputs = original_node.input outputs = original_node.output w_bi = transformer.make_split(inputs[1], split_sizes=[1, 1], axis=0) @@ -436,6 +540,13 @@ def transform_bidirectional_RNN(transformer, original_node, x, tensor_infos, act def legalize_RNN(transformer, tensor_infos, node): + """Unroll RNN operation + + Args: + transformer (ModelTransformerHelper): transformation helper + tensor_infos (dict from str to Info): dict maps tensor name to it's shape and dtype info + node (onnx.onnx_ml_pb2.NodeProto): RNN operation to unroll + """ inputs = node.input outputs = node.output if len(inputs) > 4 and inputs[4] != '': @@ -497,27 +608,27 @@ def legalize_RNN(transformer, tensor_infos, node): def generate_one_direction_LSTM(transformer, X, W, R, B, initial_h, initial_c, P, clip, act, dtype, hidden_size, batch_size): - """ - This function generates subgraph that represents one direction of unrolled LSTM layer + """Generate subgraph for one direction of unrolled LSTM layer - Parameters: + Args: transformer (ModelTransformerHelper): helper for model generation - tensor_infos (dict of Info): shapes and dtypes of tensors - X (list of str): names of input tensors in sequence. Tensor shapes: [batch_size, input_size] - W (list of str): name of concatenated weight tensor: [input, output, forget, cell] - R (list of str): name of concatenated recurrence weights tensor: [input, output, forget, cell] - B (list of str): name of concatenated bias tensor: [input, output, forget, cell] + X (list of str): names of tensors in input sequence. Each tensor shape: [batch_size, input_size] + W (str): name of concatenated weight tensor: [input, output, forget, cell] + R (str): name of concatenated recurrence weights tensor: [input, output, forget, cell] + B (str): name of concatenated bias tensor: [input, output, forget, cell] initial_h (str or None): name of tensor containing initial hidden state. Shape [batch_size, hidden_size] initial_c (str or None): name of tensor containing initial cell state. Shape [batch_size, hidden_size] - P (list of str): name of concatenated peephole tensor: [input, output, forget] + P (str or None): name of concatenated peephole tensor: [input, output, forget] clip (float or None): range which clips input of activations act (dict of str): activation functions {'f': 'Sigmoid', 'g': 'Tanh', 'h': 'Tanh'} + dtype (numpy dtype): data type used in created LSTM operation hidden_size (int): hidden dimension batch_size (int): batch dimension """ # one direction LSTM: # - # From onnx Operators.onnx + # For details see: + # https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Changelog.md#LSTM-7 # # it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) # ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) @@ -640,6 +751,21 @@ def generate_one_direction_LSTM(transformer, X, W, R, B, initial_h, initial_c, P def transform_unidirectional_LSTM(transformer, original_node, x, tensor_infos, activations, clip, direction, hidden_size, layout): + """Generate Simple (forward or reverse) unrolled LSTM + + Args: + transformer (ModelTransformerHelper): transformation helper + original_node (onnx.onnx_ml_pb2.NodeProto): unidirectional LSTM operation to unroll + x (list of str): list of input tensors (input tensor split along "time" dimension) + tensor_infos (dict from str to Info): dict maps tensor name to it's shape and dtype info + activations (list of str): list of length 3 containing names of activation functions + clip (float or None): range which clips input of activations + direction (str): "forward" or "reverse" + hidden_size (int): size of hidden state + layout (int): See attribute description: + https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-37 + """ + inputs = original_node.input outputs = original_node.output if direction == 'reverse': @@ -696,6 +822,20 @@ def transform_unidirectional_LSTM(transformer, original_node, x, tensor_infos, def transform_bidirectional_LSTM(transformer, original_node, x, tensor_infos, activations, clip, hidden_size, layout): + """Generate Bidirectional unrolled LSTM + + Args: + transformer (ModelTransformerHelper): transformation helper + original_node (onnx.onnx_ml_pb2.NodeProto): bidirectional LSTM operation to unroll + x (list of str): list of input tensors (input tensor split along "time" dimension) + tensor_infos (dict from str to Info): dict maps tensor name to it's shape and dtype info + activations (list of str): list of length 6, containing names of forward and reverse activations + clip (float or None): range which clips input of activations + hidden_size (int): size of hidden state + layout (int): See attribute description: + https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-37 + """ + inputs = original_node.input outputs = original_node.output @@ -787,6 +927,13 @@ def transform_bidirectional_LSTM(transformer, original_node, x, tensor_infos, ac def legalize_LSTM(transformer, tensor_infos, node): + """Unroll LSTM operation + + Args: + transformer (ModelTransformerHelper): transformation helper + tensor_infos (dict from str to Info): dict maps tensor name to it's shape and dtype info + node (onnx.onnx_ml_pb2.NodeProto): LSTM operation to unroll + """ inputs = node.input outputs = node.output if len(inputs) > 4 and inputs[4] != '': @@ -853,6 +1000,17 @@ def legalize_LSTM(transformer, tensor_infos, node): def legalize(model, options): + """Replace selected operations in onnx model + + Replaces operations, selected by given options with different operation sequences. + For example remove unsupported parts of graph with sequences of supported operations. + + Note that graph is changes inplace. + + Args: + model (onnx.onnx_ml_pb2.ModelProto): target model + options (LegalizeOptions): + """ tensor_infos = get_tensor_infos(model) transformer = ModelTransformerHelper(model)