8000 Script to generate NJT OpInfo testing report · pytorch/pytorch@0c62344 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0c62344

Browse files
committed
Script to generate NJT OpInfo testing report
ghstack-source-id: 91ca403 Pull Request resolved: #143311
1 parent ff4f796 commit 0c62344

File tree

1 file changed

+376
-0
lines changed

1 file changed

+376
-0
lines changed

test/gen_njt_op_report.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import Dict, List, Optional
4+
5+
from test_nestedtensor import (
6+
BACKWARD_SKIPS_AND_XFAILS,
7+
COMPILE_BACKWARD_SKIPS_AND_XFAILS,
8+
COMPILE_FORWARD_SKIPS_AND_XFAILS,
9+
FORWARD_SKIPS_AND_XFAILS,
10+
)
11+
12+
import torch
13+
from torch.nested._internal.nested_tensor import NestedTensor
14+
from torch.testing._internal.opinfo.core import SampleRule, SkipRule, XFailRule
15+
from torch.testing._internal.opinfo.definitions.nested import njt_op_db
16+
from torch.utils._pytree import tree_flatten
17+
18+
19+
# Contains info about a single SampleInput
20+
@dataclass
21+
class SampleInfo:
22+
# the first rule that was matched; could be a skip or an xfail. None == success
23+
matched_rule: Optional[SampleRule] = None
24+
# are all NJT inputs contiguous?
25+
contiguous: bool = True
26+
27+
28+
# Contains info about all SampleInputs for a single test set (e.g. for forward tests)
29+
@dataclass
30+
class TestResultInfo:
31+
device_type: str = None
32+
dtype: torch.dtype = torch.float32
33+
sample_infos: List[SampleInfo] = None
34+
35+
def __post_init__(self):
36+
if self.device_type is None:
37+
raise ValueError("device_type must be set")
38+
if self.sample_infos is None:
39+
raise ValueError("sample_infos must be set")
40+
41+
def num_xfails(self, contiguous=None):
42+
if contiguous is None:
43+
return len(
44+
[
45+
si
46+
for si in self.sample_infos
47+
if isinstance(si.matched_rule, XFailRule)
48+
]
49+
)
50+
else:
51+
return len(
52+
[
53+
si
54+
for si in self.sample_infos
55+
if isinstance(si.matched_rule, XFailRule)
56+
and si.contiguous == contiguous
57+
]
58+
)
59+
60+
def num_skips(self, contiguous=None):
61+
if contiguous is None:
62+
return len(
63+
[
64+
si
65+
for si in self.sample_infos
66+
if isinstance(si.matched_rule, SkipRule)
67+
]
68+
)
69+
else:
70+
return len(
71+
[
72+
si
73+
for si in self.sample_infos
74+
if isinstance(si.matched_rule, SkipRule)
75+
and si.contiguous == contiguous
76+
]
77+
)
78+
79+
def num_samples(self, contiguous=None):
80+
if contiguous is None:
81+
return len(self.sample_infos)
82+
else:
83+
return len([si for si in self.sample_infos if si.contiguous == contiguous])
84+
85+
def success_rate(self, contiguous=None):
86+
num_samples = self.num_samples(contiguous)
87+
if num_samples == 0:
88+
# avoid division by 0
89+
return None
90+
return 1 - (
91+
float(self.num_xfails(contiguous) + self.num_skips(contiguous))
92+
/ float(num_samples)
93+
)
94+
95+
96+
# Status around op support for NJT OpInfo tests.
97+
# We're only able to get info for those ops that have an NJT-compatible OpInfo entry.
98+
# The op may not have one or it may not support float32, which means the tests won't run (yet).
99+
class OpInfoStatus(Enum):
100+
VALID_OPINFO = 1
101+
NO_OPINFO = 2
102+
NO_FLOAT32_SUPPORT = 3
103+
104+
105+
# Contains info about all test sets for a given op
106+
@dataclass
107+
class OpTestResultInfo:
108+
status: OpInfoStatus = None
109+
# mapping from test set name (e.g. "forward") -> test set info.
110+
# should be None if this isn't able to be calculated.
111+
test_results: Dict[str, TestResultInfo] = None
112+
113+
def __post_init__(self):
114+
if self.status is None:
115+
raise ValueError("status must be set")
116+
117+
118+
@dataclass
119+
class TestSet:
120+
name: str = None
121+
skips_and_xfails: List[SampleRule] = None
122+
needs_requires_grad: bool = False
123+
124+
def __post_init__(self):
125+
if self.name is None:
126+
raise ValueError("name must be set")
127+
if self.skips_and_xfails is None:
128+
raise ValueError("skips_and_xfails must be set")
129+
130+
131+
TEST_SETS = [
132+
TestSet(
133+
name="forward",
134+
skips_and_xfails=FORWARD_SKIPS_AND_XFAILS,
135+
needs_requires_grad=False,
136+
),
137+
TestSet(
138+
name="compile_forward",
139+
skips_and_xfails=COMPILE_FORWARD_SKIPS_AND_XFAILS,
140+
needs_requires_grad=False,
141+
),
142+
TestSet(
143+
name="backward",
144+
skips_and_xfails=BACKWARD_SKIPS_AND_XFAILS,
145+
needs_requires_grad=True,
146+
),
147+
TestSet(
148+
name="compile_backward",
149+
skips_and_xfails=COMPILE_BACKWARD_SKIPS_AND_XFAILS,
150+
needs_requires_grad=True,
151+
),
152+
]
153+
154+
155+
def write_header(f):
156+
f.write(
157+
"""
158+
<html>
159+
<head>
160+
<title>NJT OpInfo Testing</title>
161+
<style>
162+
table {
163+
border-collapse: collapse;
164+
width: 98%;
165+
color: #333;
166+
font-family: Arial, sans-serif;
167+
font-size: 13px;
168+
text-align: left;
169+
border-radius: 10px;
170+
box-shadow: 0 0 20px rgba(0, 0, 0, 0.1);
171+
margin-top: 0px;
172+
}
173+
body: {
174+
background-color: #fdfdfd;
175+
}
176+
thead, tbody: {
177+
display: block;
178+
}
179+
table thead {
180+
position: sticky;
181+
top: 0;
182+
background-color: #232a25;
183+
border-radius: 10px;
184+
}
185+
table th {
186+
background-color: #232a25;
187+
color: #ddd;
188+
font-weight: bold;
189+
padding: 10px;
190+
text-transform: uppercase;
191+
letter-spacing: 1px;
192+
border-bottom: 1px solid #ccc;
193+
}
194+
table th:first-child {
195+
border-radius: 10px 0 0 0;
196+
box-shadow: 0 -2.1rem 0 .6rem #fdfdfd
197+
}
198+
table th:last-child {
199+
border-radius: 0 10px 0 0;
200+
box-shadow: 1rem -2.1rem 0 .6rem #fdfdfd
201+
}
202+
table td {
203+
background-color: #ddd;
204+
padding: 10px;
205+
border-bottom: 1px solid #333;
206+
border-right: 1px solid #333;
207+
font-weight: bold;
208+
}
209+
table tr td:last-child {
210+
border-right: 0px solid #333
211+
}
212+
table tr:last-child td {
213+
border-bottom: 0px solid #333
214+
}
215+
table tr:last-child td:first-child {
216+
border-radius: 0 0 0 10px;
217+
}
218+
table tr:last-child td:last-child {
219+
border-radius: 0 0 10px 0;
220+
}
221+
</style>
222+
</head>
223+
<body>
224+
<table>
225+
<thead>
226+
<tr>
227+
<th>Op Name</th>
228+
"""
229+
)
230+
231+
for contiguous in [True, False]:
232+
for test_set in TEST_SETS:
233+
f.write(
234+
f"""
235+
<th>{test_set.name}<br/>({"contiguous" if contiguous else "non-contiguous"})</th>
236+
"""
237+
)
238+
239+
f.write(
240+
"""
241+
</tr>
242+
</thead>
243+
<tbody>
244+
"""
245+
)
246+
247+
248+
def write_footer(f):
249+
f.write(
250+
"""
251+
</tbody>
252+
</table>
253+
</body>
254+
</html>
255+
"""
256+
)
257+
258+
259+
# returns (fgcolor, bgcolor) for given success rate
260+
def success_colors(success_rate):
261+
if success_rate is None:
262+
return ("#333333", "#dddddd")
263+
elif success_rate == 1.0:
264+
# green
265+
return ("#dddddd", "#495b52")
266+
elif success_rate > 0.5:
267+
# yellow
268+
return ("#dddddd", "#a47146")
269+
else:
270+
# red
271+
return ("#dddddd", "#a04c46")
272+
273+
274+
def write_table_row(f, op, result):
275+
table_row = f"<tr><td>{op.full_name}</td>"
276+
for contiguous in [True, False]:
277+
for test_set in TEST_SETS:
278+
success_rate = None
279+
if result.status == OpInfoStatus.VALID_OPINFO:
280+
set_result = result.test_results[test_set.name]
281+
success_rate = set_result.success_rate(contiguous)
282+
if success_rate is None:
283+
success_rate_text = "N/A"
284+
elif success_rate == 1.0:
285+
success_rate_text = f"{success_rate:.2f}"
286+
else:
287+
success_rate_text = f"{success_rate:.2f}"
288+
success_rate_text += (
289+
f" (xfails: {set_result.num_xfails(contiguous)}, "
290+
f"skips: {set_result.num_skips(contiguous)}, "
291+
f"samples: {set_result.num_samples(contiguous)})"
292+
)
293+
elif result.status == OpInfoStatus.NO_OPINFO:
294+
success_rate_text = "N/A (no OpInfo)"
295+
elif result.status == OpInfoStatus.NO_FLOAT32_SUPPORT:
296+
success_rate_text = "N/A (no float32 support)"
297+
else:
298+
raise ValueError("invalid OpInfoStatus encountered")
299+
300+
fgcolor, bgcolor = success_colors(success_rate)
301+
style = f'style="color: {fgcolor}; background-color: {bgcolor}"'
302+
table_row += f'<td {style}">{success_rate_text}</td>\n'
303+
304+
table_row += "</tr>"
305+
f.write(table_row)
306+
307+
308+
# Returns first matched rule or None if none matched
309+
def match_first_rule(rules, sample, device):
310+
for rule in rules:
311+
if rule.sample_match_fn(device, sample):
312+
return rule
313+
return None
314+
315+
316+
def get_test_result_info(
317+
op, rules, device_type, dtype, requires_grad
318+
) -> TestResultInfo:
319+
device = torch.device(device_type)
320+
sample_infos = []
321+
op_rules = [rule for rule in rules if rule.op_match_fn(device_type, op)]
322+
for sample in op.sample_inputs(
323+
device=device_type, dtype=dtype, requires_grad=requires_grad
324+
):
325+
all_njts_contiguous = all(
326+
njt.is_contiguous()
327+
for njt in tree_flatten((sample.input, sample.args, sample.kwargs))[0]
328+
if isinstance(njt, NestedTensor)
329+
)
330+
331+
sample_infos.append(
332+
SampleInfo(
333+
matched_rule=match_first_rule(op_rules, sample, device),
334+
contiguous=all_njts_contiguous,
335+
)
336+
)
337+
338+
return TestResultInfo(
339+
device_type=device_type,
340+
dtype=dtype,
341+
sample_infos=sample_infos,
342+
)
343+
344+
345+
with open("njt_op_report.html", "w") as f:
346+
write_header(f)
347+
njt_op_db.sort(key=lambda op: op.full_name)
348+
for op in njt_op_db:
349+
# TODO: un-hardcode these
350+
device_type = "cuda"
351+
supported_dtypes = op.dtypesIfCUDA if device_type == "cuda" else op.dtypes
352+
dtype = torch.float32
353+
354+
if not op.supports_njt:
355+
result = OpTestResultInfo(status=OpInfoStatus.NO_OPINFO)
356+
elif dtype not in supported_dtypes:
357+
result = OpTestResultInfo(status=OpInfoStatus.NO_FLOAT32_SUPPORT)
358+
else:
359+
test_results = {
360+
test_set.name: get_test_result_info(
361+
op,
362+
test_set.skips_and_xfails,
363+
device_type,
364+
dtype,
365+
test_set.needs_requires_grad,
366+
)
367+
for test_set in TEST_SETS
368+
}
369+
result = OpTestResultInfo(
370+
status=OpInfoStatus.VALID_OPINFO,
371+
test_results=test_results,
372+
)
373+
374+
write_table_row(f, op, result)
375+
376+
write_footer(f)

0 commit comments

Comments
 (0)
0