Skip to content

Commit

Permalink
Add support for masking and groups in QConv2D.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726258359
Change-Id: I565b1fec70fa3007246111725aeff936f54a1c0d
  • Loading branch information
Akshaya Purohit authored and copybara-github committed Feb 13, 2025
1 parent a1fec24 commit 9477f2e
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 36 deletions.
127 changes: 91 additions & 36 deletions qkeras/qconvolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import warnings

import numpy as np
import tensorflow as tf
from tensorflow.keras import constraints
from tensorflow.keras import initializers
Expand All @@ -26,19 +28,19 @@
from tensorflow.keras.layers import Conv1D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import SeparableConv1D
from tensorflow.keras.layers import SeparableConv2D
from tensorflow.keras.layers import DepthwiseConv2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import InputSpec
from tensorflow.python.eager import context
from tensorflow.python.ops import array_ops
# from tensorflow.python.ops import array_ops
from tensorflow.keras.layers import SeparableConv1D
from tensorflow.keras.layers import SeparableConv2D

from .qlayers import get_auto_range_constraint_initializer
from .qlayers import QActivation
from .quantizers import get_quantized_initializer
from .quantizers import get_quantizer

from tensorflow.python.eager import context
from tensorflow.python.ops import array_ops
# from tensorflow.python.ops import array_ops
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer


Expand Down Expand Up @@ -260,32 +262,36 @@ class QConv2D(Conv2D, PrunableLayer):
# can go over [-1,+1], these values are used to set the clipping
# value of kernels and biases, respectively, instead of using the
# constraints specified by the user.
# mask: Optional mask for kernel weights.
#
# we refer the reader to the documentation of Conv2D in Keras for the
# other parameters.
#

def __init__(self,
filters,
kernel_size,
strides=(1, 1),
padding="valid",
data_format="channels_last",
dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer="he_normal",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
kernel_range=None,
bias_range=None,
kernel_quantizer=None,
bias_quantizer=None,
**kwargs):
def __init__(
self,
filters,
kernel_size,
strides=(1, 1),
padding="valid",
data_format="channels_last",
dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer="he_normal",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
kernel_range=None,
bias_range=None,
kernel_quantizer=None,
bias_quantizer=None,
mask=None,
**kwargs,
):

if kernel_range is not None:
warnings.warn("kernel_range is deprecated in QConv2D layer.")
Expand Down Expand Up @@ -324,6 +330,20 @@ def __init__(self,
if activation is not None:
activation = get_quantizer(activation)

if mask is not None:
shape = mask.shape
if len(shape) < 2:
raise ValueError(
"Expected shape to have rank at least 2 but provided shape has"
f" rank {len(shape)}"
)
h, w = shape[0], shape[1]
self._mask = np.reshape(
mask, (h, w, 1, 1)
) # Extend the dimension to be 4D.
else:
self._mask = None

super().__init__(
filters=filters,
kernel_size=kernel_size,
Expand All @@ -343,19 +363,44 @@ def __init__(self,
**kwargs
)

def convolution_op(self, inputs, kernel):
return tf.keras.backend.conv2d(
inputs,
kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)

@tf.function(jit_compile=True)
def _jit_compiled_convolution_op(self, inputs, kernel):
return self.convolution_op(inputs, kernel)

def call(self, inputs):
if self.kernel_quantizer:
quantized_kernel = self.kernel_quantizer_internal(self.kernel)
else:
quantized_kernel = self.kernel

outputs = tf.keras.backend.conv2d(
inputs,
quantized_kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if self._mask is not None:
# Apply mask to kernel weights if one is provided.
quantized_kernel = quantized_kernel * self._mask

# Grouped convolutions are not fully supported on the CPU for compiled
# functions.
#
# This is a workaround taken from TF's core library. Remove when proper
# support is added.
# See definition of function "_jit_compiled_convolution_op" at
# cs/third_party/py/tf_keras/layers/convolutional/base_conv.py for more
# details.
if self.groups > 1:
outputs = self._jit_compiled_convolution_op(
inputs, tf.convert_to_tensor(quantized_kernel)
)
else:
outputs = self.convolution_op(inputs, quantized_kernel)

if self.use_bias:
if self.bias_quantizer:
Expand All @@ -364,7 +409,8 @@ def call(self, inputs):
quantized_bias = self.bias

outputs = tf.keras.backend.bias_add(
outputs, quantized_bias, data_format=self.data_format)
outputs, quantized_bias, data_format=self.data_format
)

if self.activation is not None:
return self.activation(outputs)
Expand All @@ -380,10 +426,19 @@ def get_config(self):
),
"kernel_range": self.kernel_range,
"bias_range": self.bias_range,
"mask": self._mask.tolist() if self._mask is not None else None,
}
base_config = super(QConv2D, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config):
mask = config.get("mask")
if mask is not None:
mask = np.array(mask)
config["mask"] = mask
return cls(**config)

