8000 feat: support encrypted mul div (#690) · zama-ai/concrete-ml@a1bd9b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit a1bd9b8

Browse files
authored
feat: support encrypted mul div (#690)
1 parent 77ced60 commit a1bd9b8

File tree

5 files changed

+332
-27
lines changed

5 files changed

+332
-27
lines changed

src/concrete/ml/pytest/torch_models.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,3 +1600,43 @@ def forward(self, x): # pylint: disable-next=no-self-use
16001600
Tuple[torch.Tensor. torch.Tensor]: Outputs of the network.
16011601
"""
16021602
return x, x.unsqueeze(0)
1603+
1604+
1605+
class TorchDivide(torch.nn.Module):
1606+
"""Torch model that performs a encrypted division between two inputs."""
1607+
1608+
def __init__(self, input_output, activation_function): # pylint: disable=unused-argument
1609+
super().__init__()
1610+
1611+
@staticmethod
1612+
def forward(x, y):
1613+
"""Forward pass.
1614+
1615+
Args:
1616+
x (torch.Tensor): The first input tensor.
1617+
y (torch.Tensor): The second input tensor.
1618+
1619+
Returns:
1620+
torch.Tensor: The result of the division.
1621+
"""
1622+
return x / y
1623+
1624+
1625+
class TorchMultiply(torch.nn.Module):
1626+
"""Torch model that performs a encrypted multiplication between two inputs."""
1627+
1628+
def __init__(self, input_output, activation_function): # pylint: disable=unused-argument
1629+
super().__init__()
1630+
1631+
@staticmethod
1632+
def forward(x, y):
1633+
"""Forward pass.
1634+
1635+
Args:
1636+
x (torch.Tensor): The first input tensor.
1637+
y (torch.Tensor): The second input tensor.
1638+
1639+
Returns:
1640+
torch.Tensor: The result of the multiplication.
1641+
"""
1642+
return x * y

src/concrete/ml/quantization/base_quantized_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ def dump_dict(self) -> Dict:
257257
if attribute_name not in metadata:
258258
metadata[attribute_name] = attribute_value
259259

260+
# Sort the metadata to be sure it is always saved in the same order
261+
metadata = dict(sorted(metadata.items()))
262+
260263
return metadata
261264

262265
@staticmethod

src/concrete/ml/quantization/quantized_ops.py

Lines changed: 260 additions & 8 deletions
+
calibrate_rounding: bool = False,
Original file line numberDiff line numberDiff line change
@@ -1577,25 +1577,277 @@ class QuantizedOr(QuantizedOpUnivariateOfEncrypted, QuantizedOp):
15771577
_impl_for_op_named: str = "Or"
15781578

15791579

1580-
class QuantizedDiv(QuantizedOpUnivariateOfEncrypted, QuantizedOp):
1581-
"""Div operator /.
1580+
class QuantizedDiv(QuantizedMixingOp):
1581+
"""Quantized Division operator.
15821582
1583-
This operation is not really working as a quantized operation. It just works when things got
1584-
fused, as in e.g., Act(x) = 1000 / (x + 42))
1583+
Can divide either two variables (both encrypted) or a variable and a constant
15851584
"""
15861585

15871586
_impl_for_op_named: str = "Div"
15881587

1588+
def __init__(
1589+
self,
1590+
*args,
1591+
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
1592+
**kwargs,
1593+
) -> None:
1594+
super().__init__(*args, rounding_threshold_bits=rounding_threshold_bits, **kwargs)
1595+
self.divider_quantizer: Optional[UniformQuantizer] = None
1596+
self.min_non_zero_value: Optional[numpy.float64] = None
1597+
1598+
def calibrate(self, *inputs: numpy.ndarray) -> numpy.ndarray:
1599+
"""Create corresponding QuantizedArray for the output of the activation function.
1600+
1601+
Args:
1602+
*inputs (numpy.ndarray): Calibration sample inputs.
1603+
1604+
Returns:
1605+
numpy.ndarray: the output values for the provided calibration samples.
1606+
"""
1607+
1608+
# If the op can not be fused and that the two inputs are not constants
1609+
# we need to compute the quantizer of the divider since we are doing
1610+
# an encrypted division where both numerator and denominator are encrypted
1611+
if not self.can_fuse() and len(inputs) == 2:
1612+
1613+
# FIXME https://github.com/zama-ai/concrete-ml-internal/issues/4556
1614+
min_non_zero_index = numpy.abs(inputs[1]).argmin(axis=None)
1615+
min_non_zero_value = inputs[1].flat[min_non_zero_index]
1616+
1617+
# mypy
1618+
assert min_non_zero_value is not None and min_non_zero_value != 0
1619+
self.min_non_zero_value = min_non_zero_value
1620+
1621+
q_array_divider = QuantizedArray(self.n_bits, 1 / inputs[1])
1622+
1623+
# Store the quantizer of the divider
1624+
self.divider_quantizer = q_array_divider.quantizer
1625+
1626+
return super().calibrate(*inputs)
1627+
1628+
def q_impl(
1629+
self,
1630+
*q_inputs: ONNXOpInputOutputType,
1631
1632+
**attrs,
1633+
) -> ONNXOpInputOutputType:
1634+
1635+
# If the op can be fused we perform the op in the clear
1636+
if self.can_fuse():
1637+
return super().q_impl(*q_inputs, **attrs)
1638+
1639+
# For mypy
1640+
assert self.output_quant_params is not None
1641+
assert self.output_quant_params.scale is not None
1642+
assert self.output_quant_params.zero_point is not None
1643+
1644+
prepared_inputs = self._prepare_inputs_with_constants(
1645+
*q_inputs, calibrate=False, quantize_actual_values=True
1646+
)
1647+
1648+
q_input_0: QuantizedArray = prepared_inputs[0]
1649+
q_input_1: QuantizedArray = prepared_inputs[1]
15891650

1590-
class QuantizedMul(QuantizedOpUnivariateOfEncrypted, QuantizedOp):
1591-
"""Multiplication operator.
1651+
assert q_input_0.quantizer.scale is not None
1652+
assert q_input_0.quantizer.zero_point is not None
1653+
1654+
assert q_input_1.quantizer.scale is not None
1655+
assert q_input_1.quantizer.zero_point is not None
1656+
1657+
# Dequantize
1658+
input_1 = q_input_1.dequant()
1659+
1660+
# Replace input_1 with min_non_zero_value if input_1 is 0
1661+
# mypy
1662+
assert self.min_non_zero_value is not None
1663+
input_1 = numpy.where(input_1 == 0, self.min_non_zero_value, input_1)
1664+
1665+
# Compute the inverse of input_1
1666+
input_1_inv = 1.0 / input_1
1667+
1668+
# Re-quantize the inverse using the same quantization parameters as q_input_1
1669+
# mypy
1670+
assert self.divider_quantizer is not None
1671+
# FIXME https://github.com/zama-ai/concrete-ml-internal/issues/4556
1672+
q_input_1_inv_rescaled = self.divider_quantizer.quant(input_1_inv)
1673+
1674+
# The product of quantized encrypted integer values
1675+
product_q_values = q_input_0.qvalues * q_input_1_inv_rescaled
1676+
1677+
# mypy
1678+
assert q_input_0.quantizer.zero_point is not None
1679+
assert q_input_1.quantizer.zero_point is not None
1680+
assert self.divider_quantizer.zero_point is not None
1681+
1682+
# Integer quantized multiplication need adjustment based on the zero points.
1683+
if q_input_0.quantizer.zero_point:
1684+
product_q_values -= q_input_0.quantizer.zero_point * (
1685+
q_input_1_inv_rescaled - self.divider_quantizer.zero_point
1686+
)
1687+
if self.divider_quantizer.zero_point:
1688+
product_q_values -= self.divider_quantizer.zero_point * (
1689+
q_input_0.qvalues - q_input_0.quantizer.zero_point
1690+
)
1691+
1692+
# mypy
1693+
assert self.divider_quantizer.scale is not None
1694+
assert self.divider_quantizer.zero_point is not None
1695+
1696+
# Compute the scale and zero point based on the scale and zero point
1697+
# of the two quantized values multiplied together
1698+
new_scale = q_input_0.quantizer.scale * self.divider_quantizer.scale
1699+
new_zero_point = q_input_0.quantizer.zero_point * self.divider_quantizer.zero_point
1700+
1701+
if self.produces_graph_output:
1702+
return self.make_output_quant_parameters(product_q_values, new_scale, new_zero_point)
1703+
1704+
with tag(self.op_instance_name + ".rounding"):
1705+
# Apply Concrete rounding (if relevant)
1706+
product_q_values = self.cnp_round(product_q_values, calibrate_rounding)
1707+
1708+
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4546
1709+
# De-quantize the product
1710+
dequant_product = (product_q_values - new_zero_point) * new_scale
1711+
1712+
# Return the raw float values without re-quantizing them to the new scale, as any
1713+
# following Gemm/Add/Conv will quantize them with _prepare_inputs_with_constants(...)
1714+
return QuantizedArray(
1715+
self.n_bits,
1716+
dequant_product,
1717+
value_is_float=True,
1718+
options=self._get_output_quant_opts(),
1719+
stats=self.output_quant_stats,
1720+
params=self.output_quant_params,
1721+
)
1722+
1723+
def can_fuse(self) -> bool:
1724+
"""Determine if this op can be fused.
15921725
1593-
Only multiplies an encrypted tensor with a float constant for now. This operation will
1594-
be fused to a (potentially larger) TLU.
1726+
Div operation can be computed in float and fused if it operates over inputs produced
1727+
by a single integer tensor.
1728+
1729+
Returns:
1730+
bool: Whether the number of integer input tensors allows computing this op as a TLU
1731+
"""
1732+
1733+
return len(self._int_input_names) == 1
1734+
1735+
1736+
class QuantizedMul(QuantizedMixingOp):
1737+
"""Quantized Multiplication operator.
1738+
1739+
Can multiply either two variables (both encrypted) or a variable and a constant
15951740
"""
15961741

15971742
_impl_for_op_named: str = "Mul"
15981743

1744+
def q_impl(
1745+
self,
1746+
*q_inputs: ONNXOpInputOutputType,
1747+
calibrate_rounding: bool = False,
1748+
**attrs,
1749+
) -> ONNXOpInputOutputType:
1750+
1751+
# If either input is a RawOpOutput or if the op can be fused,
1752+
# perform the op in the TLU using FP32
1753+
if (
1754+
len(q_inputs) == 1
1755+
or isinstance(q_inputs[0], RawOpOutput)
1756+
or isinstance(q_inputs[1], RawOpOutput)
1757+
or self.can_fuse()
1758+
):
1759+
return super().q_impl(*q_inputs, **attrs)
1760+
1761+
# For mypy
1762+
assert self.output_quant_params is not None
1763+
assert self.output_quant_params.scale is not None
1764+
assert self.output_quant_params.zero_point is not None
1765+
1766+
prepared_inputs = self._prepare_inputs_with_constants(
1767+
*q_inputs, calibrate=False, quantize_actual_values=True
1768+
)
1769+
1770+
q_input_0: QuantizedArray = prepared_inputs[0]
1771+
q_input_1: QuantizedArray = prepared_inputs[1]
1772+
1773+
assert q_input_0.quantizer.scale is not None
1774+
assert q_input_0.quantizer.zero_point is not None
1775+
1776+
assert q_input_1.quantizer.scale is not None
1777+
assert q_input_1.quantizer.zero_point is not None
1778+
1779+
# Remove the manual encrypted multiplication when we
1780+
# can handle input precision with rounding
1781+
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
1782+
@univariate
1783+
def copy_function(x):
1784+
return x
1785+
1786+
input0_q_values = q_input_0.qvalues
1787+
input1_q_values = q_input_1.qvalues
1788+
1789+
with tag(self.op_instance_name + ".enc_mul_rounding_input0"):
1790+
input0_q_values = self.cnp_round(
1791+
input0_q_values, calibrate_rounding, rounding_operation_id="input0_q_values_copy"
1792+
)
1793+
1794+
with tag(self.op_instance_name + ".enc_mul_rounding_input1"):
1795+
input1_q_values = self.cnp_round(
1796+
input1_q_values, calibrate_rounding, rounding_operation_id="input1_q_values_copy"
1797+
)
1798+
1799+
input0_q_values = copy_function(input0_q_values)
1800+
input1_q_values = copy_function(input1_q_values)
1801+
1802+
# The product of quantized encrypted integer values
1803+
product_q_values = input0_q_values * input1_q_values
1804+
1805+
# Integer quantized multiplication need adjustment based on the zero points.
1806+
if q_input_0.quantizer.zero_point:
1807+
product_q_values -= q_input_0.quantizer.zero_point * input1_q_values
1808+
if q_input_1.quantizer.zero_point:
1809+
product_q_values -= q_input_1.quantizer.zero_point * input0_q_values
1810+
1811+
# Compute the scale and zero point based on the scale and zero point
1812+
# of the two quantized values multiplied together
1813+
new_scale = q_input_0.quantizer.scale * q_input_1.quantizer.scale
1814+
new_zero_point = q_input_0.quantizer.zero_point * q_input_1.quantizer.zero_point
1815+
1816+
if self.produces_graph_output:
1817+
return self.make_output_quant_parameters(product_q_values, new_scale, new_zero_point)
1818+
1819+
with tag(self.op_instance_name + ".rounding"):
1820+
# Apply Concrete rounding (if relevant)
1821+
product_q_values = self.cnp_round(
1822+
product_q_values, calibrate_rounding, rounding_operation_id="product_q_values"
1823+
)
1824+
1825+
# De-quantize the product
1826+
dequant_product = (product_q_values + new_zero_point) * new_scale
1827+
1828+
# Return the raw float values without re-quantizing them to the new scale, as any
1829+
# following Gemm/Add/Conv will quantize them with _prepare_inputs_with_constants(...)
1830+
return QuantizedArray(
1831+
self.n_bits,
1832+
dequant_product,
1833+
value_is_float=True,
1834+
options=self._get_output_quant_opts(),
1835+
stats=self.output_quant_stats,
1836+
params=self.output_quant_params,
1837+
)
1838+
1839+
def can_fuse(self) -> bool:
1840+
"""Determine if this op can be fused.
1841+
1842+
Mul operation can be computed in float and fused if it operates over inputs produced
1843+
by a single integer tensor.
1844+
1845+
Returns:
1846+
bool: Whether the number of integer input tensors allows computing this op as a TLU
1847+
"""
1848+
1849+
return len(self._int_input_names) == 1
1850+
15991851

16001852
class QuantizedSub(QuantizedAdd):
16011853
"""Subtraction operator.

0 commit comments

Comments
 (0)
0