IntQuant

Calculates the integer-quantized values of one input data (Tensor) and produces one output data (Tensor). Additionally, takes three floats as input, which define the scale, zero-point and bit-width of the quantization, which may be scalars or tensors with number of dimensions equal to the input data tensor, for e.g. tensor-wise or channel-wise quantization. The attributes narrow and signed define how the bits of the quantization are interpreted, while the attribute rounding_mode defines how quantized values are rounded.

Notes:

  • This operator was previously named Quant but is renamed to IntQuant to distinguish it from FloatQuant. For a transition period, qonnx will transparently handle Quant as IntQuant for backwards compatibility reasons, but only IntQuant should be used for new models.

  • This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.

Version

The description of this operator in this document corresponds to qonnx.custom_ops.general opset version 1.

Attributes

signed : int (default is 1)
Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].
narrow : int (default is 0)
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
rounding_mode : string (default is "ROUND")
Defines how rounding should be applied during quantization. Avaiable options are ROUND, CEIL, FLOOR, UP, DOWN, HALF_UP, HALF_DOWN. The rounding modes are described in the table bellow. The names of rounding modes can be upper case or lower case.

Inputs

X (differentiable) : tensor(float32)
input tensor to quantize
scale : float32, tensor(float32)
The scale factor, either as a global scalar or with a shape matching the number of dimensions of the X tensor
zeropt : float32, tensor(float32)
The zero-point, either as a global scalar or with a shape matching the number of dimensions of the X tensor
bitwidth : int32, float32
The number of bits used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.

Outputs

Y (differentiable) : tensor(float32)
Output tensor

Rounding modes

rounding modes

Number \ ROUNDING_MODE

ROUND=HALF_EVEN

CEIL

FLOOR

UP

DOWN

HALF_UP

HALF_DOWN

5.5

6

6

5

6

5

6

5

2.5

2

3

2

3

2

3

2

1.6

2

2

1

2

1

2

2

1.1

1

2

1

2

1

1

1

1.0

1

1

1

1

1

1

1

-1.0

-1

-1

-1

-1

-1

-1

-1

-1.1

-1

-1

-2

-2

-1

-1

-1

-1.6

-2

-1

-2

-2

-1

-2

-2

-2.5

-2

-2

-3

-3

-2

-3

-2

-5.5

-6

-5

-6

-6

-5

-6

-5

Examples

IntQuant
from onnx import helper
import numpy as np

# Define node settings and input
x = np.random.randn(100).astype(np.float32)*10.
scale = np.array(1.)
zeropt = np.array(0.)
bitwidth = np.array(4)
signed = 1
narrow = 0
rounding_mode = "ROUND"

# Create node
node = helper.make_node(
    'IntQuant',
    domain='finn.custom_op.general',
    inputs=['x', 'scale', 'zeropt', 'bitwidth'],
    outputs=['y'],
    narrow=narrow,
    signed=signed,
    rounding_mode=rounding_mode,
)

# Execute the same settings with the reference implementation (quant)
# See the sample implementation for more details on quant.
output_ref = quant(x, scale, zeropt, bitwidth, signed, narrow, rounding_mode)

# Execute node and compare
expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_intquant')

Sample Implementation

IntQuant
# SPDX-License-Identifier: Apache-2.0

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np

def quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode):
    # Port of IntQuant class from Brevitas: https://bit.ly/2S6qvZJ
    # Scaling
    y_int = inp_tensor / scale
    y_int = y_int + zeropt
    # Clamping
    min_int_val = min_int(signed, narrow, bitwidth)
    max_int_val = max_int(signed, narrow, bitwidth)
    y_int = np.where(y_int > max_int_val, max_int_val.astype(y_int.dtype), y_int)
    y_int = np.where(y_int < min_int_val, min_int_val.astype(y_int.dtype), y_int)
    # Rounding
    rounding_fx = resolve_rounding_mode(rounding_mode)
    y_int = rounding_fx(y_int)

    # Re-scaling
    out_tensor = y_int - zeropt
    out_tensor = out_tensor * scale

    return out_tensor

def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
    """Compute the minimum integer representable by a given number of bits.
    Args:
        signed (bool): Indicates whether the represented integer is signed or not.
        narrow_range (bool): Indicates whether to narrow the minimum value
        represented by 1.
        bit_width (int): Number of bits available for the representation.
    Returns:
        int: Maximum unsigned integer that can be represented according to
        the input arguments.
    Examples:
        >>> min_int(signed=True, narrow_range=True, bit_width=8)
        int(-127)
        >>> min_int(signed=False, narrow_range=True, bit_width=8)
        int(0)
        >>> min_int(signed=True, narrow_range=False, bit_width=8)
        int(-128)
        >>> min_int(signed=False, narrow_range=False, bit_width=8)
        int(0)
    """
    if signed and narrow_range:
        value = -(2 ** (bit_width - 1)) + 1
    elif signed and not narrow_range:
        value = -(2 ** (bit_width - 1))
    else:
        value = 0 * bit_width
    return value


def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
    """Compute the maximum integer representable by a given number of bits.
    Args:
        signed (bool): Indicates whether the represented integer is signed or not.
        narrow_range (bool): Indicates whether to narrow the maximum unsigned value
        represented by 1.
        bit_width (int): Number of bits available for the representation.
    Returns:
        Tensor: Maximum integer that can be represented according to
        the input arguments.
    Examples:
        >>> max_int(signed=True, narrow_range=True, bit_width=8)
        int(127)
        >>> max_int(signed=False, narrow_range=True, bit_width=8)
        int(254)
        >>> max_int(signed=True, narrow_range=False, bit_width=8)
        int(127)
        >>> max_int(signed=False, narrow_range=False, bit_width=8)
        int(255)
    """
    if not signed and not narrow_range:
        value = (2 ** bit_width) - 1
    elif not signed and narrow_range:
        value = (2 ** bit_width) - 2
    else:
        value = (2 ** (bit_width - 1)) - 1
    return value

def resolve_rounding_mode(mode_string):
    """Resolve the rounding mode string of IntQuant and Trunc ops
    to the corresponding numpy functions."""
    if mode_string == "ROUND":
        return np.round
    elif mode_string == "CEIL":
        return np.ceil
    elif mode_string == "FLOOR":
        return np.floor
    else:
        raise ValueError(f"Could not resolve rounding mode called: {mode_string}")