# Copyright 2019 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" "create divider quantizer."""
import abc
import copy
from absl import logging
from qkeras.qtools.quantized_operators import divider_impl, quantizer_impl
[docs]
class UnacceptedQuantizerError(ValueError):
pass
[docs]
class IDivider(abc.ABC):
"""abstract class for divider."""
def __init__(self):
# also attached the output datatype in the table
self.divider_impl_table = [
[
# when qbits is denominator, use default bits for float result
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(
bits=quantizer_impl.FLOATINGPOINT_BITS
),
),
(divider_impl.Shifter, quantizer_impl.QuantizedBits()),
(None, None),
(None, None),
(None, None),
# when bits sets to None, will decide f16/f32 according
# to input quantizer
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(bits=None),
),
],
[
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(
bits=quantizer_impl.FLOATINGPOINT_BITS
),
),
(divider_impl.Subtractor, quantizer_impl.PowerOfTwo()),
(None, None),
(None, None),
(None, None),
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(bits=None),
),
],
[
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(
bits=quantizer_impl.FLOATINGPOINT_BITS
),
),
(divider_impl.Shifter, quantizer_impl.QuantizedBits()),
(None, None),
(None, None),
(None, None),
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(bits=None),
),
],
[
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(
bits=quantizer_impl.FLOATINGPOINT_BITS
),
),
(divider_impl.Shifter, quantizer_impl.PowerOfTwo()),
(None, None),
(None, None),
(None, None),
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(bits=None),
),
],
[
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(
bits=quantizer_impl.FLOATINGPOINT_BITS
),
),
(divider_impl.Shifter, quantizer_impl.PowerOfTwo()),
(None, None),
(None, None),
(None, None),
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(bits=None),
),
],
[
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(bits=None),
),
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(bits=None),
),
(None, None),
(None, None),
(None, None),
(
divider_impl.FloatingPointDivider,
quantizer_impl.FloatingPoint(bits=None),
),
],
]
[docs]
def make_quantizer(
self,
numerator_quantizer: quantizer_impl.IQuantizer,
denominator_quantizer: quantizer_impl.IQuantizer,
):
"""make the quantizer."""
# Create a local copy so that the changes made here won't change the input
local_numerator_quantizer = copy.deepcopy(numerator_quantizer)
local_denominator_quantizer = copy.deepcopy(denominator_quantizer)
mode1 = local_numerator_quantizer.mode
mode2 = local_denominator_quantizer.mode
(divider_impl_class, output_quantizer) = self.divider_impl_table[mode1][mode2]
local_output_quantizer = copy.deepcopy(output_quantizer)
if divider_impl_class is None:
raise UnacceptedQuantizerError(
f"denominator quantizer {denominator_quantizer.name} not accepted!"
)
logging.debug(
"qbn adder implemented as class %s", divider_impl_class.implemented_as()
)
return divider_impl_class(
local_numerator_quantizer,
local_denominator_quantizer,
local_output_quantizer,
)