8000 [mypyc] Use a native unboxed representation for floats (#14880) · python/mypy@d05974b · GitHub
[go: up one dir, main page]

Skip to content

Commit d05974b

Browse files
authored
[mypyc] Use a native unboxed representation for floats (#14880)
Instead of each float value being a heap-allocated Python object, use unboxed C doubles to represent floats. This makes float operations much faster, and this also significantly reduces memory use of floats (when not stored in Python containers, which always use a boxed representation). Update IR to support float arithmetic and comparison ops, and float literals. Also add a few primitives corresponding to common math functions, such as `math.sqrt`. These don't require any boxing or unboxing. (I will add more of these in follow-up PRs.) Use -113.0 as an overlapping error value for floats. This is similar to native ints. Reuse much of the infrastructure we have to support overlapping error values with native ints (e.g. various bitmaps). Also improve support for negative float literals. There are two backward compatibility breaks worth highlighting. First, assigning an int value to a float variable is disallowed within mypyc, since narrowing down to a different value representation is inefficient and can lose precision. Second, information about float subclasses is lost during unboxing. This makes the bm_float benchmark about 5x faster and the raytrace benchmark about 3x faster. Closes mypyc/mypyc#966 (I'll create separate issues for remaining open issues).
1 parent 486a51b commit d05974b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2251
-123
lines changed

mypy/constant_fold.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non
6464
value = constant_fold_expr(expr.expr, cur_mod_id)
6565
if isinstance(value, int):
6666
return constant_fold_unary_int_op(expr.op, value)
67+
if isinstance(value, float):
68+
return constant_fold_unary_float_op(expr.op, value)
6769
return None
6870

6971

@@ -110,6 +112,14 @@ def constant_fold_unary_int_op(op: str, value: int) -> int | None:
110112
return None
111113

112114

115+
def constant_fold_unary_float_op(op: str, value: float) -> float | None:
116+
if op == "-":
117+
return -value
118+
elif op == "+":
119+
return value
120+
return None
121+
122+
113123
def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None:
114124
if op == "+":
115125
return left + right

mypyc/analysis/dataflow.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
ComparisonOp,
1919
ControlOp,
2020
Extend,
21+
Float,
22+
FloatComparisonOp,
23+
FloatNeg,
24+
FloatOp,
2125
GetAttr,
2226
GetElementPtr,
2327
Goto,
@@ -245,9 +249,18 @@ def visit_load_global(self, op: LoadGlobal) -> GenAndKill[T]:
245249
def visit_int_op(self, op: IntOp) -> GenAndKill[T]:
246250
return self.visit_register_op(op)
247251

252+
def visit_float_op(self, op: FloatOp) -> GenAndKill[T]:
253+
return self.visit_register_op(op)
254+
255+
def visit_float_neg(self, op: FloatNeg) -> GenAndKill[T]:
256+
return self.visit_register_op(op)
257+
248258
def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill[T]:
249259
return self.visit_register_op(op)
250260

261+
def visit_float_comparison_op(self, op: FloatComparisonOp) -> GenAndKill[T]:
262+
return self.visit_register_op(op)
263+
251264
def visit_load_mem(self, op: LoadMem) -> GenAndKill[T]:
252265
return self.visit_register_op(op)
253266

@@ -444,7 +457,7 @@ def analyze_undefined_regs(
444457
def non_trivial_sources(op: Op) -> set[Value]:
445458
result = set()
446459
for source in op.sources():
447-
if not isinstance(source, Integer):
460+
if not isinstance(source, (Integer, Float)):
448461
result.add(source)
449462
return result
450463

@@ -454,7 +467,7 @@ def visit_branch(self, op: Branch) -> GenAndKill[Value]:
454467
return non_trivial_sources(op), set()
455468

456469
def visit_return(self, op: Return) -> GenAndKill[Value]:
457-
if not isinstance(op.value, Integer):
470+
if not isinstance(op.value, (Integer, Float)):
458471
return {op.value}, set()
459472
else:
460473
return set(), set()

mypyc/analysis/ircheck.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
ControlOp,
1717
DecRef,
1818
Extend,
19+
FloatComparisonOp,
20+
FloatNeg,
21+
FloatOp,
1922
GetAttr,
2023
GetElementPtr,
2124
Goto,
@@ -43,6 +46,7 @@
4346
TupleSet,
4447
Unbox,
4548
Unreachable,
49+
Value,
4650
)
4751
from mypyc.ir.pprint import format_func
4852
from mypyc.ir.rtypes import (
@@ -54,6 +58,7 @@
5458
bytes_rprimitive,
5559
dict_rprimitive,
5660
int_rprimitive,
61+
is_float_rprimitive,
5762
is_object_rprimitive,
5863
list_rprimitive,
5964
range_rprimitive,
@@ -221,6 +226,14 @@ def check_compatibility(self, op: Op, t: RType, s: RType) -> None:
221226
if not can_coerce_to(t, s) or not can_coerce_to(s, t):
222227
self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible")
223228

229+
def expect_float(self, op: Op, v: Value) -> None:
230+
if not is_float_rprimitive(v.type):
231+
self.fail(op, f"Float expected (actual type is {v.type})")
232+
233+
def expect_non_float(self, op: Op, v: Value) -> None:
234+
if is_float_rprimitive(v.type):
235+
self.fail(op, "Float not expected")
236+
224237
def visit_goto(self, op: Goto) -> None:
225238
self.check_control_op_targets(op)
226239

@@ -376,10 +389,24 @@ def visit_load_global(self, op: LoadGlobal) -> None:
376389
pass
377390

378391
def visit_int_op(self, op: IntOp) -> None:
379-
pass
392+
self.expect_non_float(op, op.lhs)
393+
self.expect_non_float(op, op.rhs)
380394

381395
def visit_comparison_op(self, op: ComparisonOp) -> None:
382396
self.check_compatibility(op, op.lhs.type, op.rhs.type)
397+
self.expect_non_float(op, op.lhs)
398+
self.expect_non_float(op, op.rhs)
399+
400+
def visit_float_op(self, op: FloatOp) -> None:
401+
self.expect_float(op, op.lhs)
402+
self.expect_float(op, op.rhs)
403+
404+
def visit_float_neg(self, op: FloatNeg) -> None:
405+
self.expect_float(op, op.src)
406+
407+
def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
408+
self.expect_float(op, op.lhs)
409+
self.expect_float(op, op.rhs)
383410

384411
def visit_load_mem(self, op: LoadMem) -> None:
385412
pass

mypyc/analysis/selfleaks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
Cast,
1515
ComparisonOp,
1616
Extend,
17+
FloatComparisonOp,
18+
FloatNeg,
19+
FloatOp,
1720
GetAttr,
1821
GetElementPtr,
1922
Goto,
@@ -160,6 +163,15 @@ def visit_int_op(self, op: IntOp) -> GenAndKill:
160163
def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill:
161164
return CLEAN
162165

166+
def visit_float_op(self, op: FloatOp) -> GenAndKill:
167+
return CLEAN
168+
169+
def visit_float_neg(self, op: FloatNeg) -> GenAndKill:
170+
return CLEAN
171+
172+
def visit_float_comparison_op(self, op: FloatComparisonOp) -> GenAndKill:
173+
return CLEAN
174+
163175
def visit_load_mem(self, op: LoadMem) -> GenAndKill:
164176
return CLEAN
165177

mypyc/codegen/emit.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,16 @@ def emit_unbox(
895895
self.emit_line(f"{dest} = CPyLong_AsInt32({src});")
896896
# TODO: Handle 'optional'
897897
# TODO: Handle 'failure'
898+
elif is_float_rprimitive(typ):
899+
if declare_dest:
900+
self.emit_line("double {};".format(dest))
901+
# TODO: Don't use __float__ and __index__
902+
self.emit_line(f"{dest} = PyFloat_AsDouble({src});")
903+
self.emit_lines(
904+
f"if ({dest} == -1.0 && PyErr_Occurred()) {{", f"{dest} = -113.0;", "}"
905+
)
906+
# TODO: Handle 'optional'
907+
# TODO: Handle 'failure'
898908
elif isinstance(typ, RTuple):
899909
self.declare_tuple_struct(typ)
900910
if declare_dest:
@@ -983,6 +993,8 @@ def emit_box(
983993
self.emit_line(f"{declaration}{dest} = PyLong_FromLong({src});")
984994
elif is_int64_rprimitive(typ):
985995
self.emit_line(f"{declaration}{dest} = PyLong_FromLongLong({src});")
996+
elif is_float_rprimitive(typ):
997+
self.emit_line(f"{declaration}{dest} = PyFloat_FromDouble({src});")
986998
elif isinstance(typ, RTuple):
987999
self.declare_tuple_struct(typ)
9881000
self.emit_line(f"{declaration}{dest} = PyTuple_New({len(typ.types)});")

mypyc/codegen/emitclass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,7 @@ def generate_readonly_getter(
10041004
emitter.ctype_spaced(rtype), NATIVE_PREFIX, func_ir.cname(emitter.names)
10051005
)
10061006
)
1007+
emitter.emit_error_check("retval", rtype, "return NULL;")
10071008
emitter.emit_box("retval", "retbox", rtype, declare_dest=True)
10081009
emitter.emit_line("return retbox;")
10091010
else:

mypyc/codegen/emitfunc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
ComparisonOp,
2626
DecRef,
2727
Extend,
28+
Float,
29+
FloatComparisonOp,
30+
FloatNeg,
31+
FloatOp,
2832
GetAttr,
2933
GetElementPtr,
3034
Goto,
@@ -671,6 +675,27 @@ def visit_comparison_op(self, op: ComparisonOp) -> None:
671675
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
672676
self.emit_line(f"{dest} = {lhs_cast}{lhs} {op.op_str[op.op]} {rhs_cast}{rhs};")
673677

678+
def visit_float_op(self, op: FloatOp) -> None:
679+
dest = self.reg(op)
680+
lhs = self.reg(op.lhs)
681+
rhs = self.reg(op.rhs)
682+
if op.op != FloatOp.MOD:
683+
self.emit_line("%s = %s %s %s;" % (dest, lhs, op.op_str[op.op], rhs))
684+
else:
685+
# TODO: This may set errno as a side effect, that is a little sketchy.
686+
self.emit_line("%s = fmod(%s, %s);" % (dest, lhs, rhs))
687+
688+
def visit_float_neg(self, op: FloatNeg) -> None:
689+
dest = self.reg(op)
690+
src = self.reg(op.src)
691+
self.emit_line(f"{dest} = -{src};")
692+
693+
def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
694+
dest = self.reg(op)
695+
lhs = self.reg(op.lhs)
696+
rhs = self.reg(op.rhs)
697+
self.emit_line("%s = %s %s %s;" % (dest, lhs, op.op_str[op.op], rhs))
698+
674699
def visit_load_mem(self, op: LoadMem) -> None:
675700
dest = self.reg(op)
676701
src = self.reg(op.src)
@@ -732,6 +757,13 @@ def reg(self, reg: Value) -> str:
732757
elif val <= -(1 << 31):
733758
s += "LL"
734759
return s
760+
elif isinstance(reg, Float):
761+
r = repr(reg.value)
762+
if r == "inf":
763+
return "INFINITY"
764+
elif r == "-inf":
765+
return "-INFINITY"
766+
return r
735767
else:
736768
return self.emitter.reg(reg)
737769

mypyc/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"getargs.c",
7070
"getargsfast.c",
7171
"int_ops.c",
72+
"float_ops.c",
7273
"str_ops.c",
7374
"bytes_ops.c",
7475
"list_ops.c",

0 commit comments

Comments
 (0)
0