# Copyright 2020 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import json
import os
import threading
import time
import keras
import keras.ops.numpy as knp
_current = threading.local()
_current.writer = None
class _JSONLWriter:
def __init__(self, log_dir: str):
os.makedirs(log_dir, exist_ok=True)
fname = f"events_{int(time.time())}.jsonl"
self.path = os.path.join(log_dir, fname)
# Line-buffered for immediate writes
self._fh = open(self.path, "a", buffering=1, encoding="utf-8")
@contextlib.contextmanager
def as_default(self):
prev = getattr(_current, "writer", None)
_current.writer = self
try:
yield self
finally:
_current.writer = prev
def _write_scalar(self, tag: str, value, step: int):
rec = {
"wall_time": time.time(),
"step": int(step),
"tag": str(tag),
"value": float(value),
}
self._fh.write(json.dumps(rec) + "\n")
def close(self):
try:
self._fh.close()
except Exception:
pass
[docs]
def create_file_writer(log_dir: str):
"""Create a writer object with .as_default() context manager."""
return _JSONLWriter(log_dir)
[docs]
def scalar(name: str, data, step: int):
"""Write a scalar under the active default writer."""
w = getattr(_current, "writer", None)
if w is None:
raise RuntimeError("No default writer set. Use `with writer.as_default():`.")
w._write_scalar(name, data, step)
[docs]
class QNoiseScheduler(keras.callbacks.Callback):
"""Schedules the gradual quantization noise training for each step (or epoch).
It updates the qnoise_factor in the quantizers to gradually introduce the
quantization noise during training.
The idea was adopted from "https://arxiv.org/pdf/1903.01061.pdf"
"""
def __init__(
self,
start,
finish,
freq_type="epoch",
update_freq=1,
initial_step_or_epoch=0,
exponent=3.0,
use_ste=True,
log_dir=None,
):
"""Initializes this QNoiseScheduler.
Args:
start: Int. The step (epoch) to start the gradual training.
finish: Int. The step (epoch) to finish the gradual training. When the
start and the finish are equal, the qnoise_factor will be 1.0 in the
beginning of the training.
freq_type: Str. "step" or "epoch". It sets the qnoise_factor update
frequency type.
update_freq: Int. Updating frequency of the qnoise_factor.
initial_step_or_epoch: Int. Step or epoch at which to start training.
exponent: Float. It is the exponent in the qnoise_factor calculation. It
controls the rate of the gradual qnoise_factor change.
use_ste: Bool. Whether to use "straight-through estimator" (STE) method or
not.
log_dir: Str. log directory to save qnoise_factor every epoch end.
"""
super().__init__()
self.start = start
self.finish = finish
if start > finish:
raise ValueError(
f"start {start} must be greater than finish {finish}"
)
supported_freq_type = ["step", "epoch"]
if freq_type not in supported_freq_type:
raise ValueError(
f"Invalid frequency type {freq_type}. only {supported_freq_type} are " "supported."
)
self.freq_type = freq_type
self.update_freq = update_freq
self.initial_step_or_epoch = initial_step_or_epoch
self.exponent = exponent
self.qnoise_factor = None
self.use_ste = use_ste
self.quantizers = None
self.summary_writer = create_file_writer(log_dir) if log_dir else None
self.num_iters = knp.array(0, dtype="int64")
[docs]
def calculate_qnoise_factor(self, freq):
"""Returns calculated qnoise_factor based on the current step (epoch) and
the schedule parameters.
Args:
freq: The current step (or epoch) to calculate the qnoise_factor.
Returns:
qnoise_factor : calculated qnoise_factor.
"""
if freq < self.start:
qnoise_factor = 0.0
elif freq <= self.finish and self.start != self.finish:
val = float(self.finish - freq) / float(self.finish - self.start)
qnoise_factor = 1.0 - knp.power(val, self.exponent)
else:
qnoise_factor = 1.0
return qnoise_factor
[docs]
def set_qnoise_factor(self, quantizer, qnoise_factor):
"""Set self.qnoise_factor and update the qnoise_factor of the quantizer."""
# Updating the qnoise_factor of the quantizer.
quantizer.update_qnoise_factor(qnoise_factor)
# Updating the qnoise_factor of the callback.
self.qnoise_factor = qnoise_factor
[docs]
def set_quantizers(self):
"""Set quantizers to update the qnoise_factor.
This must be called before building the quantizers.
"""
for quantizer in self.quantizers:
if hasattr(quantizer, "use_ste"):
quantizer.use_ste = self.use_ste
if hasattr(quantizer, "use_variables"):
quantizer.use_variables = True
if hasattr(quantizer, "built"):
# If the quantizer has been built but not using keras.Variable then it
# builds again to create keras.Variables.
if quantizer.built and not isinstance(
quantizer.qnoise_factor, keras.Variable
):
quantizer.build(use_variables=True)
# Set the qnoise_factor to 0.0 to pretrain without quantization.
self.set_qnoise_factor(quantizer, qnoise_factor=0.0)
[docs]
def get_quantizers(self, model):
"""Returns a list of quantizers with qnoise_factor in the model.
Args:
model: model to get a list of quantizers with qnoise_factor.
Returns:
A list of quantizers with the qnoise_factor variable.
"""
all_quantizers = []
for layer in model.layers:
# A list of attributes holding the quantizer(s).
for attr in ["quantizers", "quantizer"]:
if hasattr(layer, attr):
quantizers = getattr(layer, attr)
quantizers = quantizers if attr == "quantizers" else [quantizers]
for quantizer in quantizers:
if hasattr(quantizer, "qnoise_factor"):
all_quantizers.append(quantizer)
return all_quantizers
[docs]
def update_qnoise_factor(self, freq):
"""Update the qnoise_factor of the model.
Args:
freq: The current step (epoch) to calculate the qnoise_factor.
"""
# Update the qnoise_factor at the frequency of self.update_freq.
if freq % self.update_freq != 0:
self.num_iters += 1
return
new_qnoise_factor = self.calculate_qnoise_factor(freq)
for quantizer in self.quantizers:
# Updates the qnoise factors of the quantizers in the model.
self.set_qnoise_factor(quantizer, new_qnoise_factor)
self.num_iters += 1
[docs]
def on_train_begin(self, logs=None):
if not self.quantizers:
# Build a list of quantizers which is used for updating qnoise_factor.
self.quantizers = self.get_quantizers(self.model)
self.set_quantizers()
[docs]
def on_epoch_begin(self, epoch, logs=None):
if self.freq_type == "epoch":
self.update_qnoise_factor(self.initial_step_or_epoch + self.num_iters)
[docs]
def on_epoch_end(self, epoch, logs=None):
if self.summary_writer and self.qnoise_factor is not None:
with self.summary_writer.as_default():
scalar("qnoise_factor", data=self.qnoise_factor, step=epoch)
[docs]
def on_train_batch_begin(self, batch, logs=None):
if self.freq_type == "step":
self.update_qnoise_factor(self.initial_step_or_epoch + self.num_iters)