Source code for qkeras.safe_eval

# Copyright 2019 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements a safe evaluation using globals()."""


import logging

import keras
from pyparsing import Group, Optional, Regex, Suppress, delimitedList


def Num(s):
    """Tries to convert string to either int or float."""
    try:
        try:
            return int(s)
        except ValueError:
            return float(s)
    except ValueError:
        # this should be always true. if it isn't int or float, it should be str
        assert (s[0] == '"' and s[-1] == '"') or (s[0] == "'" and s[-1] == "'")
        s = s[1:-1]
        return s


def Str(s):
    return s[1:-1]


def IsNum(s):
    try:
        try:
            int(s)
            return True
        except ValueError:
            float(s)
            return True
    except ValueError:
        return False


def IsBool(s):
    if s in ["True", "False"]:
        return True
    else:
        return False


def IsNone(s):
    return s == "None"


def Bool(s):
    return True if "True" in s else False


def ListofNums(s):
    # remove list brackets
    s = s.replace("[", "").replace("]", "")
    list_s = s.split(" ")
    return [Num(e) for e in list_s]


def IsListofNums(s):
    # remove list brackets
    s = s.replace("[", "").replace("]", "")
    list_s = s.split(" ")
    if len(list_s) > 1:
        for e in list_s:
            # if any of the elements is not a number return false
            if not IsNum(e):
                return False
        return True
    else:
        return False


def GetArg(s):
    if IsBool(s):
        return Bool(s)
    elif IsNum(s):
        return Num(s)
    elif IsNone(s):
        return None
    elif IsListofNums(s):
        return ListofNums(s)
    else:
        return Str(s)


def GetParams(s):
    """Extracts args and kwargs from string."""
    # modified from https://stackoverflow.com/questions/38799223/parse-string-to-identify-kwargs-and-args  # pylint: disable=line-too-long

    _lparen = Suppress("(")  # pylint: disable=invalid-name
    _rparen = Suppress(")")  # pylint: disable=invalid-name
    _eq = Suppress("=")  # pylint: disable=invalid-name

    data = (
        _lparen
        + Optional(
            delimitedList(Group(Regex(r"[^=,)\s]+") + Optional(_eq + Regex("[^,)]*"))))
        )
        + _rparen
    )

    items = data.parseString(s).asList()

    # need to make sure that kwargs only happen after args are processed
    args = [GetArg(i[0]) for i in items if len(i) == 1]
    kwargs = {i[0]: GetArg(i[1]) for i in items if len(i) == 2}

    # check for syntax error
    for i in range(1, len(items)):
        if (len(items[i]) == 1) and (len(items[i - 1]) == 2):
            raise SyntaxError(

                    "Error with item "
                    + str(i)
                    + " \n"
                    + "  parsing string "
                    + s
                    + "\n"
                    + "  Items: "
                    + str(items)
                    + "\n"
                    + "  Item["
                    + str(i - 1)
                    + "] :"
                    + str(items[i - 1])
                    + "\n"
                    + "  Item["
                    + str(i)
                    + "] :"
                    + str(items[i])

            )

    return args, kwargs


[docs] def safe_eval(eval_str, op_dict, *params, **kwparams): # pylint: disable=invalid-name """Replaces eval by a safe eval mechanism.""" function_split = eval_str.split("(") quantizer = op_dict.get(function_split[0], None) if len(function_split) == 2: args, kwargs = GetParams("(" + function_split[1]) else: args = [] kwargs = {} args = args + list(params) for k in kwparams: kwargs[k] = kwparams[k] # must be Keras activation object if None if quantizer is None: logging.info("keras dict %s", function_split[0]) quantizer = keras.activations.get(function_split[0]) if len(function_split) == 2 or args or kwargs: return quantizer(*args, **kwargs) elif isinstance(quantizer, type): # Check if quantizer is a class return quantizer() else: # Otherwise it is a function, so just return it return quantizer