@@ -1577,25 +1577,277 @@ class QuantizedOr(QuantizedOpUnivariateOfEncrypted, QuantizedOp):
15771577 _impl_for_op_named : str = "Or"
15781578
15791579
1580- class QuantizedDiv (QuantizedOpUnivariateOfEncrypted , QuantizedOp ):
1581- """Div operator / .
1580+ class QuantizedDiv (QuantizedMixingOp ):
1581+ """Quantized Division operator .
15821582
1583- This operation is not really working as a quantized operation. It just works when things got
1584- fused, as in e.g., Act(x) = 1000 / (x + 42))
1583+ Can divide either two variables (both encrypted) or a variable and a constant
15851584 """
15861585
15871586 _impl_for_op_named : str = "Div"
15881587
1588+ def __init__ (
1589+ self ,
1590+ * args ,
1591+ rounding_threshold_bits : Union [None , int , Dict [str , Union [str , int ]]] = None ,
1592+ ** kwargs ,
1593+ ) -> None :
1594+ super ().__init__ (* args , rounding_threshold_bits = rounding_threshold_bits , ** kwargs )
1595+ self .divider_quantizer : Optional [UniformQuantizer ] = None
1596+ self .min_non_zero_value : Optional [numpy .float64 ] = None
1597+
1598+ def calibrate (self , * inputs : numpy .ndarray ) -> numpy .ndarray :
1599+ """Create corresponding QuantizedArray for the output of the activation function.
1600+
1601+ Args:
1602+ *inputs (numpy.ndarray): Calibration sample inputs.
1603+
1604+ Returns:
1605+ numpy.ndarray: the output values for the provided calibration samples.
1606+ """
1607+
1608+ # If the op can not be fused and that the two inputs are not constants
1609+ # we need to compute the quantizer of the divider since we are doing
1610+ # an encrypted division where both numerator and denominator are encrypted
1611+ if not self .can_fuse () and len (inputs ) == 2 :
1612+
1613+ # FIXME https://github.com/zama-ai/concrete-ml-internal/issues/4556
1614+ min_non_zero_index = numpy .abs (inputs [1 ]).argmin (axis = None )
1615+ min_non_zero_value = inputs [1 ].flat [min_non_zero_index ]
1616+
1617+ # mypy
1618+ assert min_non_zero_value is not None and min_non_zero_value != 0
1619+ self .min_non_zero_value = min_non_zero_value
1620+
1621+ q_array_divider = QuantizedArray (self .n_bits , 1 / inputs [1 ])
1622+
1623+ # Store the quantizer of the divider
1624+ self .divider_quantizer = q_array_divider .quantizer
1625+
1626+ return super ().calibrate (* inputs )
1627+
1628+ def q_impl (
1629+ self ,
1630+ * q_inputs : ONNXOpInputOutputType ,
1631+ calibrate_rounding : bool = False ,
1632+ ** attrs ,
1633+ ) -> ONNXOpInputOutputType :
1634+
1635+ # If the op can be fused we perform the op in the clear
1636+ if self .can_fuse ():
1637+ return super ().q_impl (* q_inputs , ** attrs )
1638+
1639+ # For mypy
1640+ assert self .output_quant_params is not None
1641+ assert self .output_quant_params .scale is not None
1642+ assert self .output_quant_params .zero_point is not None
1643+
1644+ prepared_inputs = self ._prepare_inputs_with_constants (
1645+ * q_inputs , calibrate = False , quantize_actual_values = True
1646+ )
1647+
1648+ q_input_0 : QuantizedArray = prepared_inputs [0 ]
1649+ q_input_1 : QuantizedArray = prepared_inputs [1 ]
15891650
1590- class QuantizedMul (QuantizedOpUnivariateOfEncrypted , QuantizedOp ):
1591- """Multiplication operator.
1651+ assert q_input_0 .quantizer .scale is not None
1652+ assert q_input_0 .quantizer .zero_point is not None
1653+
1654+ assert q_input_1 .quantizer .scale is not None
1655+ assert q_input_1 .quantizer .zero_point is not None
1656+
1657+ # Dequantize
1658+ input_1 = q_input_1 .dequant ()
1659+
1660+ # Replace input_1 with min_non_zero_value if input_1 is 0
1661+ # mypy
1662+ assert self .min_non_zero_value is not None
1663+ input_1 = numpy .where (input_1 == 0 , self .min_non_zero_value , input_1 )
1664+
1665+ # Compute the inverse of input_1
1666+ input_1_inv = 1.0 / input_1
1667+
1668+ # Re-quantize the inverse using the same quantization parameters as q_input_1
1669+ # mypy
1670+ assert self .divider_quantizer is not None
1671+ # FIXME https://github.com/zama-ai/concrete-ml-internal/issues/4556
1672+ q_input_1_inv_rescaled = self .divider_quantizer .quant (input_1_inv )
1673+
1674+ # The product of quantized encrypted integer values
1675+ product_q_values = q_input_0 .qvalues * q_input_1_inv_rescaled
1676+
1677+ # mypy
1678+ assert q_input_0 .quantizer .zero_point is not None
1679+ assert q_input_1 .quantizer .zero_point is not None
1680+ assert self .divider_quantizer .zero_point is not None
1681+
1682+ # Integer quantized multiplication need adjustment based on the zero points.
1683+ if q_input_0 .quantizer .zero_point :
1684+ product_q_values -= q_input_0 .quantizer .zero_point * (
1685+ q_input_1_inv_rescaled - self .divider_quantizer .zero_point
1686+ )
1687+ if self .divider_quantizer .zero_point :
1688+ product_q_values -= self .divider_quantizer .zero_point * (
1689+ q_input_0 .qvalues - q_input_0 .quantizer .zero_point
1690+ )
1691+
1692+ # mypy
1693+ assert self .divider_quantizer .scale is not None
1694+ assert self .divider_quantizer .zero_point is not None
1695+
1696+ # Compute the scale and zero point based on the scale and zero point
1697+ # of the two quantized values multiplied together
1698+ new_scale = q_input_0 .quantizer .scale * self .divider_quantizer .scale
1699+ new_zero_point = q_input_0 .quantizer .zero_point * self .divider_quantizer .zero_point
1700+
1701+ if self .produces_graph_output :
1702+ return self .make_output_quant_parameters (product_q_values , new_scale , new_zero_point )
1703+
1704+ with tag (self .op_instance_name + ".rounding" ):
1705+ # Apply Concrete rounding (if relevant)
1706+ product_q_values = self .cnp_round (product_q_values , calibrate_rounding )
1707+
1708+ # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4546
1709+ # De-quantize the product
1710+ dequant_product = (product_q_values - new_zero_point ) * new_scale
1711+
1712+ # Return the raw float values without re-quantizing them to the new scale, as any
1713+ # following Gemm/Add/Conv will quantize them with _prepare_inputs_with_constants(...)
1714+ return QuantizedArray (
1715+ self .n_bits ,
1716+ dequant_product ,
1717+ value_is_float = True ,
1718+ options = self ._get_output_quant_opts (),
1719+ stats = self .output_quant_stats ,
1720+ params = self .output_quant_params ,
1721+ )
1722+
1723+ def can_fuse (self ) -> bool :
1724+ """Determine if this op can be fused.
15921725
1593- Only multiplies an encrypted tensor with a float constant for now. This operation will
1594- be fused to a (potentially larger) TLU.
1726+ Div operation can be computed in float and fused if it operates over inputs produced
1727+ by a single integer tensor.
1728+
1729+ Returns:
1730+ bool: Whether the number of integer input tensors allows computing this op as a TLU
1731+ """
1732+
1733+ return len (self ._int_input_names ) == 1
1734+
1735+
1736+ class QuantizedMul (QuantizedMixingOp ):
1737+ """Quantized Multiplication operator.
1738+
1739+ Can multiply either two variables (both encrypted) or a variable and a constant
15951740 """
15961741
15971742 _impl_for_op_named : str = "Mul"
15981743
1744+ def q_impl (
1745+ self ,
1746+ * q_inputs : ONNXOpInputOutputType ,
1747+ calibrate_rounding : bool = False ,
1748+ ** attrs ,
1749+ ) -> ONNXOpInputOutputType :
1750+
1751+ # If either input is a RawOpOutput or if the op can be fused,
1752+ # perform the op in the TLU using FP32
1753+ if (
1754+ len (q_inputs ) == 1
1755+ or isinstance (q_inputs [0 ], RawOpOutput )
1756+ or isinstance (q_inputs [1 ], RawOpOutput )
1757+ or self .can_fuse ()
1758+ ):
1759+ return super ().q_impl (* q_inputs , ** attrs )
1760+
1761+ # For mypy
1762+ assert self .output_quant_params is not None
1763+ assert self .output_quant_params .scale is not None
1764+ assert self .output_quant_params .zero_point is not None
1765+
1766+ prepared_inputs = self ._prepare_inputs_with_constants (
1767+ * q_inputs , calibrate = False , quantize_actual_values = True
1768+ )
1769+
1770+ q_input_0 : QuantizedArray = prepared_inputs [0 ]
1771+ q_input_1 : QuantizedArray = prepared_inputs [1 ]
1772+
1773+ assert q_input_0 .quantizer .scale is not None
1774+ assert q_input_0 .quantizer .zero_point is not None
1775+
1776+ assert q_input_1 .quantizer .scale is not None
1777+ assert q_input_1 .quantizer .zero_point is not None
1778+
1779+ # Remove the manual encrypted multiplication when we
1780+ # can handle input precision with rounding
1781+ # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
1782+ @univariate
1783+ def copy_function (x ):
1784+ return x
1785+
1786+ input0_q_values = q_input_0 .qvalues
1787+ input1_q_values = q_input_1 .qvalues
1788+
1789+ with tag (self .op_instance_name + ".enc_mul_rounding_input0" ):
1790+ input0_q_values = self .cnp_round (
1791+ input0_q_values , calibrate_rounding , rounding_operation_id = "input0_q_values_copy"
1792+ )
1793+
1794+ with tag (self .op_instance_name + ".enc_mul_rounding_input1" ):
1795+ input1_q_values = self .cnp_round (
1796+ input1_q_values , calibrate_rounding , rounding_operation_id = "input1_q_values_copy"
1797+ )
1798+
1799+ input0_q_values = copy_function (input0_q_values )
1800+ input1_q_values = copy_function (input1_q_values )
1801+
1802+ # The product of quantized encrypted integer values
1803+ product_q_values = input0_q_values * input1_q_values
1804+
1805+ # Integer quantized multiplication need adjustment based on the zero points.
1806+ if q_input_0 .quantizer .zero_point :
1807+ product_q_values -= q_input_0 .quantizer .zero_point * input1_q_values
1808+ if q_input_1 .quantizer .zero_point :
1809+ product_q_values -= q_input_1 .quantizer .zero_point * input0_q_values
1810+
1811+ # Compute the scale and zero point based on the scale and zero point
1812+ # of the two quantized values multiplied together
1813+ new_scale = q_input_0 .quantizer .scale * q_input_1 .quantizer .scale
1814+ new_zero_point = q_input_0 .quantizer .zero_point * q_input_1 .quantizer .zero_point
1815+
1816+ if self .produces_graph_output :
1817+ return self .make_output_quant_parameters (product_q_values , new_scale , new_zero_point )
1818+
1819+ with tag (self .op_instance_name + ".rounding" ):
1820+ # Apply Concrete rounding (if relevant)
1821+ product_q_values = self .cnp_round (
1822+ product_q_values , calibrate_rounding , rounding_operation_id = "product_q_values"
1823+ )
1824+
1825+ # De-quantize the product
1826+ dequant_product = (product_q_values + new_zero_point ) * new_scale
1827+
1828+ # Return the raw float values without re-quantizing them to the new scale, as any
1829+ # following Gemm/Add/Conv will quantize them with _prepare_inputs_with_constants(...)
1830+ return QuantizedArray (
1831+ self .n_bits ,
1832+ dequant_product ,
1833+ value_is_float = True ,
1834+ options = self ._get_output_quant_opts (),
1835+ stats = self .output_quant_stats ,
1836+ params = self .output_quant_params ,
1837+ )
1838+
1839+ def can_fuse (self ) -> bool :
1840+ """Determine if this op can be fused.
1841+
1842+ Mul operation can be computed in float and fused if it operates over inputs produced
1843+ by a single integer tensor.
1844+
1845+ Returns:
1846+ bool: Whether the number of integer input tensors allows computing this op as a TLU
1847+ """
1848+
1849+ return len (self ._int_input_names ) == 1
1850+
15991851
16001852class QuantizedSub (QuantizedAdd ):
16011853 """Subtraction operator.
0 commit comments