Source code for qkeras.callbacks

# Copyright 2020 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import contextlib
import json
import os
import threading
import time

import keras
import keras.ops.numpy as knp

_current = threading.local()
_current.writer = None

class _JSONLWriter:
    def __init__(self, log_dir: str):
        os.makedirs(log_dir, exist_ok=True)
        fname = f"events_{int(time.time())}.jsonl"
        self.path = os.path.join(log_dir, fname)
        # Line-buffered for immediate writes
        self._fh = open(self.path, "a", buffering=1, encoding="utf-8")

    @contextlib.contextmanager
    def as_default(self):
        prev = getattr(_current, "writer", None)
        _current.writer = self
        try:
            yield self
        finally:
            _current.writer = prev

    def _write_scalar(self, tag: str, value, step: int):
        rec = {
            "wall_time": time.time(),
            "step": int(step),
            "tag": str(tag),
            "value": float(value),
        }
        self._fh.write(json.dumps(rec) + "\n")

    def close(self):
        try:
            self._fh.close()
        except Exception:
            pass

[docs] def create_file_writer(log_dir: str): """Create a writer object with .as_default() context manager.""" return _JSONLWriter(log_dir)
[docs] def scalar(name: str, data, step: int): """Write a scalar under the active default writer.""" w = getattr(_current, "writer", None) if w is None: raise RuntimeError("No default writer set. Use `with writer.as_default():`.") w._write_scalar(name, data, step)
[docs] class QNoiseScheduler(keras.callbacks.Callback): """Schedules the gradual quantization noise training for each step (or epoch). It updates the qnoise_factor in the quantizers to gradually introduce the quantization noise during training. The idea was adopted from "https://arxiv.org/pdf/1903.01061.pdf" """ def __init__( self, start, finish, freq_type="epoch", update_freq=1, initial_step_or_epoch=0, exponent=3.0, use_ste=True, log_dir=None, ): """Initializes this QNoiseScheduler. Args: start: Int. The step (epoch) to start the gradual training. finish: Int. The step (epoch) to finish the gradual training. When the start and the finish are equal, the qnoise_factor will be 1.0 in the beginning of the training. freq_type: Str. "step" or "epoch". It sets the qnoise_factor update frequency type. update_freq: Int. Updating frequency of the qnoise_factor. initial_step_or_epoch: Int. Step or epoch at which to start training. exponent: Float. It is the exponent in the qnoise_factor calculation. It controls the rate of the gradual qnoise_factor change. use_ste: Bool. Whether to use "straight-through estimator" (STE) method or not. log_dir: Str. log directory to save qnoise_factor every epoch end. """ super().__init__() self.start = start self.finish = finish if start > finish: raise ValueError( f"start {start} must be greater than finish {finish}" ) supported_freq_type = ["step", "epoch"] if freq_type not in supported_freq_type: raise ValueError( f"Invalid frequency type {freq_type}. only {supported_freq_type} are " "supported." ) self.freq_type = freq_type self.update_freq = update_freq self.initial_step_or_epoch = initial_step_or_epoch self.exponent = exponent self.qnoise_factor = None self.use_ste = use_ste self.quantizers = None self.summary_writer = create_file_writer(log_dir) if log_dir else None self.num_iters = knp.array(0, dtype="int64")
[docs] def calculate_qnoise_factor(self, freq): """Returns calculated qnoise_factor based on the current step (epoch) and the schedule parameters. Args: freq: The current step (or epoch) to calculate the qnoise_factor. Returns: qnoise_factor : calculated qnoise_factor. """ if freq < self.start: qnoise_factor = 0.0 elif freq <= self.finish and self.start != self.finish: val = float(self.finish - freq) / float(self.finish - self.start) qnoise_factor = 1.0 - knp.power(val, self.exponent) else: qnoise_factor = 1.0 return qnoise_factor
[docs] def set_qnoise_factor(self, quantizer, qnoise_factor): """Set self.qnoise_factor and update the qnoise_factor of the quantizer.""" # Updating the qnoise_factor of the quantizer. quantizer.update_qnoise_factor(qnoise_factor) # Updating the qnoise_factor of the callback. self.qnoise_factor = qnoise_factor
[docs] def set_quantizers(self): """Set quantizers to update the qnoise_factor. This must be called before building the quantizers. """ for quantizer in self.quantizers: if hasattr(quantizer, "use_ste"): quantizer.use_ste = self.use_ste if hasattr(quantizer, "use_variables"): quantizer.use_variables = True if hasattr(quantizer, "built"): # If the quantizer has been built but not using keras.Variable then it # builds again to create keras.Variables. if quantizer.built and not isinstance( quantizer.qnoise_factor, keras.Variable ): quantizer.build(use_variables=True) # Set the qnoise_factor to 0.0 to pretrain without quantization. self.set_qnoise_factor(quantizer, qnoise_factor=0.0)
[docs] def get_quantizers(self, model): """Returns a list of quantizers with qnoise_factor in the model. Args: model: model to get a list of quantizers with qnoise_factor. Returns: A list of quantizers with the qnoise_factor variable. """ all_quantizers = [] for layer in model.layers: # A list of attributes holding the quantizer(s). for attr in ["quantizers", "quantizer"]: if hasattr(layer, attr): quantizers = getattr(layer, attr) quantizers = quantizers if attr == "quantizers" else [quantizers] for quantizer in quantizers: if hasattr(quantizer, "qnoise_factor"): all_quantizers.append(quantizer) return all_quantizers
[docs] def update_qnoise_factor(self, freq): """Update the qnoise_factor of the model. Args: freq: The current step (epoch) to calculate the qnoise_factor. """ # Update the qnoise_factor at the frequency of self.update_freq. if freq % self.update_freq != 0: self.num_iters += 1 return new_qnoise_factor = self.calculate_qnoise_factor(freq) for quantizer in self.quantizers: # Updates the qnoise factors of the quantizers in the model. self.set_qnoise_factor(quantizer, new_qnoise_factor) self.num_iters += 1
[docs] def on_train_begin(self, logs=None): if not self.quantizers: # Build a list of quantizers which is used for updating qnoise_factor. self.quantizers = self.get_quantizers(self.model) self.set_quantizers()
[docs] def on_epoch_begin(self, epoch, logs=None): if self.freq_type == "epoch": self.update_qnoise_factor(self.initial_step_or_epoch + self.num_iters)
[docs] def on_epoch_end(self, epoch, logs=None): if self.summary_writer and self.qnoise_factor is not None: with self.summary_writer.as_default(): scalar("qnoise_factor", data=self.qnoise_factor, step=epoch)
[docs] def on_train_batch_begin(self, batch, logs=None): if self.freq_type == "step": self.update_qnoise_factor(self.initial_step_or_epoch + self.num_iters)