# Copyright 2019 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""Definition of normalization quantization package."""
import warnings
import keras
import keras.ops.numpy as knp
from keras import constraints, initializers, layers, regularizers
from keras.saving import register_keras_serializable
from keras.utils import serialize_keras_object
from .ops_portable import constant_bool_value
from .qlayers import get_auto_range_constraint_initializer, get_quantizer
[docs]
@register_keras_serializable(package="qkeras")
class QBatchNormalization(layers.BatchNormalization):
"""Quantized Batch Normalization layer.
For training, mean and variance are not quantized.
For inference, the quantized moving mean and moving variance are used.
output = (x - mean) / sqrt(var + epsilon) * quantized_gamma + quantized_beta
"""
def __init__(
self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
center=True,
scale=True,
activation=None,
beta_initializer="zeros",
gamma_initializer="ones",
moving_mean_initializer="zeros",
moving_variance_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_quantizer="quantized_po2(5)",
gamma_quantizer="quantized_relu_po2(6, 2048)",
mean_quantizer="quantized_po2(5)",
variance_quantizer="quantized_relu_po2(6, quadratic_approximation=True)",
inverse_quantizer=None,
gamma_constraint=None,
beta_constraint=None,
# use quantized_po2 and enforce quadratic approximation
# to get an even exponent for sqrt
beta_range=None,
gamma_range=None,
**kwargs,
):
if gamma_range is not None:
warnings.warn("gamma_range is deprecated in QBatchNormalization layer.")
if beta_range is not None:
warnings.warn("beta_range is deprecated in QBatchNormalization layer.")
self.gamma_range = gamma_range
self.beta_range = beta_range
self.activation = activation
self.beta_quantizer = beta_quantizer
self.gamma_quantizer = gamma_quantizer
self.mean_quantizer = mean_quantizer
self.variance_quantizer = variance_quantizer
self.inverse_quantizer = inverse_quantizer
if self.inverse_quantizer is not None:
assert self.variance_quantizer is None and self.gamma_quantizer is None, (
"If using the inverse quantizer, the gamma and variance quantizers "
"should not be used in order to avoid quantizing a value twice."
)
self.beta_quantizer_internal = get_quantizer(self.beta_quantizer)
self.gamma_quantizer_internal = get_quantizer(self.gamma_quantizer)
self.mean_quantizer_internal = get_quantizer(self.mean_quantizer)
self.variance_quantizer_internal = get_quantizer(self.variance_quantizer)
self.inverse_quantizer_internal = get_quantizer(self.inverse_quantizer)
if hasattr(self.gamma_quantizer_internal, "_set_trainable_parameter"):
self.gamma_quantizer_internal._set_trainable_parameter()
if hasattr(self.variance_quantizer_internal, "_set_trainable_parameter"):
self.variance_quantizer_internal._set_trainable_parameter()
self.quantizers = [
self.gamma_quantizer_internal,
self.beta_quantizer_internal,
self.mean_quantizer_internal,
self.variance_quantizer_internal,
self.inverse_quantizer_internal,
]
if scale and self.gamma_quantizer:
gamma_constraint, gamma_initializer = get_auto_range_constraint_initializer(
self.gamma_quantizer_internal, gamma_constraint, gamma_initializer
)
if center and self.beta_quantizer:
beta_constraint, beta_initializer = get_auto_range_constraint_initializer(
self.beta_quantizer_internal, beta_constraint, beta_initializer
)
if kwargs.get("fused", None):
warnings.warn(
"batch normalization fused is disabled " "in qkeras qnormalization.py."
)
del kwargs["fused"]
if kwargs.get("renorm", None):
warnings.warn(
"batch normalization renorm is disabled " "in qkeras qnormalization.py."
)
del kwargs["renorm"]
if kwargs.get("virtual_batch_size", None):
warnings.warn(
"batch normalization virtual_batch_size is disabled "
"in qkeras qnormalization.py."
)
del kwargs["virtual_batch_size"]
if kwargs.get("adjustment", None):
warnings.warn(
"batch normalization adjustment is disabled "
"in qkeras qnormalization.py."
)
del kwargs["adjustment"]
super().__init__(
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
moving_mean_initializer=moving_mean_initializer,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=beta_constraint,
gamma_constraint=gamma_constraint,
**kwargs,
)
[docs]
def call(self, inputs, training=False):
if self.scale and self.gamma_quantizer:
quantized_gamma = self.gamma_quantizer_internal(self.gamma)
else:
quantized_gamma = self.gamma
if self.center and self.beta_quantizer:
quantized_beta = self.beta_quantizer_internal(self.beta)
else:
quantized_beta = self.beta
if self.mean_quantizer:
quantized_moving_mean = self.mean_quantizer_internal(self.moving_mean)
else:
quantized_moving_mean = self.moving_mean
if self.variance_quantizer:
quantized_moving_variance = self.variance_quantizer_internal(
self.moving_variance
)
else:
quantized_moving_variance = self.moving_variance
# Compute the axes along which to reduce the mean / variance
input_shape = inputs.shape
ndims = len(input_shape)
axis = self.axis if isinstance(self.axis, (list, tuple)) else [self.axis]
reduction_axes = [i for i in range(ndims) if i not in axis]
# Broadcasting only necessary for single-axis batch norm where the axis is
# not the last dimension
broadcast_shape = [1] * ndims
broadcast_shape[axis[0]] = input_shape.dims[axis[0]].value
def _broadcast(v):
if (
v is not None
and len(v.shape) != ndims
and reduction_axes != list(range(ndims - 1))
):
return keras.ops.broadcast_to(v, broadcast_shape)
return v
scale, offset = _broadcast(quantized_gamma), _broadcast(quantized_beta)
# Determine a boolean value for `training`: could be True, False, or None.
training_value = constant_bool_value(training)
if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
quantized_mean, quantized_variance = (
quantized_moving_mean,
quantized_moving_variance,
)
else:
# Some of the computations here are not necessary when training==False
# but not a constant. However, this makes the code simpler.
keep_dims = len(axis) > 1
mean, variance = keras.ops.moments(
keras.ops.cast(inputs, self.compute_dtype),
axes=reduction_axes,
keepdims=keep_dims,
)
moving_mean = self.moving_mean
moving_variance = self.moving_variance
mean = keras.ops.where(
training, mean, keras.ops.convert_to_tensor(moving_mean)
)
variance = keras.ops.where(
training,
variance,
keras.ops.convert_to_tensor(moving_variance)
)
new_mean, new_variance = mean, variance
if self.mean_quantizer:
quantized_mean = self.mean_quantizer_internal(mean)
else:
quantized_mean = mean
if self.variance_quantizer:
quantized_variance = self.variance_quantizer_internal(variance)
else:
quantized_variance = variance
inputs_size = keras.ops.where(
knp.equal(keras.ops.shape(inputs)[0], 0),
knp.size(inputs),
-1,
)
def _do_update(var, value):
"""Compute the updates for mean and variance."""
return self._assign_moving_average(
var, value, self.momentum, inputs_size
)
def mean_update():
true_branch = _do_update(self.moving_mean, new_mean)
false_branch = self.moving_mean
return keras.ops.where(keras.ops.cast(training, bool), true_branch, false_branch)
def variance_update():
"""Update the moving variance."""
true_branch = _do_update(self.moving_variance, new_variance)
false_branch = self.moving_variance
return keras.ops.where(keras.ops.cast(training, bool), true_branch, false_branch)
moving_mean_assign = self.moving_mean.assign(
self.moving_mean * self.momentum + mean * (1.0 - self.momentum)
)
moving_variance_assign = self.moving_variance.assign(
self.moving_variance * self.momentum + variance * (1.0 - self.momentum)
)
quantized_mean = _broadcast(keras.ops.cast(quantized_mean, inputs.dtype))
quantized_variance = _broadcast(keras.ops.cast(quantized_variance, inputs.dtype))
if offset is not None:
offset = keras.ops.cast(offset, inputs.dtype)
if scale is not None:
scale = keras.ops.cast(scale, inputs.dtype)
# Calculate and quantize the inverse
inv = 1 / keras.ops.sqrt(quantized_variance + self.epsilon)
if scale is not None:
inv *= scale
if self.inverse_quantizer_internal is not None:
inv = self.inverse_quantizer_internal(inv)
# Calculate the forward pass of the BN
outputs = inputs * keras.ops.cast(inv, inputs.dtype) + keras.ops.cast(
offset - quantized_mean * inv
if offset is not None
else -quantized_mean * inv,
inputs.dtype,
)
# If some components of the shape got lost due to adjustments, fix that.
outputs = keras.ops.reshape(outputs, keras.ops.shape(inputs))
return outputs
[docs]
def get_config(self):
config = {
"axis": self.axis,
"momentum": self.momentum,
"epsilon": self.epsilon,
"center": self.center,
"scale": self.scale,
"beta_quantizer": serialize_keras_object(self.beta_quantizer_internal),
"gamma_quantizer": serialize_keras_object(self.gamma_quantizer_internal),
"mean_quantizer": serialize_keras_object(self.mean_quantizer_internal),
"variance_quantizer": serialize_keras_object(
self.variance_quantizer_internal
),
"beta_initializer": initializers.serialize(self.beta_initializer),
"gamma_initializer": initializers.serialize(self.gamma_initializer),
"moving_mean_initializer": initializers.serialize(
self.moving_mean_initializer
),
"moving_variance_initializer": initializers.serialize(
self.moving_variance_initializer
),
"inverse_quantizer": serialize_keras_object(
self.inverse_quantizer_internal
),
"beta_regularizer": regularizers.serialize(self.beta_regularizer),
"gamma_regularizer": regularizers.serialize(self.gamma_regularizer),
"beta_constraint": constraints.serialize(self.beta_constraint),
"gamma_constraint": constraints.serialize(self.gamma_constraint),
"beta_range": self.beta_range,
"gamma_range": self.gamma_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs]
def compute_output_shape(self, input_shape):
return input_shape
[docs]
def get_quantizers(self):
return self.quantizers
[docs]
def get_prunable_weights(self):
return []