8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d8f9e35 commit 8e794b7Copy full SHA for 8e794b7
tests/pytorch/kernel/test_fill_kv_cache.py
@@ -8,17 +8,13 @@ def _div_up(a, b):
8
return (a + b - 1) // b
9
10
11
-def precise_round(x: torch.Tensor):
12
- return x.sign() * (x.abs() + 0.5).floor()
13
-
14
15
def quant(kv: torch.Tensor, nbits: int = 8):
16
"""Quant kv on the head_dim."""
17
amax = kv.amax(dim=-1, keepdim=True)
18
amin = kv.amin(dim=-1, keepdim=True)
19
scales = (amax - amin) / (2**nbits - 1)
20
zeros = -amin / scales
21
- q_kv = precise_round((kv - amin) / scales).to(torch.uint8)
+ q_kv = (kv / scales + zeros + 0.5).to(torch.uint8)
22
if nbits == 4:
23
q_kv1, q_kv2 = q_kv.split(q_kv.shape[-1] // 2, -1)
24
q_kv = q_kv1 + q_kv2 * 16
0 commit comments