8000 GH-135904: Optimize the JIT's assembly control flow (GH-135905) · python/cpython@0e5d096 · GitHub
[go: up one dir, main page]

Skip to content
< 8000 /react-partial>

Commit 0e5d096

Browse files
authored
GH-135904: Optimize the JIT's assembly control flow (GH-135905)
1 parent 0141e7f commit 0e5d096

File tree

4 files changed

+352
-94
lines changed

4 files changed

+352
-94
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Perform more aggressive control-flow optimizations on the machine code
2+
templates emitted by the experimental JIT compiler.

Tools/jit/_optimizers.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
"""Low-level optimization of textual assembly."""
2+
3+
import dataclasses
4+
import pathlib
5+
import re
6+
import typing
7+
8+
# Same as saying "not string.startswith('')":
9+
_RE_NEVER_MATCH = re.compile(r"(?!)")
10+
# Dictionary mapping branch instructions to their inverted branch instructions.
11+
# If a branch cannot be inverted, the value is None:
12+
_X86_BRANCHES = {
13+
# https://www.felixcloutier.com/x86/jcc
14+
"ja": "jna",
15+
"jae": "jnae",
16+
"jb": "jnb",
17+
"jbe": "jnbe",
18+
"jc": "jnc",
19+
"jcxz": None,
20+
"je": "jne",
21+
"jecxz": None,
22+
"jg": "jng",
23+
"jge": "jnge",
24+
"jl": "jnl",
25+
"jle": "jnle",
26+
"jo": "jno",
27+
"jp": "jnp",
28+
"jpe": "jpo",
29+
"jrcxz": None,
30+
"js": "jns",
31+
"jz": "jnz",
32+
# https://www.felixcloutier.com/x86/loop:loopcc
33+
"loop": None,
34+
"loope": None,
35+
"loopne": None,
36+
"loopnz": None,
37+
"loopz": None,
38+
}
39+
# Update with all of the inverted branches, too:
40+
_X86_BRANCHES |= {v: k for k, v in _X86_BRANCHES.items() if v}
41+
42+
43+
@dataclasses.dataclass
44+
class _Block:
45+
label: str | None = None
46+
# Non-instruction lines like labels, directives, and comments:
47+ 6D40
noninstructions: list[str] = dataclasses.field(default_factory=list)
48+
# Instruction lines:
49+
instructions: list[str] = dataclasses.field(default_factory=list)
50+
# If this block ends in a jump, where to?
51+
target: typing.Self | None = None
52+
# The next block in the linked list:
53+
link: typing.Self | None = None
54+
# Whether control flow can fall through to the linked block above:
55+
fallthrough: bool = True
56+
# Whether this block can eventually reach the next uop (_JIT_CONTINUE):
57+
hot: bool = False
58+
59+
def resolve(self) -> typing.Self:
60+
"""Find the first non-empty block reachable from this one."""
61+
block = self
62+
while block.link and not block.instructions:
63+
block = block.link
64+
return block
65+
66+
67+
@dataclasses.dataclass
68+
class Optimizer:
69+
"""Several passes of analysis and optimization for textual assembly."""
70+
71+
path: pathlib.Path
72+
_: dataclasses.KW_ONLY
73+
# prefix used to mangle symbols on some platforms:
74+
prefix: str = ""
75+
# The first block in the linked list:
76+
_root: _Block = dataclasses.field(init=False, default_factory=_Block)
77+
_labels: dict[str, _Block] = dataclasses.field(init=False, default_factory=dict)
78+
# No groups:
79+
_re_noninstructions: typing.ClassVar[re.Pattern[str]] = re.compile(
80+
r"\s*(?:\.|#|//|$)"
81+
)
82+
# One group (label):
83+
_re_label: typing.ClassVar[re.Pattern[str]] = re.compile(
84+
r'\s*(?P<label>[\w."$?@]+):'
85+
)
86+
# Override everything that follows in subclasses:
87+
_alignment: typing.ClassVar[int] = 1
88+
_branches: typing.ClassVar[dict[str, str | None]] = {}
89+
# Two groups (instruction and target):
90+
_re_branch: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
91+
# One group (target):
92+
_re_jump: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH< F438 /span>
93+
# No groups:
94+
_re_return: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
95+
96+
def __post_init__(self) -> None:
97+
# Split the code into a linked list of basic blocks. A basic block is an
98+
# optional label, followed by zero or more non-instruction lines,
99+
# followed by zero or more instruction lines (only the last of which may
100+
# be a branch, jump, or return):
101+
text = self._preprocess(self.path.read_text())
102+
block = self._root
103+
for line in text.splitlines():
104+
# See if we need to start a new block:
105+
if match := self._re_label.match(line):
106+
# Label. New block:
107+
block.link = block = self._lookup_label(match["label"])
108+
block.noninstructions.append(line)
109+
continue
110+
if self._re_noninstructions.match(line):
111+
if block.instructions:
112+
# Non-instruction lines. New block:
113+
block.link = block = _Block()
114+
block.noninstructions.append(line)
115+
continue
116+
if block.target or not block.fallthrough:
117+
# Current block ends with a branch, jump, or return. New block:
118+
block.link = block = _Block()
119+
block.instructions.append(line)
120+
if match := self._re_branch.match(line):
121+
# A block ending in a branch has a target and fallthrough:
122+
block.target = self._lookup_label(match["target"])
123+
assert block.fallthrough
124+
elif match := self._re_jump.match(line):
125+
# A block ending in a jump has a target and no fallthrough:
126+
block.target = self._lookup_label(match["target"])
127+
block.fallthrough = False
128+
elif self._re_return.match(line):
129+
# A block ending in a return has no target and fallthrough:
130+
assert not block.target
131+
block.fallthrough = False
132+
133+
def _preprocess(self, text: str) -> str:
134+
# Override this method to do preprocessing of the textual assembly:
135+
return text
136+
137+
@classmethod
138+
def _invert_branch(cls, line: str, target: str) -> str | None:
139+
match = cls._re_branch.match(line)
140+
assert match
141+
inverted = cls._branches.get(match["instruction"])
142+
if not inverted:
143+
return None
144+
(a, b), (c, d) = match.span("instruction"), match.span("target")
145+
# Before:
146+
# je FOO
147+
# After:
148+
# jne BAR
149+
return "".join([line[:a], inverted, line[b:c], target, line[d:]])
150+
151+
@classmethod
152+
def _update_jump(cls, line: str, target: str) -> str:
153+
match = cls._re_jump.match(line)
154+
assert match
155+
a, b = match.span("target")
156+
# Before:
157+
# jmp FOO
158+
# After:
159+
# jmp BAR
160+
return "".join([line[:a], target, line[b:]])
161+
162+
def _lookup_label(self, label: str) -> _Block:
163+
if label not in self._labels:
164+
self._labels[label] = _Block(label)
165+
return self._labels[label]
166+
167+
def _blocks(self) -> typing.Generator[_Block, None, None]:
168+
block: _Block | None = self._root
169+
while block:
170+
yield block
171+
block = block.link
172+
173+
def _body(self) -> str:
174+
lines = []
175+
hot = True
176+
for block in self._blocks():
177+
if hot != block.hot:
178+
hot = block.hot
179+
# Make it easy to tell at a glance where cold code is:
180+
lines.append(f"# JIT: {'HOT' if hot else 'COLD'} ".ljust(80, "#"))
181+
lines.extend(block.noninstructions)
182+
lines.extend(block.instructions)
183+
return "\n".join(lines)
184+
185+
def _predecessors(self, block: _Block) -> typing.Generator[_Block, None, None]:
186+
# This is inefficient, but it's never wrong:
187+
for pre in self._blocks():
188+
if pre.target is block or pre.fallthrough and pre.link is block:
189+
yield pre
190+
191+
def _insert_continue_label(self) -> None:
192+
# Find the block with the last instruction:
193+
for end in reversed(list(self._blocks())):
194+
if end.instructions:
195+
break
196+
# Before:
197+
# jmp FOO
198+
# After:
199+
# jmp FOO
200+
# .balign 8
201+
# _JIT_CONTINUE:
202+
# This lets the assembler encode _JIT_CONTINUE jumps at build time!
203+
align = _Block()
204+
align.noninstructions.append(f"\t.balign\t{self._alignment}")
205+
continuation = self._lookup_label(f"{self.prefix}_JIT_CONTINUE")
206+
assert continuation.label
207+
continuation.noninstructions.append(f"{continuation.label}:")
208+
end.link, align.link, continuation.link = align, continuation, end.link
209+
210+
def _mark_hot_blocks(self) -> None:
211+
# Start with the last block, and perform a DFS to find all blocks that
212+
# can eventually reach it:
213+
todo = list(self._blocks())[-1:]
214+
while todo:
215+
block = todo.pop()
216+
block.hot = True
217+
todo.extend(pre for pre in self._predecessors(block) if not pre.hot)
218+
219+
def _invert_hot_branches(self) -> None:
220+
for branch in self._blocks():
221+
link = branch.link
222+
if link is None:
223+
continue
224+
jump = link.resolve()
225+
# Before:
226+
# je HOT
227+
# jmp COLD
228+
# After:
229+
# jne COLD
230+
# jmp HOT
231+
if (
232+
# block ends with a branch to hot code...
233+
branch.target
234+
and branch.fallthrough
235+
and branch.target.hot
236+
# ...followed by a jump to cold code with no other predecessors:
237+
and jump.target
238+
and not jump.fallthrough
239+
and not jump.target.hot
240+
and len(jump.instructions) == 1
241+
and list(self._predecessors(jump)) == [branch]
242+
):
243+
assert jump.target.label
244+
assert branch.target.label
245+
inverted = self._invert_branch(
246+
branch.instructions[-1], jump.target.label
247+
)
248+
# Check to see if the branch can even be inverted:
249+
if inverted is None:
250+
continue
251+
branch.instructions[-1] = inverted
252+
jump.instructions[-1] = self._update_jump(
253+
jump.instructions[-1], branch.target.label
254+
)
255+
branch.target, jump.target = jump.target, branch.target
256+
jump.hot = True
257+
258+
def _remove_redundant_jumps(self) -> None:
259+
# Zero-length jumps can be introduced by _insert_continue_label and
260+
# _invert_hot_branches:
261+
for block in self._blocks():
262+
# Before:
263+
# jmp FOO
264+
# FOO:
265+
# After:
266+
# FOO:
267+
if (
268+
block.target
269+
and block.link
270+
and block.target.resolve() is block.link.resolve()
271+
):
272+
block.target = None
273+
block.fallthrough = True
274+
block.instructions.pop()
275+
276+
def run(self) -> None:
277+
"""Run this optimizer."""
278+
self._insert_continue_label()
279+
self._mark_hot_blocks()
280+
self._invert_hot_branches()
281+
self._remove_redundant_jumps()
282+
self.path.write_text(self._body())
283+
284+
285+
class OptimizerAArch64(Optimizer): # pylint: disable = too-few-public-methods
286+
"""aarch64-apple-darwin/aarch64-pc-windows-msvc/aarch64-unknown-linux-gnu"""
287+
288+
# TODO: @diegorusso
289+
_alignment = 8
290+
# https://developer.arm.com/documentation/ddi0602/2025-03/Base-Instructions/B--Branch-
291+
_re_jump = re.compile(r"\s*b\s+(?P<target>[\w.]+)")
292+
293+
294+
class OptimizerX86(Optimizer): # pylint: disable = too-few-public-methods
295+
"""i686-pc-windows-msvc/x86_64-apple-darwin/x86_64-unknown-linux-gnu"""
296+
297+
_branches = _X86_BRANCHES
298+
_re_branch = re.compile(
299+
rf"\s*(?P<instruction>{'|'.join(_X86_BRANCHES)})\s+(?P<target>[\w.]+)"
300+
)
301+
# https://www.felixcloutier.com/x86/jmp
302+
_re_jump = re.compile(r"\s*jmp\s+(?P<target>[\w.]+)")
303+
# https://www.felixcloutier.com/x86/ret
304+
_re_return = re.compile(r"\s*ret\b")
305+
306+
307+
class OptimizerX8664Windows(OptimizerX86): # pylint: disable = too-few-public-methods
308+
"""x86_64-pc-windows-msvc"""
309+
310+
def _preprocess(self, text: str) -> str:
311+
text = super()._preprocess(text)
312+
# Before:
313+
# rex64 jmpq *__imp__JIT_CONTINUE(%rip)
314+
# After:
315+
# jmp _JIT_CONTINUE
316+
far_indirect_jump = (
317+
rf"rex64\s+jmpq\s+\*__imp_(?P<target>{self.prefix}_JIT_\w+)\(%rip\)"
318+
)
319+
return re.sub(far_indirect_jump, r"jmp\t\g<target>", text)

0 commit comments

Comments
 (0)
0