8000 Update · pytorch/pytorch@f3035d4 · GitHub
[go: up one dir, main page]

Skip to content

Commit f3035d4

Browse files
committed
Update
[ghstack-poisoned]
2 parents 35945bf + 5ecf03c commit f3035d4

File tree

325 files changed

+2174
-2353
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

325 files changed

+2174
-2353
lines changed

.flake8

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ max-line-length = 120
88
# E501 is not flexible enough, we're using B950 instead
99
ignore =
1010
E203,E305,E402,E501,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,
11-
# type stub in .py files formatted by black
12-
E704,
1311
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
1412
# to line this up with executable bit
1513
EXE001,

.github/scripts/build_triton_wheel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ def main() -> None:
178178

179179
build_triton(
180180
build_rocm=args.build_rocm,
181-
commit_hash=(
182-
args.commit_hash if args.commit_hash else read_triton_pin(args.build_rocm)
183-
),
181+
commit_hash=args.commit_hash
182+
if args.commit_hash
183+
else read_triton_pin(args.build_rocm),
184184
version=args.triton_version,
185185
build_conda=args.build_conda,
186186
py_version=args.py_version,

.github/scripts/trymerge.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,9 +1419,9 @@ def find_matching_merge_rule(
14191419
pending_checks, failed_checks, _ = categorize_checks(
14201420
checks,
14211421
required_checks,
1422-
ok_failed_checks_threshold=(
1423-
IGNORABLE_FAILED_CHECKS_THESHOLD if rule.ignore_flaky_failures else 0
1424-
),
1422+
ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD
1423+
if rule.ignore_flaky_failures
1424+
else 0,
14251425
)
14261426

14271427
# categorize_checks assumes all tests are required if required_checks is empty.
@@ -2202,9 +2202,9 @@ def merge(
22022202
checks,
22032203
required_checks
22042204
+ [x for x in checks.keys() if x not in required_checks],
2205-
ok_failed_checks_threshold=(
2206-
IGNORABLE_FAILED_CHECKS_THESHOLD if ignore_flaky_failures else 0
2207-
),
2205+
ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD
2206+
if ignore_flaky_failures
2207+
else 0,
22082208
)
22092209
# HACK until GitHub will be better about surfacing those
22102210
startup_failures = filter_checks_with_lambda(

.lintrunner.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1694,7 +1694,7 @@ init_command = [
16941694
'tools/linter/adapters/pip_init.py',
16951695
'--dry-run={{DRYRUN}}',
16961696
'--no-black-binary',
1697-
'black==24.4.2',
1697+
'black==23.12.1',
16981698
'ufmt==2.7.0',
16991699
'usort==1.0.8.post1',
17001700
'isort==5.13.2',

benchmarks/distributed/rpc/rl/launcher.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ def main():
209209
x_axis_variables
210210
): # run benchmark for every x axis variable
211211
if len(x_axis_variables) > 1:
212-
# Set x axis variable for this benchmark iteration
213-
args[args["x_axis_name"]] = x_axis_variable
212+
args[
213+
args["x_axis_name"]
214+
] = x_axis_variable # set x axis variable for this benchmark iteration
214215
processes = []
215216
start_time = time.time()
216217
for rank in range(args["world_size"]):

benchmarks/dynamo/common.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,10 +1360,12 @@ def _generate_onnx_model_directory(
13601360
return model_path
13611361

13621362
@abc.abstractmethod
1363-
def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]: ...
1363+
def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]:
1364+
...
13641365

13651366
@abc.abstractmethod
1366-
def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]: ...
1367+
def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]:
1368+
...
13671369

13681370
def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, np.ndarray]:
13691371
pt_inputs = self.format_pt_inputs(pt_inputs)
@@ -2016,9 +2018,9 @@ def cast_to(dtype, model, inputs):
20162018
model = model.to(dtype)
20172019

20182020
inputs = tree_map(
2019-
lambda x: (
2020-
x.to(dtype) if isinstance(x, torch.Tensor) and x.is_floating_point() else x
2021-
),
2021+
lambda x: x.to(dtype)
2022+
if isinstance(x, torch.Tensor) and x.is_floating_point()
2023+
else x,
20222024
inputs,
20232025
)
20242026
return model, inputs
@@ -2450,11 +2452,9 @@ def deepcopy_and_maybe_parallelize(self, model):
24502452
model = FSDP(
24512453
model,
24522454
use_orig_params=True,
2453-
device_id=(
2454-
torch.cuda.current_device()
2455-
if self.args.devices[-1] == "cuda"
2456-
else None
2457-
),
2455+
device_id=torch.cuda.current_device()
2456+
if self.args.devices[-1] == "cuda"
2457+
else None,
24582458
mixed_precision=mp_policy,
24592459
limit_all_gathers=True,
24602460
auto_wrap_policy=self.get_fsdp_auto_wrap_policy(self.args.only),
@@ -2519,11 +2519,9 @@ def record_status(accuracy_status, dynamo_start_stats):
25192519
self.init_optimizer(name, current_device, model_fp64.parameters())
25202520
fp64_outputs = self.run_n_iterations(model_fp64, inputs_fp64)
25212521
fp64_outputs = tree_map(
2522-
lambda x: (
2523-
x.to(torch.float64)
2524-
if isinstance(x, torch.Tensor) and x.is_floating_point()
2525-
else x
2526-
),
2522+
lambda x: x.to(torch.float64)
2523+
if isinstance(x, torch.Tensor) and x.is_floating_point()
2524+
else x,
25272525
fp64_outputs,
25282526
)
25292527
except Exception:
@@ -2913,9 +2911,9 @@ def warmup(fn, model, example_inputs, mode, niters=5):
29132911
experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem
29142912
experiment_kwargs["dynamo_stats"] = dynamo_stats
29152913
if self.args.profile_dynamo_cache_lookup:
2916-
experiment_kwargs["cache_lookup_latency"] = (
2917-
dynamo_cache_lookup_latency
2918-
)
2914+
experiment_kwargs[
2915+
"cache_lookup_latency"
2916+
] = dynamo_cache_lookup_latency
29192917

29202918
if experiment.func is coverage_experiment:
29212919
ok, total = Stats.reset_counters()

benchmarks/dynamo/microbenchmarks/analyze_templates.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
55
That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates.
66
"""
7-
87
import json
98

109
import click

benchmarks/instruction_counts/applications/ci.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Collect instruction counts for continuous integration."""
2-
32
import argparse
43
import hashlib
54
import json

benchmarks/instruction_counts/core/api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Key enums and structs used to handle data flow within the benchmark."""
2-
32
import dataclasses
43
import enum
54
import itertools as it

benchmarks/instruction_counts/core/expand.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
33
This is mostly string manipulation, with just a bit of importlib magic.
44
"""
5-
65
import importlib.abc
76
import importlib.util
87
import itertools as it

0 commit comments

Comments
 (0)
0