8000 Align UT with triton fill_kv_cache_quant kernel (#2644) · InferenceNexus/lmdeploy@8e794b7 · GitHub
[go: up one dir, main page]

10000
Skip to content

Commit 8e794b7

Browse files
authored
Align UT with triton fill_kv_cache_quant kernel (InternLM#2644)
1 parent d8f9e35 commit 8e794b7

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

tests/pytorch/kernel/test_fill_kv_cache.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,13 @@ def _div_up(a, b):
88
return (a + b - 1) // b
99

1010

11-
def precise_round(x: torch.Tensor):
12-
return x.sign() * (x.abs() + 0.5).floor()
13-
14-
1511
def quant(kv: torch.Tensor, nbits: int = 8):
1612
"""Quant kv on the head_dim."""
1713
amax = kv.amax(dim=-1, keepdim=True)
1814
amin = kv.amin(dim=-1, keepdim=True)
1915
scales = (amax - amin) / (2**nbits - 1)
2016
zeros = -amin / scales
21-
q_kv = precise_round((kv - amin) / scales).to(torch.uint8)
17+
q_kv = (kv / scales + zeros + 0.5).to(torch.uint8)
2218
if nbits == 4:
2319
q_kv1, q_kv2 = q_kv.split(q_kv.shape[-1] // 2, -1)
2420
q_kv = q_kv1 + q_kv2 * 16

0 commit comments

Comments
 (0)
0