Skip to content

Commit

Permalink
Dragan: random uniform fix (openvinotoolkit#8716)
Browse files Browse the repository at this point in the history
* fix random uniform input constant layout

* added test for 4D
  • Loading branch information
sadolini authored Nov 26, 2021
1 parent 7a16de2 commit 5b3c3bc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
3 changes: 3 additions & 0 deletions model-optimizer/extensions/ops/random_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from mo.graph.graph import Graph, Node
from mo.graph.perm_inputs import PermuteInputs
from mo.middle.passes.convert_data_type import np_data_type_to_destination_type
from mo.ops.op import Op

Expand Down Expand Up @@ -51,6 +52,8 @@ def infer(node: Node):
node.in_node(1)['correct_data_type'] = True
node.in_node(2)['correct_data_type'] = True

PermuteInputs().set_input_permutation(node.in_node(0), node, 'output:0', 'shape')


class AttributedRandomUniform(Op):
""" RandomUniform operation that generates a sequence of random values from uniform distribution.
Expand Down
3 changes: 2 additions & 1 deletion tests/layer_tests/tensorflow_tests/test_tf_RandomUniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_tf_random_uniform_net(self, global_seed, op_seed, x_shape, min_val, m
x = tf.compat.v1.placeholder(input_type, tf_x_shape, 'Input')
if global_seed is not None:
tf.compat.v1.random.set_random_seed(global_seed)
random_uniform = tf.random.uniform(x_shape, seed=op_seed, dtype=input_type, minval=min_val,
random_uniform = tf.random.uniform(tf_x_shape, seed=op_seed, dtype=input_type, minval=min_val,
maxval=max_val) + x

tf.compat.v1.global_variables_initializer()
Expand Down Expand Up @@ -80,6 +80,7 @@ def create_tf_random_uniform_net(self, global_seed, op_seed, x_shape, min_val, m
dict(global_seed=32465, op_seed=48971, min_val=0.0, max_val=1.0, x_shape=[3, 7], input_type=tf.float32),
marks=pytest.mark.precommit),
dict(global_seed=None, op_seed=56197, min_val=-100, max_val=100, x_shape=[6], input_type=tf.float32),
dict(global_seed=None, op_seed=56197, min_val=-100, max_val=100, x_shape=[1, 2, 1, 1], input_type=tf.float32),
dict(global_seed=78132, op_seed=None, min_val=-200, max_val=-50, x_shape=[5, 8], input_type=tf.int32),
dict(global_seed=4571, op_seed=48971, min_val=1.5, max_val=2.3, x_shape=[7], input_type=tf.float32),
dict(global_seed=32465, op_seed=12335, min_val=-150, max_val=-100, x_shape=[18], input_type=tf.int32)]
Expand Down

0 comments on commit 5b3c3bc

Please sign in to comment.