forked from twitter/the-algorithm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsparse_inputs.py
24 lines (19 loc) · 966 Bytes
/
sparse_inputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
import tensorflow.compat.v1 as tf
def create_sparse_tensor(batch_size, input_size, num_values, dtype=tf.float32):
random_indices = np.sort(np.random.randint(batch_size * input_size, size=num_values))
test_indices_i = random_indices // input_size
test_indices_j = random_indices % input_size
test_indices = np.stack([test_indices_i, test_indices_j], axis=1)
test_values = np.random.random(num_values).astype(dtype.as_numpy_dtype)
return tf.SparseTensor(indices=tf.constant(test_indices),
values=tf.constant(test_values),
dense_shape=(batch_size, input_size))
def create_reference_input(sparse_input, use_binary_values):
if use_binary_values:
sp_a = tf.SparseTensor(indices=sparse_input.indices,
values=tf.ones_like(sparse_input.values),
dense_shape=sparse_input.dense_shape)
else:
sp_a = sparse_input
return sp_a