-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBinaryEmbeddings.py
115 lines (98 loc) · 3.66 KB
/
BinaryEmbeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import tensorflow as tf
from Utils.utils import CFakeObject
from NN.utils import normVec
class CBinaryEmbeddings(tf.keras.layers.Layer):
def __init__(self, input_dim, output_dim, name, **kwargs):
super().__init__(name=name, **kwargs)
assert input_dim == 256, 'Only 256 input_dim is supported'
assert output_dim == 8, 'Only 8 output_dim is supported'
self._N = input_dim
self.output_dim = output_dim
self._embeddings = tf.Variable(
initial_value=self._initEmbeddings(),
trainable=False,
name='%s/embeddings' % name
)
# self._scaleProbabilities = tf.Variable(
# initial_value=tf.zeros((1,)),
# trainable=True,
# name='%s/scaleProbabilities' % name
# )
return
def _initEmbeddings(self):
x = tf.range(256, dtype=tf.int32)
x = tf.reshape(x, (256, 1))
# To get the binary representation in an array format, you can use binary expansion
x_expanded = tf.reshape(
tf.stack([tf.bitwise.right_shift(x, i) & 1 for i in range(8)], axis=-1),
(256, 8)
)
x = tf.cast(x_expanded, tf.float32)
# Scale from 0 to 1 to -1 to 1
x = x * 2.0 - 1.0
tf.assert_equal(tf.shape(x), (256, 8))
return x
def normalize(self, x):
V, L = normVec(x)
L = tf.clip_by_value(L, clip_value_min=1e-6, clip_value_max=1.0)
return V * L
@property
def embeddings(self):
res = self._embeddings
return res
def call(self, inputs):
B = tf.shape(inputs)[0]
tf.assert_equal(tf.shape(inputs), (B, ))
res = tf.gather(self.embeddings, inputs)
tf.assert_equal(tf.shape(res), (B, self.output_dim))
return res
def _score(self, x):
x = self.normalize(x)
# Ensure `x` is 2D: [batch_size, num_features]
B = tf.shape(x)[0]
tf.assert_equal(tf.shape(x), (B, self.output_dim))
embeddings = self.embeddings
dot_product = tf.matmul(x, embeddings, transpose_b=True) # [B, N]
tf.assert_equal(tf.shape(dot_product), (B, self._N))
embLen = tf.reduce_sum(embeddings ** 2, axis=-1, keepdims=True)
embLen = tf.transpose(embLen)
tf.assert_equal(tf.shape(embLen), (1, self._N))
xLen = tf.reduce_sum(x ** 2, axis=-1, keepdims=True)
tf.assert_equal(tf.shape(xLen), (B, 1))
distance = embLen + xLen - 2 * dot_product
distance = tf.maximum(distance, 0.0)
tf.assert_equal(tf.shape(distance), (B, self._N))
scale = -1. #tf.nn.softplus(self._scaleProbabilities)
res = tf.nn.softmax(distance * scale, axis=-1)
tf.assert_equal(tf.shape(res), (B, self._N))
return res
def separability(self):
return 0.0
@tf.function
def loss(self, x, target):
B = tf.shape(x)[0]
tf.assert_equal(tf.shape(x), (B, self.output_dim))
tf.assert_equal(tf.shape(target), (B, 1))
scores = self._score(x)
tf.assert_equal(tf.shape(scores), (B, self._N))
res = tf.losses.sparse_categorical_crossentropy(target, scores)
return res
def encode(self, color):
# color is in range [-1, 1], single value
N = tf.size(color)
x = tf.reshape(color, (N, 1))
x = tf.clip_by_value(x, clip_value_min=-1.0, clip_value_max=1.0)
x = (x + 1.0) / 2.0 # [0, 1]
idx = tf.cast(x * self._N, tf.int32)
idx = tf.clip_by_value(idx, clip_value_min=0, clip_value_max=self._N - 1)
return CFakeObject(indices=idx, embeddings=self(idx[:, 0]))
def decode(self, x):
B = tf.shape(x)[0]
tf.assert_equal(tf.shape(x), (B, self.output_dim))
scores = self._score(x)
tf.assert_equal(tf.shape(scores), (B, self._N))
idx = tf.argmax(scores, axis=-1)[..., None]
idx = tf.cast(idx, tf.int32)
x = tf.cast(idx, tf.float32) / self._N
x = x * 2.0 - 1.0
return CFakeObject(values=x, indices=idx)