def get_quantization_config(self):
return {
"kernel_quantizer":
Expand Down
94 changes: 94 additions & 0 deletions tests/qconvolutional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,5 +308,99 @@ def test_qconv2dtranspose():
[2., 3., 4., 4., 3., 2.] ]).reshape((1,6,6,1)).astype(np.float16)
assert_allclose(actual_output, expected_output, rtol=1e-4)


def test_masked_qconv2d_creates_correct_parameters():
mask = mask = np.ones((5, 5), dtype=np.float32)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(10, 10, 1)))
model.add(QConv2D(mask=mask, filters=1, kernel_size=(5, 5), use_bias=False))

# There should be no non-trainable params.
np.testing.assert_equal(len(model.non_trainable_weights), 0)

# Validate number of trainable params. This should be equal to one (5,5)
# kernel.
np.testing.assert_equal(len(model.trainable_weights), 1)
num_trainable_params = np.prod(model.trainable_weights[0].shape)
np.testing.assert_equal(num_trainable_params, 25)


def test_qconv2d_masks_weights():
# Create an arbitrary mask.
mask = np.array(
[
[1.0, 0.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 1.0, 0.0, 1.0],
],
dtype=np.float32,
)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(5, 5, 1)))
model.add(QConv2D(mask=mask, filters=1, kernel_size=(5, 5), use_bias=False))

# Set the weights to be all ones.
model.layers[0].set_weights([np.ones((5, 5, 1, 1), dtype=np.float32)])

# Run inference on a all ones input.
output = model.predict(np.ones((1, 5, 5, 1), dtype=np.float32))
# Output should just be summation of number of ones in the mask.
np.testing.assert_array_equal(
output, np.array([[[[11.0]]]], dtype=np.float32)
)


def test_masked_qconv2d_load_restore_works():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(10, 10, 1)))
model.add(
QConv2D(
mask=np.ones((5, 5), dtype=np.float32),
filters=1,
kernel_size=(5, 5),
use_bias=False,
)
)

with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, 'model.keras')
# Can save the model.
model.save(model_path)

# Can load the model.
custom_objects = {
'QConv2D': QConv2D,
}
loaded_model = tf.keras.models.load_model(
model_path, custom_objects=custom_objects
)

np.testing.assert_array_equal(
model.layers[0].weights[0], loaded_model.layers[0].weights[0]
)


def test_qconv2d_groups_works():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(10, 10, 10)))
model.add(
QConv2D(
filters=6,
kernel_size=(1, 1),
use_bias=True,
groups=2,
)
)
# Validate number of trainable params.
np.testing.assert_equal(len(model.trainable_weights), 2)
num_trainable_params = np.prod(model.trainable_weights[0].shape) + np.prod(
model.trainable_weights[1].shape
)
expected_trainable_params = 36 # (5*3)*2 + 6
np.testing.assert_equal(num_trainable_params, expected_trainable_params)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 9477f2e

Please sign in to comment.