From 4748cfbeeb2854d8110cd9c05c368a857f415b13 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 28 Nov 2023 12:01:06 -0800 Subject: [PATCH 01/66] Add PyPI shield to README.md (#3) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 30fac53..b15b9b0 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # TorchFix - a linter for PyTorch-using code with autofix support +[![PyPI](https://img.shields.io/pypi/v/torchfix.svg)](https://pypi.org/project/torchfix/) + TorchFix is a Python code static analysis tool - a linter with autofix capabilities - for users of PyTorch. It can be used to find and fix issues like usage of deprecated PyTorch functions and non-public symbols, and to adopt PyTorch best practices in general. From 03ea18c64e88ccf98f2f8c628eac7219252ddac1 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 26 Dec 2023 12:43:29 -0800 Subject: [PATCH 02/66] Add a rule for use_reentrant with checkpoint (#7) --- .../misc/checker/reentrant_checkpoint.py | 13 ++++++ .../misc/checker/reentrant_checkpoint.txt | 2 + .../misc/codemod/reentrant_checkpoint.py | 6 +++ .../misc/codemod/reentrant_checkpoint.py.out | 6 +++ torchfix/torchfix.py | 4 +- torchfix/visitors/misc/__init__.py | 41 +++++++++++++++++++ 6 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/misc/checker/reentrant_checkpoint.py create mode 100644 tests/fixtures/misc/checker/reentrant_checkpoint.txt create mode 100644 tests/fixtures/misc/codemod/reentrant_checkpoint.py create mode 100644 tests/fixtures/misc/codemod/reentrant_checkpoint.py.out diff --git a/tests/fixtures/misc/checker/reentrant_checkpoint.py b/tests/fixtures/misc/checker/reentrant_checkpoint.py new file mode 100644 index 0000000..938a41f --- /dev/null +++ b/tests/fixtures/misc/checker/reentrant_checkpoint.py @@ -0,0 +1,13 @@ +import torch +def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) + +import torch.utils.checkpoint +def fn(x, y): + return checkpoint(gn, torch.sin(x), y) + return checkpoint(gn, torch.sin(x), y, use_reentrant=False) + +from torch.utils.checkpoint import checkpoint +def fn(x, y): + return checkpoint(gn, torch.sin(x), y) + return checkpoint(gn, torch.sin(x), y, use_reentrant=True) diff --git a/tests/fixtures/misc/checker/reentrant_checkpoint.txt b/tests/fixtures/misc/checker/reentrant_checkpoint.txt new file mode 100644 index 0000000..af867d6 --- /dev/null +++ b/tests/fixtures/misc/checker/reentrant_checkpoint.txt @@ -0,0 +1,2 @@ +7:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`. +12:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`. diff --git a/tests/fixtures/misc/codemod/reentrant_checkpoint.py b/tests/fixtures/misc/codemod/reentrant_checkpoint.py new file mode 100644 index 0000000..3d0051d --- /dev/null +++ b/tests/fixtures/misc/codemod/reentrant_checkpoint.py @@ -0,0 +1,6 @@ +import torch +from torch.utils.checkpoint import checkpoint +def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) +def fn(x, y): + return checkpoint(gn, torch.sin(x), y) diff --git a/tests/fixtures/misc/codemod/reentrant_checkpoint.py.out b/tests/fixtures/misc/codemod/reentrant_checkpoint.py.out new file mode 100644 index 0000000..57c69b7 --- /dev/null +++ b/tests/fixtures/misc/codemod/reentrant_checkpoint.py.out @@ -0,0 +1,6 @@ +import torch +from torch.utils.checkpoint import checkpoint +def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) +def fn(x, y): + return checkpoint(gn, torch.sin(x), y, use_reentrant=False) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 1a47e20..c0e7da9 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -11,7 +11,8 @@ ) from .visitors.performance import TorchSynchronizedDataLoaderVisitor -from .visitors.misc import TorchRequireGradVisitor +from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor) + from .visitors.vision import ( TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, @@ -33,6 +34,7 @@ def GET_ALL_VISITORS(): TorchVisionDeprecatedPretrainedVisitor(), TorchVisionDeprecatedToTensorVisitor(), TorchUnsafeLoadVisitor(), + TorchReentrantCheckpointVisitor(), ] diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index ef83d91..6ce7c84 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -46,3 +46,44 @@ def visit_Assign(self, node): replacement=replacement, ) ) + + +class TorchReentrantCheckpointVisitor(TorchVisitor): + """ + Find and fix common misuse of reentrant checkpoints. + """ + + ERROR_CODE = "TOR003" + MESSAGE = ( + "Please pass `use_reentrant` explicitly to `checkpoint`. " + "To maintain old behavior, pass `use_reentrant=True`. " + "It is recommended to use `use_reentrant=False`." + ) + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + if qualified_name == "torch.utils.checkpoint.checkpoint": + use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1) + if use_reentrant_arg is None: + position_metadata = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + + # This codemod maybe unsafe correctness-wise + # if reentrant behavior is actually needed, + # so the changes need to be verified/tested. + use_reentrant_arg = cst.ensure_type( + cst.parse_expression("f(use_reentrant=False)"), cst.Call + ).args[0] + replacement = node.with_changes(args=node.args + (use_reentrant_arg,)) + + self.violations.append( + LintViolation( + error_code=self.ERROR_CODE, + message=self.MESSAGE, + line=position_metadata.start.line, + column=position_metadata.start.column, + node=node, + replacement=replacement, + ) + ) From 3e15bd9c8d3202183adac588290125c016c084c5 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 11 Jan 2024 12:03:04 -0800 Subject: [PATCH 03/66] Add torch.nn.utils.weight_norm to deprecated symbols (#9) --- README.md | 22 +++++++++++++++++++ .../deprecated_symbols/checker/weight_norm.py | 2 ++ .../checker/weight_norm.txt | 1 + torchfix/deprecated_symbols.yaml | 5 +++++ 4 files changed, 30 insertions(+) create mode 100644 tests/fixtures/deprecated_symbols/checker/weight_norm.py create mode 100644 tests/fixtures/deprecated_symbols/checker/weight_norm.txt diff --git a/README.md b/README.md index b15b9b0..e443eb5 100644 --- a/README.md +++ b/README.md @@ -67,5 +67,27 @@ To get the LU factorization see `torch.lu`, which can be used with `torch.lu_sol `X = torch.solve(B, A).solution` should be replaced with `X = torch.linalg.solve(A, B)`. +### TOR101 Use of deprecated function + +#### torch.nn.utils.weight_norm + +This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm` +which uses the modern parametrization API. The new ``weight_norm`` is compatible +with ``state_dict`` generated from old ``weight_norm``. + +Migration guide: + +* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed + as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` + respectively. + +* To remove the weight normalization reparametrization, use + `torch.nn.utils.parametrize.remove_parametrizations`. + +* The weight is no longer recomputed once at module forward; instead, it will + be recomputed on every access. To restore the old behavior, use + `torch.nn.utils.parametrize.cached` before invoking the module + in question. + ## License TorchFix is BSD License licensed, as found in the LICENSE file. diff --git a/tests/fixtures/deprecated_symbols/checker/weight_norm.py b/tests/fixtures/deprecated_symbols/checker/weight_norm.py new file mode 100644 index 0000000..e4bb515 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/weight_norm.py @@ -0,0 +1,2 @@ +from torch import nn +m = nn.utils.weight_norm(nn.Linear(20, 40), name='weight') diff --git a/tests/fixtures/deprecated_symbols/checker/weight_norm.txt b/tests/fixtures/deprecated_symbols/checker/weight_norm.txt new file mode 100644 index 0000000..74e8c05 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/weight_norm.txt @@ -0,0 +1 @@ +2:5 TOR101 Use of deprecated function torch.nn.utils.weight_norm: https://github.com/pytorch-labs/torchfix#torchnnutilsweight_norm diff --git a/torchfix/deprecated_symbols.yaml b/torchfix/deprecated_symbols.yaml index 94d8b60..b2bf4c5 100644 --- a/torchfix/deprecated_symbols.yaml +++ b/torchfix/deprecated_symbols.yaml @@ -60,6 +60,11 @@ deprecate_pr: TBA remove_pr: +- name: torch.nn.utils.weight_norm + deprecate_pr: https://github.com/pytorch/pytorch/pull/103001 + remove_pr: + reference: https://github.com/pytorch-labs/torchfix#torchnnutilsweight_norm + # functorch - name: functorch.vmap deprecate_pr: TBA From f4689df843894cc4741ed350a10e3f7efd3c1b73 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 18 Jan 2024 15:18:08 -0800 Subject: [PATCH 04/66] Update README.md to include info on all TOR0 rules (#10) * Update README.md to include info on all TOR0 rules * Address review comments --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index e443eb5..9c46955 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,18 @@ To get the LU factorization see `torch.lu`, which can be used with `torch.lu_sol `X = torch.solve(B, A).solution` should be replaced with `X = torch.linalg.solve(A, B)`. +### TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`? + +This is a common misspelling that can lead to silent performance issues. + +### TOR003 Please pass `use_reentrant` explicitly to `checkpoint` + +The default value of the `use_reentrant` parameter in `torch.utils.checkpoint` is being changed +from `True` to `False`. In the meantime, the value needs to be passed explicitly. + +See this [forum post](https://dev-discuss.pytorch.org/t/bc-breaking-update-to-torch-utils-checkpoint-not-passing-in-use-reentrant-flag-will-raise-an-error/1745) +for details. + ### TOR101 Use of deprecated function #### torch.nn.utils.weight_norm From 9872f7f83b3e976b4fdbe410679c7a53ced54f03 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 18 Jan 2024 15:44:34 -0800 Subject: [PATCH 05/66] Bump version to 0.3.0 (#11) Preparing 0.3.0 release. - Added rule TOR003 about explicitly passing `use_reentrant` explicitly to `torch.utils.checkpoint` - Added `torch.nn.utils.weight_norm` to the list of deprecated functions flagged by TOR101 - Updated README with TOR0 rules description --- torchfix/torchfix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index c0e7da9..989b44e 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -19,7 +19,7 @@ ) from .visitors.security import TorchUnsafeLoadVisitor -__version__ = "0.2.1" +__version__ = "0.3.0" DEPRECATED_CONFIG_PATH = Path(__file__).absolute().parent / "deprecated_symbols.yaml" From 9e2c8e6bf434c1cd008702ec1160a845b9f45eeb Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Wed, 24 Jan 2024 17:06:40 -0800 Subject: [PATCH 06/66] Resolve TODO after LibCST PR 994 (#12) --- torchfix/__main__.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 0c3a823..80fb800 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -4,7 +4,6 @@ import contextlib import sys import io -import os from .torchfix import TorchCodemod, TorchCodemodConfig from .common import CYAN, ENDC @@ -55,13 +54,11 @@ def main() -> None: MARKER = "torch" # this will catch import torch or functorch torch_files = [] for file in files: - # TODO: remove the check when https://github.com/Instagram/LibCST/pull/994 lands - if os.path.isfile(file): # `codemod.gather_files` can return dirs with ".py" - with open(file, errors="replace") as f: - for line in f: - if MARKER in line: - torch_files.append(file) - break + with open(file, errors="replace") as f: + for line in f: + if MARKER in line: + torch_files.append(file) + break config = TorchCodemodConfig() config.select = args.select From a776f3e3f192f4e4f43db401a0ecc238a8ac3d27 Mon Sep 17 00:00:00 2001 From: Eli Uriegas <1700823+seemethere@users.noreply.github.com> Date: Fri, 26 Jan 2024 16:27:24 -0800 Subject: [PATCH 07/66] add torchfix --version (#14) * add torchfix --version Also changes the default command to not require an argument and to default to your current working directory Signed-off-by: Eli Uriegas * get rid of default for path, print usage if left blank Signed-off-by: Eli Uriegas --------- Signed-off-by: Eli Uriegas --- torchfix/__main__.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 80fb800..584104b 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -5,7 +5,7 @@ import sys import io -from .torchfix import TorchCodemod, TorchCodemodConfig +from .torchfix import TorchCodemod, TorchCodemodConfig, __version__ as TorchFixVersion from .common import CYAN, ENDC @@ -14,7 +14,7 @@ def main() -> None: parser.add_argument( "path", - nargs="+", + nargs="*", help=("Path to check/fix. Can be a directory, a file, or multiple of either."), ) parser.add_argument( @@ -36,6 +36,11 @@ def main() -> None: "ALL", ], ) + parser.add_argument( + "--version", + action="store_true", + help="Print current version.", + ) # XXX TODO: Get rid of this! # Silence "Failed to determine module name" @@ -47,6 +52,15 @@ def main() -> None: args = parser.parse_args() + if args.version: + # TODO: Perhaps add commit hash here if we can + print(f"{TorchFixVersion}") + sys.exit(0) + + if not args.path: + parser.print_usage() + sys.exit(1) + files = codemod.gather_files(args.path) # Filter out files that don't have "torch" string in them. From a8a69e9b0807b80357bc2058dd5b2b81577ed61d Mon Sep 17 00:00:00 2001 From: Eli Uriegas <1700823+seemethere@users.noreply.github.com> Date: Fri, 26 Jan 2024 19:16:46 -0800 Subject: [PATCH 08/66] use built-in version flag, revert to default behavior (#15) Signed-off-by: Eli Uriegas --- torchfix/__main__.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 584104b..af94868 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -14,7 +14,7 @@ def main() -> None: parser.add_argument( "path", - nargs="*", + nargs="+", help=("Path to check/fix. Can be a directory, a file, or multiple of either."), ) parser.add_argument( @@ -38,8 +38,8 @@ def main() -> None: ) parser.add_argument( "--version", - action="store_true", - help="Print current version.", + action="version", + version=f"{TorchFixVersion}" ) # XXX TODO: Get rid of this! @@ -52,11 +52,6 @@ def main() -> None: args = parser.parse_args() - if args.version: - # TODO: Perhaps add commit hash here if we can - print(f"{TorchFixVersion}") - sys.exit(0) - if not args.path: parser.print_usage() sys.exit(1) From a428069616c1ed1469cef887b2970c4306ee75ff Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:53:25 -0800 Subject: [PATCH 09/66] Update mypy to 1.7.0 (#18) To be in sync with PyTorch --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 225a8fe..134840c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,4 @@ flake8==6.0.0 pytest==7.2.0 libcst==1.1.0 types-PyYAML==6.0.7 -mypy==1.4.1 +mypy==1.7.0 From 590d51511567ccfae3df30c6165aad23e3639e08 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Wed, 31 Jan 2024 13:22:11 -0800 Subject: [PATCH 10/66] Use `dup2` to redirect stderr to null on MacOS (#19) * Use `dup2` to redirect stderr to null on MacOS Not sure why `contextlib.redirect_stderr(io.StringIO())` does not work, but `dup2(open("/dev/null", O_WRONLY), 2)` do work on MacOS (and on Linux as well) * Rewrite using contextmanager * Make mypy happy --- torchfix/__main__.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index af94868..a2f0130 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -2,6 +2,7 @@ import libcst.codemod as codemod import contextlib +import ctypes import sys import io @@ -9,6 +10,29 @@ from .common import CYAN, ENDC +# Should get rid of this code eventually. +@contextlib.contextmanager +def StderrSilencer(redirect: bool = True): + if not redirect: + yield + elif sys.platform != "darwin": + with contextlib.redirect_stderr(io.StringIO()): + yield + else: + # redirect_stderr does not work for some reason + # Workaround it by using good old dup2 to redirect + # stderr to /dev/null + libc = ctypes.CDLL("libc.dylib") + orig_stderr = libc.dup(2) + with open("/dev/null", "w") as f: + libc.dup2(f.fileno(), 2) + try: + yield + finally: + libc.dup2(orig_stderr, 2) + libc.close(orig_stderr) + + def main() -> None: parser = argparse.ArgumentParser() @@ -74,12 +98,7 @@ def main() -> None: command_instance = TorchCodemod(codemod.CodemodContext(), config) DIFF_CONTEXT = 5 try: - if not args.show_stderr: - context = contextlib.redirect_stderr(io.StringIO()) - else: - # Should get rid of this code eventually. - context = contextlib.nullcontext() # type: ignore - with context: + with StderrSilencer(not args.show_stderr): result = codemod.parallel_exec_transform_with_prettyprint( command_instance, torch_files, From 92136feac04a2f383d9c459afc02e1a2964ccd53 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Wed, 31 Jan 2024 15:32:36 -0800 Subject: [PATCH 11/66] Require at least Python 3.9 (#20) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 244c9aa..87bbb70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [project] name = "TorchFix" +requires-python = ">=3.9" description = "TorchFix - a linter for PyTorch-using code with autofix support" readme = "README.md" license = {file = "LICENSE"} From f219838885ca7612fdaaa2bfe8809afe9843b16c Mon Sep 17 00:00:00 2001 From: Jeffrey Wan Date: Wed, 31 Jan 2024 19:45:05 -0500 Subject: [PATCH 12/66] Update --select arg to accept specific rules (#16) --- tests/test_torchfix.py | 20 ++++- torchfix/__main__.py | 20 +++-- torchfix/torchfix.py | 123 ++++++++++++++++++++++---- torchfix/visitors/vision/to_tensor.py | 2 +- 4 files changed, 139 insertions(+), 26 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index b9f0be0..cd9b74c 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -3,7 +3,11 @@ TorchChecker, TorchCodemod, TorchCodemodConfig, + DISABLED_BY_DEFAULT, + expand_error_codes, GET_ALL_VISITORS, + GET_ALL_ERROR_CODES, + process_error_code_str, ) import logging import libcst.codemod as codemod @@ -20,7 +24,7 @@ def _checker_results(s): def _codemod_results(source_path): with open(source_path) as source: code = source.read() - config = TorchCodemodConfig(select="ALL") + config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) new_module = codemod.transform_module(context, code) return new_module.code @@ -60,3 +64,17 @@ def test_errorcodes_distinct(): for e in error_code if isinstance(error_code, list) else [error_code]: assert e not in seen seen.add(e) + + +def test_parse_error_code_str(): + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + cases = [ + ("ALL", GET_ALL_ERROR_CODES()), + ("ALL,TOR102", GET_ALL_ERROR_CODES()), + ("TOR102", {"TOR102"}), + ("TOR102,TOR101", {"TOR102", "TOR101"}), + ("TOR1,TOR102", {"TOR102", "TOR101"}), + (None, GET_ALL_ERROR_CODES() - exclude_set), + ] + for case, expected in cases: + assert expected == process_error_code_str(case) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index a2f0130..5df0cf9 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -6,7 +6,14 @@ import sys import io -from .torchfix import TorchCodemod, TorchCodemodConfig, __version__ as TorchFixVersion +from .torchfix import ( + TorchCodemod, + TorchCodemodConfig, + __version__ as TorchFixVersion, + DISABLED_BY_DEFAULT, + GET_ALL_ERROR_CODES, + process_error_code_str, +) from .common import CYAN, ENDC @@ -55,10 +62,11 @@ def main() -> None: ) parser.add_argument( "--select", - help="ALL to enable rules disabled by default", - choices=[ - "ALL", - ], + help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. " + f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. " + f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", + type=str, + default=None, ) parser.add_argument( "--version", @@ -94,7 +102,7 @@ def main() -> None: break config = TorchCodemodConfig() - config.select = args.select + config.select = list(process_error_code_str(args.select)) command_instance = TorchCodemod(codemod.CodemodContext(), config) DIFF_CONTEXT = 5 try: diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 989b44e..d1d648d 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -1,6 +1,7 @@ from dataclasses import dataclass +import functools from pathlib import Path -from typing import Optional +from typing import Optional, List import libcst as cst import libcst.codemod as codemod @@ -25,17 +26,100 @@ DISABLED_BY_DEFAULT = ["TOR3", "TOR4"] +ALL_VISITOR_CLS = [ + TorchDeprecatedSymbolsVisitor, + TorchRequireGradVisitor, + TorchSynchronizedDataLoaderVisitor, + TorchVisionDeprecatedPretrainedVisitor, + TorchVisionDeprecatedToTensorVisitor, + TorchUnsafeLoadVisitor, + TorchReentrantCheckpointVisitor, +] + + +@functools.cache +def GET_ALL_ERROR_CODES(): + codes = set() + for cls in ALL_VISITOR_CLS: + if isinstance(cls.ERROR_CODE, list): + codes |= set(cls.ERROR_CODE) + else: + codes.add(cls.ERROR_CODE) + return codes + + +@functools.cache +def expand_error_codes(codes): + out_codes = set() + for c_a in codes: + for c_b in GET_ALL_ERROR_CODES(): + if c_b.startswith(c_a): + out_codes.add(c_b) + return out_codes + + +def construct_visitor(cls): + if cls is TorchDeprecatedSymbolsVisitor: + return cls(DEPRECATED_CONFIG_PATH) + else: + return cls() + def GET_ALL_VISITORS(): - return [ - TorchDeprecatedSymbolsVisitor(DEPRECATED_CONFIG_PATH), - TorchRequireGradVisitor(), - TorchSynchronizedDataLoaderVisitor(), - TorchVisionDeprecatedPretrainedVisitor(), - TorchVisionDeprecatedToTensorVisitor(), - TorchUnsafeLoadVisitor(), - TorchReentrantCheckpointVisitor(), - ] + out = [] + for v in ALL_VISITOR_CLS: + out.append(construct_visitor(v)) + return out + + +def get_visitors_with_error_codes(error_codes): + visitor_classes = set() + for error_code in error_codes: + # Assume the error codes have been expanded so each error code can + # only correspond to one visitor. + found = False + for visitor_cls in ALL_VISITOR_CLS: + if isinstance(visitor_cls.ERROR_CODE, list): + if error_code in visitor_cls.ERROR_CODE: + visitor_classes.add(visitor_cls) + found = True + break + else: + if error_code == visitor_cls.ERROR_CODE: + visitor_classes.add(visitor_cls) + found = True + break + if not found: + raise AssertionError(f"Unknown error code: {error_code}") + out = [] + for cls in visitor_classes: + out.append(construct_visitor(cls)) + return out + + +def process_error_code_str(code_str): + # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001. + # We deduplicate them here. + + # Default when --select is not provided. + if code_str is None: + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + return GET_ALL_ERROR_CODES() - exclude_set + + raw_codes = [s.strip() for s in code_str.split(",")] + + # Validate error codes + for c in raw_codes: + if c == "ALL": + continue + if len(expand_error_codes((c,))) == 0: + raise ValueError(f"Invalid error code: {c}, available error " + f"codes: {list(GET_ALL_ERROR_CODES())}") + + if "ALL" in raw_codes: + return GET_ALL_ERROR_CODES() + + return expand_error_codes(tuple(raw_codes)) # Flake8 plugin @@ -78,7 +162,7 @@ def add_options(optmanager): # Standalone torchfix command @dataclass class TorchCodemodConfig: - select: Optional[str] = None + select: Optional[List[str]] = None class TorchCodemod(codemod.Codemod): @@ -97,8 +181,10 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: # in that case we would need to use `wrapped_module.module` # instead of `module`. wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True) + if self.config is None or self.config.select is None: + raise AssertionError("Expected self.config.select to be set") + visitors = get_visitors_with_error_codes(self.config.select) - visitors = GET_ALL_VISITORS() violations = [] needed_imports = [] wrapped_module.visit_batched(visitors) @@ -110,12 +196,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: replacement_map = {} assert self.context.filename is not None for violation in violations: - skip_violation = False - if self.config is None or self.config.select != "ALL": - for disabled_code in DISABLED_BY_DEFAULT: - if violation.error_code.startswith(disabled_code): - skip_violation = True - break + # Still need to skip violations here, since a single visitor can + # correspond to multiple different types of violations. + skip_violation = True + for code in self.config.select: + if violation.error_code.startswith(code): + skip_violation = False + break if skip_violation: continue diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index 6886c41..ab15827 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -43,7 +43,7 @@ def visit_ImportFrom(self, node): def visit_Attribute(self, node): qualified_names = self.get_metadata(cst.metadata.QualifiedNameProvider, node) - if not len(qualified_names) == 1: + if len(qualified_names) != 1: return self._maybe_add_violation(qualified_names.pop().name, node) From bfe27bc9281604f9bfbfeb13c83ebd07f1aa7f12 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 5 Feb 2024 10:36:19 -0800 Subject: [PATCH 13/66] TorchScopedLibraryVisitor (#22) --- .../internal/checker/scoped_library.py | 3 ++ .../internal/checker/scoped_library.txt | 1 + torchfix/torchfix.py | 5 ++- torchfix/visitors/internal/__init__.py | 33 +++++++++++++++++++ 4 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/internal/checker/scoped_library.py create mode 100644 tests/fixtures/internal/checker/scoped_library.txt create mode 100644 torchfix/visitors/internal/__init__.py diff --git a/tests/fixtures/internal/checker/scoped_library.py b/tests/fixtures/internal/checker/scoped_library.py new file mode 100644 index 0000000..54d5316 --- /dev/null +++ b/tests/fixtures/internal/checker/scoped_library.py @@ -0,0 +1,3 @@ +import torch +from torch.library import Library, impl, fallthrough_kernel +my_lib1 = Library("aten", "IMPL") diff --git a/tests/fixtures/internal/checker/scoped_library.txt b/tests/fixtures/internal/checker/scoped_library.txt new file mode 100644 index 0000000..1f1e7f8 --- /dev/null +++ b/tests/fixtures/internal/checker/scoped_library.txt @@ -0,0 +1 @@ +3:11 TOR901 Use `torch.library._scoped_library` instead of `torch.library.Library` in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 for details. diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index d1d648d..e6d01d1 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -11,6 +11,8 @@ _UpdateFunctorchImports, ) +from .visitors.internal import TorchScopedLibraryVisitor + from .visitors.performance import TorchSynchronizedDataLoaderVisitor from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor) @@ -24,11 +26,12 @@ DEPRECATED_CONFIG_PATH = Path(__file__).absolute().parent / "deprecated_symbols.yaml" -DISABLED_BY_DEFAULT = ["TOR3", "TOR4"] +DISABLED_BY_DEFAULT = ["TOR3", "TOR4", "TOR9"] ALL_VISITOR_CLS = [ TorchDeprecatedSymbolsVisitor, TorchRequireGradVisitor, + TorchScopedLibraryVisitor, TorchSynchronizedDataLoaderVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, diff --git a/torchfix/visitors/internal/__init__.py b/torchfix/visitors/internal/__init__.py new file mode 100644 index 0000000..424e1f2 --- /dev/null +++ b/torchfix/visitors/internal/__init__.py @@ -0,0 +1,33 @@ +import libcst as cst +from ...common import TorchVisitor, LintViolation + + +class TorchScopedLibraryVisitor(TorchVisitor): + """ + Suggest `torch.library._scoped_library` for PyTorch tests. + """ + + ERROR_CODE = "TOR901" + MESSAGE = ( + "Use `torch.library._scoped_library` instead of `torch.library.Library` " + "in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 " + "for details." + ) + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + if qualified_name == "torch.library.Library": + position_metadata = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + + self.violations.append( + LintViolation( + error_code=self.ERROR_CODE, + message=self.MESSAGE, + line=position_metadata.start.line, + column=position_metadata.start.column, + node=node, + replacement=None, + ) + ) From 428b1e1746d4ce3b0df299c37ee1fa350391b9e5 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 6 Feb 2024 10:11:18 -0800 Subject: [PATCH 14/66] Fix bug with function name replacement (#23) --- .../deprecated_symbols/codemod/ger-outer.py | 3 ++ .../codemod/ger-outer.py.out | 3 ++ torchfix/common.py | 34 ++++++++++++++----- .../visitors/deprecated_symbols/__init__.py | 6 ++-- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py index c5e64c4..6fce087 100644 --- a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py +++ b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py @@ -1,6 +1,9 @@ import torch +from torch import ger deprecated = torch.norm() sinusoid_inp = torch.ger(pos_seq, inv_freq) other = something.ger(pos_seq, inv_freq) deprecated = torch.norm() one_more = torch.ger(pos_seq, inv_freq) + +just_name = ger(pos_seq, inv_freq) diff --git a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out index 45f3d84..3303ed0 100644 --- a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out +++ b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out @@ -1,6 +1,9 @@ import torch +from torch import outer, ger deprecated = torch.norm() sinusoid_inp = torch.outer(pos_seq, inv_freq) other = something.ger(pos_seq, inv_freq) deprecated = torch.norm() one_more = torch.outer(pos_seq, inv_freq) + +just_name = outer(pos_seq, inv_freq) diff --git a/torchfix/common.py b/torchfix/common.py index 52f2f52..b302346 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -3,7 +3,7 @@ import libcst as cst from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider from libcst.codemod.visitors import ImportItem -from typing import Optional, List, Set, Union +from typing import Optional, List, Set, Tuple, Union from abc import ABC IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() @@ -83,19 +83,34 @@ def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]: def call_with_name_changes( node: cst.Call, old_qualified_name: str, new_qualified_name: str -) -> Optional[cst.Call]: +) -> Optional[Tuple[cst.Call, Set[ImportItem]]]: """ - Return new `Call` node with name changes. + Return an optional tuple: + new `Call` node with name changes + and a set of newly needed imports. """ old_begin, _, old_last = old_qualified_name.rpartition(".") new_begin, _, new_last = new_qualified_name.rpartition(".") + needed_imports: Set[ImportItem] = set() # If the only difference is the last name part. if old_begin == new_begin: - replacement = node.with_deep_changes( - old_node=cst.ensure_type(node.func, cst.Attribute).attr, - value=new_last, - ) + if isinstance(node.func, cst.Attribute): + replacement = node.with_deep_changes( + old_node=node.func.attr, + value=new_last, + ) + elif isinstance(node.func, cst.Name): + replacement = node.with_deep_changes( + old_node=node.func, + value=new_last, + ) + needed_imports.add( + ImportItem( + module_name=new_begin, + obj_name=new_last, + ) + ) # If the last name part is the same and # originally called without a dot: don't change the call site, @@ -106,7 +121,10 @@ def call_with_name_changes( # Replace with new_qualified_name. else: replacement = node.with_changes(func=cst.parse_expression(new_qualified_name)) - return replacement + if replacement is None: + return None + else: + return replacement, needed_imports def deep_multi_replace(tree, replacement_map): diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index fed7032..93a9082 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -49,10 +49,12 @@ def _call_replacement( qualified_name, {} ).get("replacement", "") if function_name_replacement: - replacement = call_with_name_changes( + replacement_and_imports = call_with_name_changes( node, qualified_name, function_name_replacement ) - + if replacement_and_imports is not None: + replacement, imports = replacement_and_imports + self.needed_imports.update(imports) return replacement def visit_Call(self, node): From 35f24881685df0345ac92062027ae88ff6d152de Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 6 Feb 2024 23:04:03 -0800 Subject: [PATCH 15/66] Bump version to 0.4.0 (#25) Preparing 0.4.0 release. --- torchfix/torchfix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index e6d01d1..a26097a 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -22,7 +22,7 @@ ) from .visitors.security import TorchUnsafeLoadVisitor -__version__ = "0.3.0" +__version__ = "0.4.0" DEPRECATED_CONFIG_PATH = Path(__file__).absolute().parent / "deprecated_symbols.yaml" From 8846f5c303fc913604ad3ec678eac45e6e37d602 Mon Sep 17 00:00:00 2001 From: Suwen Ge Date: Tue, 5 Mar 2024 15:13:37 -0800 Subject: [PATCH 16/66] [Issue 7] Update import torchvision.models as models (#26) * [Issue 7] Update import torchvision.models as models * Move torchvision.models visitor to vision dir * Move torchvision.models visitor to vision dir --- .../fixtures/vision/checker/models_import.py | 5 +++ .../fixtures/vision/checker/models_import.txt | 1 + torchfix/torchfix.py | 2 + torchfix/visitors/vision/__init__.py | 1 + torchfix/visitors/vision/models_import.py | 40 +++++++++++++++++++ 5 files changed, 49 insertions(+) create mode 100644 tests/fixtures/vision/checker/models_import.py create mode 100644 tests/fixtures/vision/checker/models_import.txt create mode 100644 torchfix/visitors/vision/models_import.py diff --git a/tests/fixtures/vision/checker/models_import.py b/tests/fixtures/vision/checker/models_import.py new file mode 100644 index 0000000..8eae98e --- /dev/null +++ b/tests/fixtures/vision/checker/models_import.py @@ -0,0 +1,5 @@ +import torchvision.models as models +import torchvision.models as cnn +from torchvision.models import resnet50, resnet101 +import torchvision.models +from torchvision.models import * diff --git a/tests/fixtures/vision/checker/models_import.txt b/tests/fixtures/vision/checker/models_import.txt new file mode 100644 index 0000000..864cf35 --- /dev/null +++ b/tests/fixtures/vision/checker/models_import.txt @@ -0,0 +1 @@ +1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index a26097a..a38d81d 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -19,6 +19,7 @@ from .visitors.vision import ( TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, + TorchVisionModelsImportVisitor, ) from .visitors.security import TorchUnsafeLoadVisitor @@ -35,6 +36,7 @@ TorchSynchronizedDataLoaderVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, + TorchVisionModelsImportVisitor, TorchUnsafeLoadVisitor, TorchReentrantCheckpointVisitor, ] diff --git a/torchfix/visitors/vision/__init__.py b/torchfix/visitors/vision/__init__.py index 7adcc19..9bc944e 100644 --- a/torchfix/visitors/vision/__init__.py +++ b/torchfix/visitors/vision/__init__.py @@ -1,2 +1,3 @@ from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401 from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401 +from .models_import import TorchVisionModelsImportVisitor # noqa: F401 diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py new file mode 100644 index 0000000..ba5a325 --- /dev/null +++ b/torchfix/visitors/vision/models_import.py @@ -0,0 +1,40 @@ +import libcst as cst + +from ...common import LintViolation, TorchVisitor + + +class TorchVisionModelsImportVisitor(TorchVisitor): + ERROR_CODE = "TOR203" + + def visit_Import(self, node: cst.Import) -> None: + for imported_item in node.names: + if isinstance(imported_item.name, cst.Attribute): + if ( + isinstance(imported_item.name.value, cst.Name) + and imported_item.name.value.value == "torchvision" + and isinstance(imported_item.name.attr, cst.Name) + and imported_item.name.attr.value == "models" + and imported_item.asname is not None + and isinstance(imported_item.asname.name, cst.Name) + and imported_item.asname.name.value == "models" + ): + position = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + replacement = cst.ImportFrom( + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name("models"))], + ) + self.violations.append( + LintViolation( + error_code=self.ERROR_CODE, + message=( + "Consider replacing 'import torchvision.models as" + " models' with 'from torchvision import models'." + ), + line=position.start.line, + column=position.start.column, + node=node, + replacement=replacement + ) + ) From 890b5cc5dd4a6e70b06d969c6d9e94799cd69ad0 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 11 Mar 2024 14:41:31 -0700 Subject: [PATCH 17/66] Raise original exception when test hits codemod.TransformFailure (#28) * Raise original exception when test hits codemod.TransformFailure Otherwise it's very hard to tell why the test failed. --- tests/test_torchfix.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index cd9b74c..6e7e0c6 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -27,6 +27,8 @@ def _codemod_results(source_path): config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) new_module = codemod.transform_module(context, code) + if isinstance(new_module, codemod.TransformFailure): + raise new_module.error return new_module.code From 5e51d0192007561f05f29eab938a75f3935ac10d Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 11 Mar 2024 21:02:34 -0700 Subject: [PATCH 18/66] Fix for TorchVisionModelsImportVisitor (#29) * Fix for TorchVisionModelsImportVisitor --- .../fixtures/vision/checker/models_import.py | 1 + .../fixtures/vision/checker/models_import.txt | 1 + .../fixtures/vision/codemod/models_import.py | 5 +++++ .../vision/codemod/models_import.py.out | 5 +++++ torchfix/visitors/vision/models_import.py | 22 ++++++++++++------- 5 files changed, 26 insertions(+), 8 deletions(-) create mode 100644 tests/fixtures/vision/codemod/models_import.py create mode 100644 tests/fixtures/vision/codemod/models_import.py.out diff --git a/tests/fixtures/vision/checker/models_import.py b/tests/fixtures/vision/checker/models_import.py index 8eae98e..3a16490 100644 --- a/tests/fixtures/vision/checker/models_import.py +++ b/tests/fixtures/vision/checker/models_import.py @@ -3,3 +3,4 @@ from torchvision.models import resnet50, resnet101 import torchvision.models from torchvision.models import * +import torchvision.models as models, torch diff --git a/tests/fixtures/vision/checker/models_import.txt b/tests/fixtures/vision/checker/models_import.txt index 864cf35..7a517da 100644 --- a/tests/fixtures/vision/checker/models_import.txt +++ b/tests/fixtures/vision/checker/models_import.txt @@ -1 +1,2 @@ 1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. +6:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. diff --git a/tests/fixtures/vision/codemod/models_import.py b/tests/fixtures/vision/codemod/models_import.py new file mode 100644 index 0000000..6b75141 --- /dev/null +++ b/tests/fixtures/vision/codemod/models_import.py @@ -0,0 +1,5 @@ +import torchvision.models as models +import torchvision.models as cnn + +# don't touch if more than one name imported +import torchvision.models as models, torch diff --git a/tests/fixtures/vision/codemod/models_import.py.out b/tests/fixtures/vision/codemod/models_import.py.out new file mode 100644 index 0000000..53269c1 --- /dev/null +++ b/tests/fixtures/vision/codemod/models_import.py.out @@ -0,0 +1,5 @@ +from torchvision import models +import torchvision.models as cnn + +# don't touch if more than one name imported +import torchvision.models as models, torch diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py index ba5a325..7ccbebb 100644 --- a/torchfix/visitors/vision/models_import.py +++ b/torchfix/visitors/vision/models_import.py @@ -5,10 +5,16 @@ class TorchVisionModelsImportVisitor(TorchVisitor): ERROR_CODE = "TOR203" + MESSAGE = ( + "Consider replacing 'import torchvision.models as models' " + "with 'from torchvision import models'." + ) def visit_Import(self, node: cst.Import) -> None: + replacement = None for imported_item in node.names: if isinstance(imported_item.name, cst.Attribute): + # TODO refactor using libcst.matchers.matches if ( isinstance(imported_item.name.value, cst.Name) and imported_item.name.value.value == "torchvision" @@ -21,20 +27,20 @@ def visit_Import(self, node: cst.Import) -> None: position = self.get_metadata( cst.metadata.WhitespaceInclusivePositionProvider, node ) - replacement = cst.ImportFrom( - module=cst.Name("torchvision"), - names=[cst.ImportAlias(name=cst.Name("models"))], - ) + # Replace only if the import statement has no other names + if len(node.names) == 1: + replacement = cst.ImportFrom( + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name("models"))], + ) self.violations.append( LintViolation( error_code=self.ERROR_CODE, - message=( - "Consider replacing 'import torchvision.models as" - " models' with 'from torchvision import models'." - ), + message=self.MESSAGE, line=position.start.line, column=position.start.column, node=node, replacement=replacement ) ) + break From 40e021a4e7048267c33d0d5ba404ca1287c5e148 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Wed, 13 Mar 2024 11:07:22 -0700 Subject: [PATCH 19/66] Use Alerts markdown in Readme (#30) --- README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9c46955..1b825e1 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ reporting issues. TorchFix can be used as a Flake8 plugin (linting only) or as a standalone program (with autofix available for a subset of the lint violations). -Currently TorchFix is in a **beta version** stage, so there are still a lot of rough +> [!WARNING] +> Currently TorchFix is in a **beta version** stage, so there are still a lot of rough edges and many things can and will change. ## Installation @@ -36,7 +37,8 @@ TorchFix can also be run as a standalone program: `torchfix .` Add `--fix` parameter to try to autofix some of the issues (the files will be overwritten!) To see some additional debug info, add `--show-stderr` parameter. -Please keep in mind that autofix is a best-effort mechanism. Given the dynamic nature of Python, +> [!CAUTION] +> Please keep in mind that autofix is a best-effort mechanism. Given the dynamic nature of Python, and especially the beta version status of TorchFix, it's very difficult to have certainty when making changes to code, even for the seemingly trivial fixes. @@ -83,9 +85,9 @@ for details. #### torch.nn.utils.weight_norm -This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm` -which uses the modern parametrization API. The new ``weight_norm`` is compatible -with ``state_dict`` generated from old ``weight_norm``. +This function is deprecated. Use `torch.nn.utils.parametrizations.weight_norm` +which uses the modern parametrization API. The new `weight_norm` is compatible +with `state_dict` generated from old `weight_norm`. Migration guide: From 42bb4a8e3e233ed0d9394526af55474fae4803ae Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 14 Mar 2024 18:21:05 -0700 Subject: [PATCH 20/66] Add TorchNonPublicAliasVisitor (#33) Add a rule to suggest using public aliases for non-public functions that have it. This PR implements a lint-only rule with a suggestion for `default_collate` and `default_convert`. I will added an actual codemod for this in a different PR. Also this PR introduces an utility method `add_violation` to reduce code duplication. I'll refactor other rules to use `add_violation` in a different PR. --- .../checker/default_collate_convert.py | 11 ++++ .../checker/default_collate_convert.txt | 8 +++ tests/test_torchfix.py | 2 +- torchfix/common.py | 21 ++++++++ torchfix/torchfix.py | 2 + torchfix/visitors/nonpublic/__init__.py | 51 +++++++++++++++++++ 6 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/nonpublic/checker/default_collate_convert.py create mode 100644 tests/fixtures/nonpublic/checker/default_collate_convert.txt create mode 100644 torchfix/visitors/nonpublic/__init__.py diff --git a/tests/fixtures/nonpublic/checker/default_collate_convert.py b/tests/fixtures/nonpublic/checker/default_collate_convert.py new file mode 100644 index 0000000..c7b7c65 --- /dev/null +++ b/tests/fixtures/nonpublic/checker/default_collate_convert.py @@ -0,0 +1,11 @@ +from torch.utils.data import _utils +batch = _utils.collate.default_collate(batch) + +from torch.utils.data._utils.collate import default_collate +inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx) + +from torch.utils.data._utils.collate import default_convert +values = default_convert(values) + +import torch +values = torch.utils.data._utils.collate.default_convert(values) diff --git a/tests/fixtures/nonpublic/checker/default_collate_convert.txt b/tests/fixtures/nonpublic/checker/default_collate_convert.txt new file mode 100644 index 0000000..edfc9a9 --- /dev/null +++ b/tests/fixtures/nonpublic/checker/default_collate_convert.txt @@ -0,0 +1,8 @@ +2:9 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +4:1 TOR105 Import of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +5:29 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +5:54 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +5:79 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +7:1 TOR105 Import of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead +8:10 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead +11:10 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 6e7e0c6..7b9a051 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -75,7 +75,7 @@ def test_parse_error_code_str(): ("ALL,TOR102", GET_ALL_ERROR_CODES()), ("TOR102", {"TOR102"}), ("TOR102,TOR101", {"TOR102", "TOR101"}), - ("TOR1,TOR102", {"TOR102", "TOR101"}), + ("TOR1,TOR102", {"TOR102", "TOR101", "TOR104", "TOR105"}), (None, GET_ALL_ERROR_CODES() - exclude_set), ] for case, expected in cases: diff --git a/torchfix/common.py b/torchfix/common.py index b302346..7fdd00a 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -61,6 +61,27 @@ def get_specific_arg( return arg return None + def add_violation( + self, + node: cst.CSTNode, + error_code: str, + message: str, + replacement: Optional[cst.CSTNode] = None, + ) -> None: + position_metadata = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + self.violations.append( + LintViolation( + error_code=error_code, + message=message, + line=position_metadata.start.line, + column=position_metadata.start.column, + node=node, + replacement=replacement, + ) + ) + def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]: # Guard against situations like `vmap(a)(b)`: # diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index a38d81d..21a8994 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -15,6 +15,7 @@ from .visitors.performance import TorchSynchronizedDataLoaderVisitor from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor) +from .visitors.nonpublic import TorchNonPublicAliasVisitor from .visitors.vision import ( TorchVisionDeprecatedPretrainedVisitor, @@ -39,6 +40,7 @@ TorchVisionModelsImportVisitor, TorchUnsafeLoadVisitor, TorchReentrantCheckpointVisitor, + TorchNonPublicAliasVisitor, ] diff --git a/torchfix/visitors/nonpublic/__init__.py b/torchfix/visitors/nonpublic/__init__.py new file mode 100644 index 0000000..0b8318f --- /dev/null +++ b/torchfix/visitors/nonpublic/__init__.py @@ -0,0 +1,51 @@ +from typing import Sequence + +import libcst as cst +from ...common import TorchVisitor + + +class TorchNonPublicAliasVisitor(TorchVisitor): + """ + Suggest to use public APIs instead of non-public aliases. + + Currently implemented for + torch.utils.data._utils.collate.default_collate and + torch.utils.data._utils.collate.default_convert, + see https://github.com/pytorch/pytorch/pull/69862/files + """ + + ERROR_CODE = ["TOR104", "TOR105"] + + # fmt: off + ALIASES = { + "torch.utils.data._utils.collate.default_collate": "torch.utils.data.dataloader.default_collate", # noqa: E501 + "torch.utils.data._utils.collate.default_convert": "torch.utils.data.dataloader.default_convert", # noqa: E501 + } + # fmt: on + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + if qualified_name is None: + return + + if qualified_name in self.ALIASES: + public_name = self.ALIASES[qualified_name] + error_code = self.ERROR_CODE[0] + message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 + self.add_violation(node, error_code=error_code, message=message) + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + if node.module is None: + return + + module = cst.helpers.get_full_name_for_node(node.module) + if not isinstance(node.names, Sequence): + return + + for name in node.names: + qualified_name = f"{module}.{name.name.value}" + if qualified_name in self.ALIASES: + public_name = self.ALIASES[qualified_name] + error_code = self.ERROR_CODE[1] + message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 + self.add_violation(node, error_code=error_code, message=message) From e7826707cb7718f3c7d29219fda063aa4f5c1ad6 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 14 Mar 2024 19:20:13 -0700 Subject: [PATCH 21/66] Add rules for importing deprecated and removed symbols (#32) * Add rules for importing deprecated symbols * Add test for functorch * Appease mypy --- .../checker/deprecated_qr.txt | 2 + .../deprecated_symbols/checker/functorch.py | 2 + .../deprecated_symbols/checker/functorch.txt | 2 + .../checker/removed_symeig.txt | 1 + tests/test_torchfix.py | 2 +- .../visitors/deprecated_symbols/__init__.py | 40 ++++++++++++++++++- 6 files changed, 46 insertions(+), 3 deletions(-) diff --git a/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt b/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt index 9768b1f..3da6ee5 100644 --- a/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt +++ b/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt @@ -1,4 +1,6 @@ 2:7 TOR101 Use of deprecated function torch.qr 6:7 TOR101 Use of deprecated function torch.qr +9:1 TOR103 Import of deprecated function torch.qr 10:7 TOR101 Use of deprecated function torch.qr +13:1 TOR103 Import of deprecated function torch.qr 16:7 TOR101 Use of deprecated function torch.qr diff --git a/tests/fixtures/deprecated_symbols/checker/functorch.py b/tests/fixtures/deprecated_symbols/checker/functorch.py index f072cbb..044240f 100644 --- a/tests/fixtures/deprecated_symbols/checker/functorch.py +++ b/tests/fixtures/deprecated_symbols/checker/functorch.py @@ -2,3 +2,5 @@ # Check that we get only one warning for the line functorch.vmap(tdmodule, (None, 0))(td, params) + +from functorch import vmap, jacrev diff --git a/tests/fixtures/deprecated_symbols/checker/functorch.txt b/tests/fixtures/deprecated_symbols/checker/functorch.txt index 336c7ae..e4f802c 100644 --- a/tests/fixtures/deprecated_symbols/checker/functorch.txt +++ b/tests/fixtures/deprecated_symbols/checker/functorch.txt @@ -1 +1,3 @@ 4:1 TOR101 Use of deprecated function functorch.vmap +6:1 TOR103 Import of deprecated function functorch.vmap +6:1 TOR103 Import of deprecated function functorch.jacrev diff --git a/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt b/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt index 06b13e7..7610c28 100644 --- a/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt +++ b/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt @@ -1,2 +1,3 @@ +2:1 TOR004 Import of removed function torch.symeig 4:8 TOR001 Use of removed function torch.symeig 5:8 TOR001 Use of removed function torch.symeig diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 7b9a051..d699ea7 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -75,7 +75,7 @@ def test_parse_error_code_str(): ("ALL,TOR102", GET_ALL_ERROR_CODES()), ("TOR102", {"TOR102"}), ("TOR102,TOR101", {"TOR102", "TOR101"}), - ("TOR1,TOR102", {"TOR102", "TOR101", "TOR104", "TOR105"}), + ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), (None, GET_ALL_ERROR_CODES() - exclude_set), ] for case, expected in cases: diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 93a9082..beb0657 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -16,7 +16,7 @@ class TorchDeprecatedSymbolsVisitor(TorchVisitor): - ERROR_CODE = ["TOR001", "TOR101"] + ERROR_CODE = ["TOR001", "TOR101", "TOR004", "TOR103"] def __init__(self, deprecated_config_path=None): def read_deprecated_config(path=None): @@ -57,7 +57,43 @@ def _call_replacement( self.needed_imports.update(imports) return replacement - def visit_Call(self, node): + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + if node.module is None: + return + + module = cst.helpers.get_full_name_for_node(node.module) + if isinstance(node.names, Sequence): + for name in node.names: + qualified_name = f"{module}.{name.name.value}" + position_metadata = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + if qualified_name in self.deprecated_config: + if self.deprecated_config[qualified_name]["remove_pr"] is None: + error_code = self.ERROR_CODE[3] + message = f"Import of deprecated function {qualified_name}" + else: + error_code = self.ERROR_CODE[2] + message = f"Import of removed function {qualified_name}" + + reference = self.deprecated_config[qualified_name].get( + "reference" + ) + if reference is not None: + message = f"{message}: {reference}" + + self.violations.append( + LintViolation( + error_code=error_code, + message=message, + line=position_metadata.start.line, + column=position_metadata.start.column, + node=node, + replacement=None, + ) + ) + + def visit_Call(self, node) -> None: qualified_name = self.get_qualified_name_for_call(node) if qualified_name is None: return From 9b5adefdf46069d3da617502bc17c006f3eb5880 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Fri, 15 Mar 2024 14:28:46 -0700 Subject: [PATCH 22/66] Refactor to use add_violation (#34) --- .../visitors/deprecated_symbols/__init__.py | 33 ++------------- .../deprecated_symbols/chain_matmul.py | 2 +- .../visitors/deprecated_symbols/cholesky.py | 2 +- torchfix/visitors/deprecated_symbols/qr.py | 2 +- torchfix/visitors/deprecated_symbols/range.py | 1 + torchfix/visitors/internal/__init__.py | 18 +-------- torchfix/visitors/misc/__init__.py | 40 +++++-------------- torchfix/visitors/performance/__init__.py | 18 ++------- torchfix/visitors/security/__init__.py | 21 +++------- torchfix/visitors/vision/models_import.py | 23 ++++------- torchfix/visitors/vision/pretrained.py | 19 +++------ torchfix/visitors/vision/to_tensor.py | 25 ++++-------- 12 files changed, 51 insertions(+), 153 deletions(-) diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index beb0657..3450777 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -6,7 +6,6 @@ from ...common import ( TorchVisitor, call_with_name_changes, - LintViolation, ) from .range import call_replacement_range @@ -65,9 +64,6 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: if isinstance(node.names, Sequence): for name in node.names: qualified_name = f"{module}.{name.name.value}" - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) if qualified_name in self.deprecated_config: if self.deprecated_config[qualified_name]["remove_pr"] is None: error_code = self.ERROR_CODE[3] @@ -76,22 +72,11 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: error_code = self.ERROR_CODE[2] message = f"Import of removed function {qualified_name}" - reference = self.deprecated_config[qualified_name].get( - "reference" - ) + reference = self.deprecated_config[qualified_name].get("reference") if reference is not None: message = f"{message}: {reference}" - self.violations.append( - LintViolation( - error_code=error_code, - message=message, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=None, - ) - ) + self.add_violation(node, error_code=error_code, message=message) def visit_Call(self, node) -> None: qualified_name = self.get_qualified_name_for_call(node) @@ -99,9 +84,6 @@ def visit_Call(self, node) -> None: return if qualified_name in self.deprecated_config: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) if self.deprecated_config[qualified_name]["remove_pr"] is None: error_code = self.ERROR_CODE[1] message = f"Use of deprecated function {qualified_name}" @@ -114,15 +96,8 @@ def visit_Call(self, node) -> None: if reference is not None: message = f"{message}: {reference}" - self.violations.append( - LintViolation( - error_code=error_code, - message=message, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, error_code=error_code, message=message, replacement=replacement ) diff --git a/torchfix/visitors/deprecated_symbols/chain_matmul.py b/torchfix/visitors/deprecated_symbols/chain_matmul.py index 3eab730..ca546c3 100644 --- a/torchfix/visitors/deprecated_symbols/chain_matmul.py +++ b/torchfix/visitors/deprecated_symbols/chain_matmul.py @@ -20,7 +20,7 @@ def call_replacement_chain_matmul(node: cst.Call) -> cst.CSTNode: replacement_args = [matrices_arg] else: replacement_args = [matrices_arg, out_arg] - module_name = get_module_name(node, 'torch') + module_name = get_module_name(node, "torch") replacement = cst.parse_expression(f"{module_name}.linalg.multi_dot(args)") replacement = replacement.with_changes(args=replacement_args) diff --git a/torchfix/visitors/deprecated_symbols/cholesky.py b/torchfix/visitors/deprecated_symbols/cholesky.py index cec5e71..c44c831 100644 --- a/torchfix/visitors/deprecated_symbols/cholesky.py +++ b/torchfix/visitors/deprecated_symbols/cholesky.py @@ -1,5 +1,5 @@ import libcst as cst -from ...common import (TorchVisitor, get_module_name) +from ...common import TorchVisitor, get_module_name def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode: diff --git a/torchfix/visitors/deprecated_symbols/qr.py b/torchfix/visitors/deprecated_symbols/qr.py index f1d96df..9fc4874 100644 --- a/torchfix/visitors/deprecated_symbols/qr.py +++ b/torchfix/visitors/deprecated_symbols/qr.py @@ -1,6 +1,6 @@ import libcst as cst from typing import Optional -from ...common import (TorchVisitor, get_module_name) +from ...common import TorchVisitor, get_module_name def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]: diff --git a/torchfix/visitors/deprecated_symbols/range.py b/torchfix/visitors/deprecated_symbols/range.py index 97fec69..26f0a4f 100644 --- a/torchfix/visitors/deprecated_symbols/range.py +++ b/torchfix/visitors/deprecated_symbols/range.py @@ -7,6 +7,7 @@ def call_replacement_range(node: cst.Call) -> Optional[cst.Call]: """Replace `range` with `arange`. Add `step` to the `end` argument as `arange` has the interval `[start, end)`. """ + # `torch.range` documented signature is not a valid Python signature, # so it's hard to generalize this. def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: diff --git a/torchfix/visitors/internal/__init__.py b/torchfix/visitors/internal/__init__.py index 424e1f2..14389b3 100644 --- a/torchfix/visitors/internal/__init__.py +++ b/torchfix/visitors/internal/__init__.py @@ -1,5 +1,4 @@ -import libcst as cst -from ...common import TorchVisitor, LintViolation +from ...common import TorchVisitor class TorchScopedLibraryVisitor(TorchVisitor): @@ -17,17 +16,4 @@ class TorchScopedLibraryVisitor(TorchVisitor): def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) if qualified_name == "torch.library.Library": - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=None, - ) - ) + self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 6ce7c84..ef60b3e 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -2,7 +2,7 @@ import libcst.matchers as m -from ...common import TorchVisitor, LintViolation +from ...common import TorchVisitor class TorchRequireGradVisitor(TorchVisitor): @@ -31,20 +31,11 @@ def visit_Assign(self, node): replacement = node.with_deep_changes( old_node=node.targets[0].target.attr, value="requires_grad" ) - - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, ) @@ -65,10 +56,6 @@ def visit_Call(self, node): if qualified_name == "torch.utils.checkpoint.checkpoint": use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1) if use_reentrant_arg is None: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - # This codemod maybe unsafe correctness-wise # if reentrant behavior is actually needed, # so the changes need to be verified/tested. @@ -76,14 +63,9 @@ def visit_Call(self, node): cst.parse_expression("f(use_reentrant=False)"), cst.Call ).args[0] replacement = node.with_changes(args=node.args + (use_reentrant_arg,)) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, ) diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index f838fbe..427eb78 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -1,8 +1,7 @@ -import libcst as cst import libcst.matchers as m -from ...common import TorchVisitor, LintViolation +from ...common import TorchVisitor class TorchSynchronizedDataLoaderVisitor(TorchVisitor): @@ -25,17 +24,6 @@ def visit_Call(self, node): if num_workers_arg is None or m.matches( num_workers_arg.value, m.Integer(value="0") ): - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=None, - ) + self.add_violation( + node, error_code=self.ERROR_CODE, message=self.MESSAGE ) diff --git a/torchfix/visitors/security/__init__.py b/torchfix/visitors/security/__init__.py index 010c5f4..5dfdf6e 100644 --- a/torchfix/visitors/security/__init__.py +++ b/torchfix/visitors/security/__init__.py @@ -1,5 +1,5 @@ import libcst as cst -from ...common import TorchVisitor, LintViolation +from ...common import TorchVisitor class TorchUnsafeLoadVisitor(TorchVisitor): @@ -21,10 +21,6 @@ def visit_Call(self, node): if qualified_name == "torch.load": weights_only_arg = self.get_specific_arg(node, "weights_only", -1) if weights_only_arg is None: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - # Add `weights_only=True` if there is no `pickle_module`. # (do not add `weights_only=False` with `pickle_module`, as it # needs to be an explicit choice). @@ -42,14 +38,9 @@ def visit_Call(self, node): replacement = node.with_changes( args=node.args + (weights_only_arg,) ) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, ) diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py index 7ccbebb..928f2c1 100644 --- a/torchfix/visitors/vision/models_import.py +++ b/torchfix/visitors/vision/models_import.py @@ -1,6 +1,6 @@ import libcst as cst -from ...common import LintViolation, TorchVisitor +from ...common import TorchVisitor class TorchVisionModelsImportVisitor(TorchVisitor): @@ -24,23 +24,16 @@ def visit_Import(self, node: cst.Import) -> None: and isinstance(imported_item.asname.name, cst.Name) and imported_item.asname.name.value == "models" ): - position = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) # Replace only if the import statement has no other names if len(node.names) == 1: replacement = cst.ImportFrom( - module=cst.Name("torchvision"), - names=[cst.ImportAlias(name=cst.Name("models"))], - ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position.start.line, - column=position.start.column, - node=node, - replacement=replacement + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name("models"))], ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, ) break diff --git a/torchfix/visitors/vision/pretrained.py b/torchfix/visitors/vision/pretrained.py index 6e17048..99dd845 100644 --- a/torchfix/visitors/vision/pretrained.py +++ b/torchfix/visitors/vision/pretrained.py @@ -3,7 +3,7 @@ import libcst as cst from libcst.codemod.visitors import ImportItem -from ...common import LintViolation, TorchVisitor +from ...common import TorchVisitor class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): @@ -248,16 +248,9 @@ def _new_arg_and_import( node.with_changes(args=replacement_args) if has_replacement else None ) if message is not None: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=message, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=message, + replacement=replacement, ) diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index ab15827..3395dd9 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -1,32 +1,21 @@ from collections.abc import Sequence import libcst as cst -from ...common import LintViolation, TorchVisitor +from ...common import TorchVisitor class TorchVisionDeprecatedToTensorVisitor(TorchVisitor): ERROR_CODE = "TOR202" + MESSAGE = ( + "The transform `v2.ToTensor()` is deprecated and will be removed " + "in a future release. Instead, please use " + "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 + ) def _maybe_add_violation(self, qualified_name, node): if qualified_name != "torchvision.transforms.v2.ToTensor": return - position = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=( - "The transform `v2.ToTensor()` is deprecated and will be removed " - "in a future release. Instead, please use " - "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 - ), - line=position.start.line, - column=position.start.column, - node=node, - replacement=None, - ) - ) + self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE) def visit_ImportFrom(self, node): module_path = cst.helpers.get_absolute_module_from_package_for_import( From 7cc047d5232de065c3f17d768ba58e70e3416c3c Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 18 Mar 2024 11:54:04 -0700 Subject: [PATCH 23/66] Update to checkout action v4 (#35) To get rid of the warning > Node.js 16 actions are deprecated. Please update the following actions to use Node.js 20: actions/checkout@v3. For more information see: https://github.blog/changelog/2023-09-22-github-actions-transitioning-from-node-16-to-node-20/ --- .github/workflows/test-torchfix.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-torchfix.yml b/.github/workflows/test-torchfix.yml index 2047373..a80c845 100644 --- a/.github/workflows/test-torchfix.yml +++ b/.github/workflows/test-torchfix.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install requirements run: | pip3 install -r requirements-dev.txt From c909236c24c70ddee1d0f06bddaf3c1ec73fcdaa Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 18 Mar 2024 18:41:35 -0700 Subject: [PATCH 24/66] Add codemod for TorchNonPublicAliasVisitor (#36) * Add codemod for TorchNonPublicAliasVisitor * Format * noqa --- .../codemod/default_collate_convert.py | 14 ++++++ .../codemod/default_collate_convert.py.out | 14 ++++++ torchfix/visitors/nonpublic/__init__.py | 48 ++++++++++++++++++- 3 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 tests/fixtures/nonpublic/codemod/default_collate_convert.py create mode 100644 tests/fixtures/nonpublic/codemod/default_collate_convert.py.out diff --git a/tests/fixtures/nonpublic/codemod/default_collate_convert.py b/tests/fixtures/nonpublic/codemod/default_collate_convert.py new file mode 100644 index 0000000..b61bad6 --- /dev/null +++ b/tests/fixtures/nonpublic/codemod/default_collate_convert.py @@ -0,0 +1,14 @@ +from torch.utils.data import _utils # will not be removed as it could be used for something besides default_collate +batch = _utils.collate.default_collate(batch) + +from torch.utils.data._utils import collate # also will not be removed +batch = collate.default_collate(batch) + +from torch.utils.data._utils.collate import default_collate +inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx) + +from torch.utils.data._utils.collate import default_convert +values = default_convert(values) + +import torch +values = torch.utils.data._utils.collate.default_convert(values) diff --git a/tests/fixtures/nonpublic/codemod/default_collate_convert.py.out b/tests/fixtures/nonpublic/codemod/default_collate_convert.py.out new file mode 100644 index 0000000..11bfb53 --- /dev/null +++ b/tests/fixtures/nonpublic/codemod/default_collate_convert.py.out @@ -0,0 +1,14 @@ +from torch.utils.data import dataloader, _utils # will not be removed as it could be used for something besides default_collate +batch = dataloader.default_collate(batch) + +from torch.utils.data._utils import collate # also will not be removed +batch = dataloader.default_collate(batch) + +from torch.utils.data.dataloader import default_collate +inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx) + +from torch.utils.data.dataloader import default_convert +values = default_convert(values) + +import torch +values = torch.utils.data.dataloader.default_convert(values) diff --git a/torchfix/visitors/nonpublic/__init__.py b/torchfix/visitors/nonpublic/__init__.py index 0b8318f..7240e0f 100644 --- a/torchfix/visitors/nonpublic/__init__.py +++ b/torchfix/visitors/nonpublic/__init__.py @@ -1,6 +1,9 @@ +from os.path import commonprefix from typing import Sequence import libcst as cst +from libcst.codemod.visitors import ImportItem + from ...common import TorchVisitor @@ -32,7 +35,32 @@ def visit_Call(self, node): public_name = self.ALIASES[qualified_name] error_code = self.ERROR_CODE[0] message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 - self.add_violation(node, error_code=error_code, message=message) + + call_name = cst.helpers.get_full_name_for_node(node) + replacement = None + if not public_name.endswith(call_name): + # We need to change the call name as it's not in the public name. + # Get the new call name on the same hierarchical level. + new_call_name = public_name.removeprefix( + commonprefix([qualified_name.removesuffix(call_name), public_name]) + ) + new_module_name = public_name.removesuffix(new_call_name).removesuffix( + "." + ) + if new_module_name: + self.needed_imports.add( + ImportItem( + module_name=new_module_name, + obj_name=new_call_name.split(".")[0], + ) + ) + replacement = node.with_changes( + func=cst.parse_expression(new_call_name) + ) + + self.add_violation( + node, error_code=error_code, message=message, replacement=replacement + ) def visit_ImportFrom(self, node: cst.ImportFrom) -> None: if node.module is None: @@ -48,4 +76,20 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: public_name = self.ALIASES[qualified_name] error_code = self.ERROR_CODE[1] message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 - self.add_violation(node, error_code=error_code, message=message) + + new_module = ".".join(public_name.split(".")[:-1]) + new_name = public_name.split(".")[-1] + # Replace only if the import statement has no other names + if len(node.names) == 1: + replacement = cst.ImportFrom( + module=cst.parse_expression(new_module), # type: ignore[arg-type] # noqa: E501 + names=[cst.ImportAlias(name=cst.Name(new_name))], + ) + else: + replacement = None + self.add_violation( + node, + error_code=error_code, + message=message, + replacement=replacement, + ) From 33870709f69c5fd1e29e2637fc516fe327d598e5 Mon Sep 17 00:00:00 2001 From: Francesca Wang Date: Wed, 3 Apr 2024 14:56:17 -0700 Subject: [PATCH 25/66] Refactored each import item from node using libcst.matchers.matches (#37) * refactored each import item from node using libcst.matchers.matches * add proper import statement * remove comment * fix code style with prettier indent * prettify the code by removing more indent * prettify the code --------- Co-authored-by: Francesca Wang --- torchfix/visitors/vision/models_import.py | 41 ++++++++++------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py index 928f2c1..de75d5a 100644 --- a/torchfix/visitors/vision/models_import.py +++ b/torchfix/visitors/vision/models_import.py @@ -1,4 +1,5 @@ import libcst as cst +import libcst.matchers as m from ...common import TorchVisitor @@ -13,27 +14,21 @@ class TorchVisionModelsImportVisitor(TorchVisitor): def visit_Import(self, node: cst.Import) -> None: replacement = None for imported_item in node.names: - if isinstance(imported_item.name, cst.Attribute): - # TODO refactor using libcst.matchers.matches - if ( - isinstance(imported_item.name.value, cst.Name) - and imported_item.name.value.value == "torchvision" - and isinstance(imported_item.name.attr, cst.Name) - and imported_item.name.attr.value == "models" - and imported_item.asname is not None - and isinstance(imported_item.asname.name, cst.Name) - and imported_item.asname.name.value == "models" - ): - # Replace only if the import statement has no other names - if len(node.names) == 1: - replacement = cst.ImportFrom( - module=cst.Name("torchvision"), - names=[cst.ImportAlias(name=cst.Name("models"))], - ) - self.add_violation( - node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, - replacement=replacement, + if m.matches(imported_item, m.ImportAlias( + name=m.Attribute(value=m.Name("torchvision"), + attr=m.Name("models")), + asname=m.AsName(name=m.Name("models")) + )): + # Replace only if the import statement has no other names + if len(node.names) == 1: + replacement = cst.ImportFrom( + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name("models"))], ) - break + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, + ) + break From 04bab2c9d629b4651460c4daf8a8cc7f9b1c3e27 Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Wed, 17 Apr 2024 16:47:27 -0700 Subject: [PATCH 26/66] Fix deprecated symbol loading in zipped deployments (#39) torchfix currentyl uses standard filesystem methods to find and load the deprecated symbols data from disk. When running flake8 as zipped deployments, this fails because the data is not accessible as a standard filesystem path. This replaces the filesystem usage with stdlib `pkgutil.get_data()` [1] that is capable of resolving the data file within the zip deployment, and loading that data using the correct internal mechanisms. 1: https://docs.python.org/3/library/pkgutil.html#pkgutil.get_data --- torchfix/torchfix.py | 2 +- torchfix/visitors/deprecated_symbols/__init__.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 21a8994..4a1de3c 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -26,7 +26,7 @@ __version__ = "0.4.0" -DEPRECATED_CONFIG_PATH = Path(__file__).absolute().parent / "deprecated_symbols.yaml" +DEPRECATED_CONFIG_PATH = "deprecated_symbols.yaml" DISABLED_BY_DEFAULT = ["TOR3", "TOR4", "TOR9"] diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 3450777..9949242 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -1,4 +1,5 @@ import libcst as cst +import pkgutil import yaml from typing import Optional from collections.abc import Sequence @@ -21,9 +22,9 @@ def __init__(self, deprecated_config_path=None): def read_deprecated_config(path=None): deprecated_config = {} if path is not None: - with open(path) as f: - for item in yaml.load(f, yaml.SafeLoader): - deprecated_config[item["name"]] = item + data = pkgutil.get_data("torchfix", path) + for item in yaml.load(data, yaml.SafeLoader): + deprecated_config[item["name"]] = item return deprecated_config super().__init__() From 55da3d0d0065392c1b74af7e75820a9a3967b4ba Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 18 Apr 2024 10:36:25 -0700 Subject: [PATCH 27/66] Bump version to 0.5.0 (#40) Preparing 0.5.0 release. - Added rule TOR203 to replace 'import torchvision.models as models' with 'from torchvision import models' - Added rules TOR104 and TOR105 for calling and importing non-public PyTorch functions that have known public aliases - Added rules TOR004 and TOR103 for importing removed and deprecated functions (in addition to the existing rules for calling those functions) - Fixed loading for deprecated symbols config in zipped deployments - Done several smaller refactorings and bug fixes --- torchfix/torchfix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 4a1de3c..dda4057 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -24,7 +24,7 @@ ) from .visitors.security import TorchUnsafeLoadVisitor -__version__ = "0.4.0" +__version__ = "0.5.0" DEPRECATED_CONFIG_PATH = "deprecated_symbols.yaml" From 86579eb88df9d5fab0a9d71034759170bfc65bb2 Mon Sep 17 00:00:00 2001 From: clee2000 <44682903+clee2000@users.noreply.github.com> Date: Mon, 22 Apr 2024 11:40:21 -0700 Subject: [PATCH 28/66] Add deprecation warning for `torch.backends.cuda.sdp_kernel` (#43) * only deprecation warning * typo * move to correct section in readme --- README.md | 7 +++++++ .../deprecated_symbols/checker/sdp_kernel.py | 12 ++++++++++++ .../deprecated_symbols/checker/sdp_kernel.txt | 4 ++++ torchfix/deprecated_symbols.yaml | 5 +++++ 4 files changed, 28 insertions(+) create mode 100644 tests/fixtures/deprecated_symbols/checker/sdp_kernel.py create mode 100644 tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt diff --git a/README.md b/README.md index 1b825e1..9fbb7bd 100644 --- a/README.md +++ b/README.md @@ -103,5 +103,12 @@ Migration guide: `torch.nn.utils.parametrize.cached` before invoking the module in question. +#### torch.backends.cuda.sdp_kernel + +This function is deprecated. Use the `torch.nn.attention.sdpa_kernel` context manager instead. + +Migration guide: +Each boolean input parameter (defaulting to true unless specified) of `sdp_kernel` corresponds to a `SDPBackened`. If the input parameter is true, the corresponding backend should be added to the input list of `sdpa_kernel`. + ## License TorchFix is BSD License licensed, as found in the LICENSE file. diff --git a/tests/fixtures/deprecated_symbols/checker/sdp_kernel.py b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.py new file mode 100644 index 0000000..06d14a8 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.py @@ -0,0 +1,12 @@ +import torch +from torch.backends import cuda +from torch.backends.cuda import sdp_kernel + +with torch.backends.cuda.sdp_kernel() as context: + pass + +with cuda.sdp_kernel() as context: + pass + +with sdp_kernel() as context: + pass diff --git a/tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt new file mode 100644 index 0000000..d18f1ee --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt @@ -0,0 +1,4 @@ +3:1 TOR103 Import of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +5:6 TOR101 Use of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +8:6 TOR101 Use of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +11:6 TOR101 Use of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel diff --git a/torchfix/deprecated_symbols.yaml b/torchfix/deprecated_symbols.yaml index b2bf4c5..9cce56d 100644 --- a/torchfix/deprecated_symbols.yaml +++ b/torchfix/deprecated_symbols.yaml @@ -65,6 +65,11 @@ remove_pr: reference: https://github.com/pytorch-labs/torchfix#torchnnutilsweight_norm +- name: torch.backends.cuda.sdp_kernel + deprecate_pr: https://github.com/pytorch/pytorch/pull/114689 + remove_pr: + reference: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel + # functorch - name: functorch.vmap deprecate_pr: TBA From b2d55f8b91ca951ee512c452dc9676046b952b44 Mon Sep 17 00:00:00 2001 From: Eli Uriegas <1700823+seemethere@users.noreply.github.com> Date: Mon, 22 Apr 2024 12:16:43 -0700 Subject: [PATCH 29/66] torchfix: Refactor ERROR_CODE to be consistent (#46) --- tests/test_torchfix.py | 4 +-- torchfix/__main__.py | 10 ++---- torchfix/common.py | 22 +++++++++--- torchfix/torchfix.py | 27 ++++++-------- .../visitors/deprecated_symbols/__init__.py | 25 ++++++++----- torchfix/visitors/internal/__init__.py | 26 +++++++++----- torchfix/visitors/misc/__init__.py | 35 +++++++++++-------- torchfix/visitors/nonpublic/__init__.py | 31 ++++++++++++---- torchfix/visitors/performance/__init__.py | 23 +++++++----- torchfix/visitors/security/__init__.py | 25 +++++++------ torchfix/visitors/vision/models_import.py | 34 +++++++++++------- torchfix/visitors/vision/pretrained.py | 19 +++++++--- torchfix/visitors/vision/to_tensor.py | 20 ++++++----- 13 files changed, 187 insertions(+), 114 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index d699ea7..5f5dff9 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -62,8 +62,8 @@ def test_errorcodes_distinct(): seen = set() for visitor in visitors: LOGGER.info("Checking error code for %s", visitor.__class__.__name__) - error_code = visitor.ERROR_CODE - for e in error_code if isinstance(error_code, list) else [error_code]: + errors = visitor.ERRORS + for e in errors: assert e not in seen seen.add(e) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 5df0cf9..b8413bf 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -63,16 +63,12 @@ def main() -> None: parser.add_argument( "--select", help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. " - f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. " - f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", + f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. " + f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", type=str, default=None, ) - parser.add_argument( - "--version", - action="version", - version=f"{TorchFixVersion}" - ) + parser.add_argument("--version", action="version", version=f"{TorchFixVersion}") # XXX TODO: Get rid of this! # Silence "Failed to determine module name" diff --git a/torchfix/common.py b/torchfix/common.py index 7fdd00a..db2fb8c 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -1,10 +1,11 @@ -from dataclasses import dataclass import sys +from abc import ABC +from dataclasses import dataclass +from typing import List, Optional, Set, Tuple + import libcst as cst -from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider from libcst.codemod.visitors import ImportItem -from typing import Optional, List, Set, Tuple, Union -from abc import ABC +from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() CYAN = "\033[96m" if IS_TTY else "" @@ -34,13 +35,24 @@ def codemod_result(self) -> str: return f"{position} {error_code}{fixable} {self.message}" +@dataclass(frozen=True) +class TorchError: + """Defines an error along with an explanation""" + + error_code: str + message_template: str + + def message(self, **kwargs): + return self.message_template.format(**kwargs) + + class TorchVisitor(cst.BatchableCSTVisitor, ABC): METADATA_DEPENDENCIES = ( QualifiedNameProvider, WhitespaceInclusivePositionProvider, ) - ERROR_CODE: Union[str, List[str]] + ERRORS: List[TorchError] def __init__(self) -> None: self.violations: List[LintViolation] = [] diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index dda4057..85c3943 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -14,7 +14,7 @@ from .visitors.internal import TorchScopedLibraryVisitor from .visitors.performance import TorchSynchronizedDataLoaderVisitor -from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor) +from .visitors.misc import TorchRequireGradVisitor, TorchReentrantCheckpointVisitor from .visitors.nonpublic import TorchNonPublicAliasVisitor from .visitors.vision import ( @@ -48,10 +48,7 @@ def GET_ALL_ERROR_CODES(): codes = set() for cls in ALL_VISITOR_CLS: - if isinstance(cls.ERROR_CODE, list): - codes |= set(cls.ERROR_CODE) - else: - codes.add(cls.ERROR_CODE) + codes |= set(error.error_code for error in cls.ERRORS) return codes @@ -86,16 +83,10 @@ def get_visitors_with_error_codes(error_codes): # only correspond to one visitor. found = False for visitor_cls in ALL_VISITOR_CLS: - if isinstance(visitor_cls.ERROR_CODE, list): - if error_code in visitor_cls.ERROR_CODE: - visitor_classes.add(visitor_cls) - found = True - break - else: - if error_code == visitor_cls.ERROR_CODE: - visitor_classes.add(visitor_cls) - found = True - break + if error_code in list(err.error_code for err in visitor_cls.ERRORS): + visitor_classes.add(visitor_cls) + found = True + break if not found: raise AssertionError(f"Unknown error code: {error_code}") out = [] @@ -120,8 +111,10 @@ def process_error_code_str(code_str): if c == "ALL": continue if len(expand_error_codes((c,))) == 0: - raise ValueError(f"Invalid error code: {c}, available error " - f"codes: {list(GET_ALL_ERROR_CODES())}") + raise ValueError( + f"Invalid error code: {c}, available error " + f"codes: {list(GET_ALL_ERROR_CODES())}" + ) if "ALL" in raw_codes: return GET_ALL_ERROR_CODES() diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 9949242..f1cf61f 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -1,11 +1,12 @@ import libcst as cst import pkgutil import yaml -from typing import Optional +from typing import Optional, List from collections.abc import Sequence from ...common import ( TorchVisitor, + TorchError, call_with_name_changes, ) @@ -16,7 +17,12 @@ class TorchDeprecatedSymbolsVisitor(TorchVisitor): - ERROR_CODE = ["TOR001", "TOR101", "TOR004", "TOR103"] + ERRORS: List[TorchError] = [ + TorchError("TOR001", "Use of removed function {qualified_name}"), + TorchError("TOR101", "Import of deprecated function {qualified_name}"), + TorchError("TOR004", "Import of removed function {qualified_name}"), + TorchError("TOR103", "Import of deprecated function {qualified_name}"), + ] def __init__(self, deprecated_config_path=None): def read_deprecated_config(path=None): @@ -67,11 +73,11 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: qualified_name = f"{module}.{name.name.value}" if qualified_name in self.deprecated_config: if self.deprecated_config[qualified_name]["remove_pr"] is None: - error_code = self.ERROR_CODE[3] - message = f"Import of deprecated function {qualified_name}" + error_code = self.ERRORS[3].error_code + message = self.ERRORS[3].message(qualified_name=qualified_name) else: - error_code = self.ERROR_CODE[2] - message = f"Import of removed function {qualified_name}" + error_code = self.ERRORS[2].error_code + message = self.ERRORS[2].message(qualified_name=qualified_name) reference = self.deprecated_config[qualified_name].get("reference") if reference is not None: @@ -86,11 +92,12 @@ def visit_Call(self, node) -> None: if qualified_name in self.deprecated_config: if self.deprecated_config[qualified_name]["remove_pr"] is None: - error_code = self.ERROR_CODE[1] + error_code = self.ERRORS[1].error_code + message = self.ERRORS[1].message(qualified_name=qualified_name) message = f"Use of deprecated function {qualified_name}" else: - error_code = self.ERROR_CODE[0] - message = f"Use of removed function {qualified_name}" + error_code = self.ERRORS[0].error_code + message = self.ERRORS[0].message(qualified_name=qualified_name) replacement = self._call_replacement(node, qualified_name) reference = self.deprecated_config[qualified_name].get("reference") diff --git a/torchfix/visitors/internal/__init__.py b/torchfix/visitors/internal/__init__.py index 14389b3..908527a 100644 --- a/torchfix/visitors/internal/__init__.py +++ b/torchfix/visitors/internal/__init__.py @@ -1,4 +1,4 @@ -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchScopedLibraryVisitor(TorchVisitor): @@ -6,14 +6,24 @@ class TorchScopedLibraryVisitor(TorchVisitor): Suggest `torch.library._scoped_library` for PyTorch tests. """ - ERROR_CODE = "TOR901" - MESSAGE = ( - "Use `torch.library._scoped_library` instead of `torch.library.Library` " - "in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 " - "for details." - ) + ERRORS = [ + TorchError( + "TOR901", + ( + "Use `torch.library._scoped_library` " + "instead of `torch.library.Library` " + "in PyTorch tests files. " + "See https://github.com/pytorch/pytorch/pull/118318 " + "for details." + ), + ) + ] def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) if qualified_name == "torch.library.Library": - self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE) + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + ) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index ef60b3e..a8ee248 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -1,8 +1,7 @@ import libcst as cst import libcst.matchers as m - -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchRequireGradVisitor(TorchVisitor): @@ -10,8 +9,12 @@ class TorchRequireGradVisitor(TorchVisitor): Find and fix common misspelling `require_grad` (instead of `requires_grad`). """ - ERROR_CODE = "TOR002" - MESSAGE = "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?" + ERRORS = [ + TorchError( + "TOR002", + "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?", + ) + ] def visit_Assign(self, node): # Look for any assignment with `require_grad` attribute on the left. @@ -33,8 +36,8 @@ def visit_Assign(self, node): ) self.add_violation( node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), replacement=replacement, ) @@ -44,12 +47,16 @@ class TorchReentrantCheckpointVisitor(TorchVisitor): Find and fix common misuse of reentrant checkpoints. """ - ERROR_CODE = "TOR003" - MESSAGE = ( - "Please pass `use_reentrant` explicitly to `checkpoint`. " - "To maintain old behavior, pass `use_reentrant=True`. " - "It is recommended to use `use_reentrant=False`." - ) + ERRORS = [ + TorchError( + "TOR003", + ( + "Please pass `use_reentrant` explicitly to `checkpoint`. " + "To maintain old behavior, pass `use_reentrant=True`. " + "It is recommended to use `use_reentrant=False`." + ), + ) + ] def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) @@ -65,7 +72,7 @@ def visit_Call(self, node): replacement = node.with_changes(args=node.args + (use_reentrant_arg,)) self.add_violation( node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), replacement=replacement, ) diff --git a/torchfix/visitors/nonpublic/__init__.py b/torchfix/visitors/nonpublic/__init__.py index 7240e0f..575ad9d 100644 --- a/torchfix/visitors/nonpublic/__init__.py +++ b/torchfix/visitors/nonpublic/__init__.py @@ -1,10 +1,10 @@ from os.path import commonprefix -from typing import Sequence +from typing import Sequence, List import libcst as cst from libcst.codemod.visitors import ImportItem -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchNonPublicAliasVisitor(TorchVisitor): @@ -17,7 +17,20 @@ class TorchNonPublicAliasVisitor(TorchVisitor): see https://github.com/pytorch/pytorch/pull/69862/files """ - ERROR_CODE = ["TOR104", "TOR105"] + ERRORS: List[TorchError] = [ + TorchError( + "TOR104", ( + "Use of non-public function `{qualified_name}`, " + "please use `{public_name}` instead" + ), + ), + TorchError( + "TOR105", ( + "Import of non-public function `{qualified_name}`, " + "please use `{public_name}` instead" + ), + ), + ] # fmt: off ALIASES = { @@ -33,8 +46,10 @@ def visit_Call(self, node): if qualified_name in self.ALIASES: public_name = self.ALIASES[qualified_name] - error_code = self.ERROR_CODE[0] - message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 + error_code = self.ERRORS[0].error_code + message = self.ERRORS[0].message( + qualified_name=qualified_name, public_name=public_name + ) call_name = cst.helpers.get_full_name_for_node(node) replacement = None @@ -74,8 +89,10 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: qualified_name = f"{module}.{name.name.value}" if qualified_name in self.ALIASES: public_name = self.ALIASES[qualified_name] - error_code = self.ERROR_CODE[1] - message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 + error_code = self.ERRORS[1].error_code + message = self.ERRORS[1].message( + qualified_name=qualified_name, public_name=public_name + ) new_module = ".".join(public_name.split(".")[:-1]) new_name = public_name.split(".")[-1] diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index 427eb78..249df4c 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -1,7 +1,6 @@ import libcst.matchers as m - -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchSynchronizedDataLoaderVisitor(TorchVisitor): @@ -10,12 +9,16 @@ class TorchSynchronizedDataLoaderVisitor(TorchVisitor): https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py """ - ERROR_CODE = "TOR401" - MESSAGE = ( - "Detected DataLoader running with synchronized implementation. " - "Please enable asynchronous dataloading by setting num_workers > 0 when " - "initializing DataLoader." - ) + ERRORS = [ + TorchError( + "TOR401", + ( + "Detected DataLoader running with synchronized implementation." + " Please enable asynchronous dataloading by setting " + "num_workers > 0 when initializing DataLoader." + ), + ) + ] def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) @@ -25,5 +28,7 @@ def visit_Call(self, node): num_workers_arg.value, m.Integer(value="0") ): self.add_violation( - node, error_code=self.ERROR_CODE, message=self.MESSAGE + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), ) diff --git a/torchfix/visitors/security/__init__.py b/torchfix/visitors/security/__init__.py index 5dfdf6e..775bed9 100644 --- a/torchfix/visitors/security/__init__.py +++ b/torchfix/visitors/security/__init__.py @@ -1,5 +1,6 @@ import libcst as cst -from ...common import TorchVisitor + +from ...common import TorchError, TorchVisitor class TorchUnsafeLoadVisitor(TorchVisitor): @@ -8,13 +9,17 @@ class TorchUnsafeLoadVisitor(TorchVisitor): See https://github.com/pytorch/pytorch/issues/31875. """ - ERROR_CODE = "TOR102" - MESSAGE = ( - "`torch.load` without `weights_only` parameter is unsafe. " - "Explicitly set `weights_only` to False only if you trust the data you load " - "and full pickle functionality is needed, otherwise set " - "`weights_only=True`." - ) + ERRORS = [ + TorchError( + "TOR102", + ( + "`torch.load` without `weights_only` parameter is unsafe. " + "Explicitly set `weights_only` to False only if you trust " + "the data you load " "and full pickle functionality is needed," + " otherwise set `weights_only=True`." + ), + ) + ] def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) @@ -40,7 +45,7 @@ def visit_Call(self, node): ) self.add_violation( node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), replacement=replacement, ) diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py index de75d5a..f3b0797 100644 --- a/torchfix/visitors/vision/models_import.py +++ b/torchfix/visitors/vision/models_import.py @@ -1,24 +1,32 @@ import libcst as cst import libcst.matchers as m -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchVisionModelsImportVisitor(TorchVisitor): - ERROR_CODE = "TOR203" - MESSAGE = ( - "Consider replacing 'import torchvision.models as models' " - "with 'from torchvision import models'." - ) + ERRORS = [ + TorchError( + "TOR203", + ( + "Consider replacing 'import torchvision.models as models' " + "with 'from torchvision import models'." + ), + ) + ] def visit_Import(self, node: cst.Import) -> None: replacement = None for imported_item in node.names: - if m.matches(imported_item, m.ImportAlias( - name=m.Attribute(value=m.Name("torchvision"), - attr=m.Name("models")), - asname=m.AsName(name=m.Name("models")) - )): + if m.matches( + imported_item, + m.ImportAlias( + name=m.Attribute( + value=m.Name("torchvision"), attr=m.Name("models") + ), + asname=m.AsName(name=m.Name("models")), + ), + ): # Replace only if the import statement has no other names if len(node.names) == 1: replacement = cst.ImportFrom( @@ -27,8 +35,8 @@ def visit_Import(self, node: cst.Import) -> None: ) self.add_violation( node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), replacement=replacement, ) break diff --git a/torchfix/visitors/vision/pretrained.py b/torchfix/visitors/vision/pretrained.py index 99dd845..af52a0f 100644 --- a/torchfix/visitors/vision/pretrained.py +++ b/torchfix/visitors/vision/pretrained.py @@ -3,7 +3,7 @@ import libcst as cst from libcst.codemod.visitors import ImportItem -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): @@ -16,7 +16,12 @@ class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): otherwise only lint violation is emitted. """ - ERROR_CODE = "TOR201" + ERRORS = [ + TorchError( + "TOR201", + "Parameter `{old_arg_name}` is deprecated, please use `{new_arg_name}` instead.", + ) + ] # flake8: noqa: E105 # fmt: off @@ -215,13 +220,17 @@ def _new_arg_and_import( message = None pretrained_arg = self.get_specific_arg(node, "pretrained", 0) if pretrained_arg is not None: - message = "Parameter `pretrained` is deprecated, please use `weights` instead." + message = self.ERRORS[0].message( + old_arg_name="pretrained", new_arg_name="weights" + ) pretrained_backbone_arg = self.get_specific_arg( node, "pretrained_backbone", 1 ) if pretrained_backbone_arg is not None: - message = "Parameter `pretrained_backbone` is deprecated, please use `weights_backbone` instead." + message = self.ERRORS[0].message( + old_arg_name="pretrained_backbone", new_arg_name="weights_backbone" + ) replacement_args = list(node.args) @@ -250,7 +259,7 @@ def _new_arg_and_import( if message is not None: self.add_violation( node, - error_code=self.ERROR_CODE, + error_code=self.ERRORS[0].error_code, message=message, replacement=replacement, ) diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index 3395dd9..791a9e5 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -1,21 +1,25 @@ from collections.abc import Sequence + import libcst as cst -from ...common import TorchVisitor +from ...common import TorchError, TorchVisitor + +MESSAGE = ( + "The transform `v2.ToTensor()` is deprecated and will be removed " + "in a future release. Instead, please use " + "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 +) class TorchVisionDeprecatedToTensorVisitor(TorchVisitor): - ERROR_CODE = "TOR202" - MESSAGE = ( - "The transform `v2.ToTensor()` is deprecated and will be removed " - "in a future release. Instead, please use " - "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 - ) + ERRORS = [TorchError("TOR202", MESSAGE)] def _maybe_add_violation(self, qualified_name, node): if qualified_name != "torchvision.transforms.v2.ToTensor": return - self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE) + self.add_violation( + node, error_code=self.ERRORS[0].error_code, message=self.ERRORS[0].message() + ) def visit_ImportFrom(self, node): module_path = cst.helpers.get_absolute_module_from_package_for_import( From af37f69cffa84ccae903bb601c6ca3bb522dd6c7 Mon Sep 17 00:00:00 2001 From: Ivan Zaitsev <108101595+izaitsevfb@users.noreply.github.com> Date: Mon, 22 Apr 2024 13:15:53 -0700 Subject: [PATCH 30/66] [refactoring] Extract helper method `has_specific_arg` (#49) fix #8, extract helper method `has_specific_arg` that checks for the call argument presence, and simplify all relevant call sites --- torchfix/common.py | 17 ++++++++- torchfix/visitors/misc/__init__.py | 33 +++++++++-------- torchfix/visitors/security/__init__.py | 49 ++++++++++++-------------- 3 files changed, 55 insertions(+), 44 deletions(-) diff --git a/torchfix/common.py b/torchfix/common.py index db2fb8c..36d144a 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -62,7 +62,10 @@ def __init__(self) -> None: def get_specific_arg( node: cst.Call, arg_name: str, arg_pos: int ) -> Optional[cst.Arg]: - # `arg_pos` is zero-based. + """ + :param arg_pos: `arg_pos` is zero-based. -1 means it's a keyword argument. + :note: consider using `has_specific_arg` if you only need to check for presence. + """ curr_pos = 0 for arg in node.args: if arg.keyword is None: @@ -73,6 +76,18 @@ def get_specific_arg( return arg return None + @staticmethod + def has_specific_arg( + node: cst.Call, arg_name: str, position: Optional[int] = None + ) -> bool: + """ + Check if the specific argument is present in a call. + """ + return TorchVisitor.get_specific_arg( + node, arg_name, + position if position is not None else -1 + ) is not None + def add_violation( self, node: cst.CSTNode, diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index a8ee248..08601c0 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -59,20 +59,19 @@ class TorchReentrantCheckpointVisitor(TorchVisitor): ] def visit_Call(self, node): - qualified_name = self.get_qualified_name_for_call(node) - if qualified_name == "torch.utils.checkpoint.checkpoint": - use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1) - if use_reentrant_arg is None: - # This codemod maybe unsafe correctness-wise - # if reentrant behavior is actually needed, - # so the changes need to be verified/tested. - use_reentrant_arg = cst.ensure_type( - cst.parse_expression("f(use_reentrant=False)"), cst.Call - ).args[0] - replacement = node.with_changes(args=node.args + (use_reentrant_arg,)) - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=replacement, - ) + if (self.get_qualified_name_for_call(node) == + "torch.utils.checkpoint.checkpoint" and + not self.has_specific_arg(node, "use_reentrant")): + # This codemod maybe unsafe correctness-wise + # if reentrant behavior is actually needed, + # so the changes need to be verified/tested. + use_reentrant_arg = cst.ensure_type( + cst.parse_expression("f(use_reentrant=False)"), cst.Call + ).args[0] + replacement = node.with_changes(args=node.args + (use_reentrant_arg,)) + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=replacement, + ) diff --git a/torchfix/visitors/security/__init__.py b/torchfix/visitors/security/__init__.py index 775bed9..c53c43f 100644 --- a/torchfix/visitors/security/__init__.py +++ b/torchfix/visitors/security/__init__.py @@ -22,30 +22,27 @@ class TorchUnsafeLoadVisitor(TorchVisitor): ] def visit_Call(self, node): - qualified_name = self.get_qualified_name_for_call(node) - if qualified_name == "torch.load": - weights_only_arg = self.get_specific_arg(node, "weights_only", -1) - if weights_only_arg is None: - # Add `weights_only=True` if there is no `pickle_module`. - # (do not add `weights_only=False` with `pickle_module`, as it - # needs to be an explicit choice). - # - # This codemod is somewhat unsafe correctness-wise - # because full pickling functionality may still be needed - # even without `pickle_module`, - # so the changes need to be verified/tested. - replacement = None - pickle_module_arg = self.get_specific_arg(node, "pickle_module", 2) - if pickle_module_arg is None: - weights_only_arg = cst.ensure_type( - cst.parse_expression("f(weights_only=True)"), cst.Call - ).args[0] - replacement = node.with_changes( - args=node.args + (weights_only_arg,) - ) - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=replacement, + if self.get_qualified_name_for_call(node) == "torch.load" and \ + not self.has_specific_arg(node, "weights_only"): + # Add `weights_only=True` if there is no `pickle_module`. + # (do not add `weights_only=False` with `pickle_module`, as it + # needs to be an explicit choice). + # + # This codemod is somewhat unsafe correctness-wise + # because full pickling functionality may still be needed + # even without `pickle_module`, + # so the changes need to be verified/tested. + replacement = None + if not self.has_specific_arg(node, "pickle_module", 2): + weights_only_arg = cst.ensure_type( + cst.parse_expression("f(weights_only=True)"), cst.Call + ).args[0] + replacement = node.with_changes( + args=node.args + (weights_only_arg,) ) + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=replacement, + ) From a71baf17472193ab7907c56dfc31685ad5e48bdb Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Mon, 22 Apr 2024 16:32:12 -0500 Subject: [PATCH 31/66] Add rule for deprecated _register_pytree_node (#44) * Add rule for deprecated _register_pytree_node * fix test * Add codemod test --- .../checker/deprecated_register_pytree_node.py | 9 +++++++++ .../checker/deprecated_register_pytree_node.txt | 4 ++++ .../deprecated_symbols/codemod/register_pytree_node.py | 9 +++++++++ .../codemod/register_pytree_node.py.out | 9 +++++++++ torchfix/deprecated_symbols.yaml | 5 +++++ 5 files changed, 36 insertions(+) create mode 100644 tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.py create mode 100644 tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.txt create mode 100644 tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py create mode 100644 tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out diff --git a/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.py b/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.py new file mode 100644 index 0000000..b594d68 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.py @@ -0,0 +1,9 @@ +from torch.utils._pytree import _register_pytree_node + +_register_pytree_node() + +from torch.utils import _pytree as xx +xx._register_pytree_node() + +import torch +torch.utils._pytree._register_pytree_node() diff --git a/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.txt b/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.txt new file mode 100644 index 0000000..eb93906 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.txt @@ -0,0 +1,4 @@ +1:1 TOR103 Import of deprecated function torch.utils._pytree._register_pytree_node +3:1 TOR101 Use of deprecated function torch.utils._pytree._register_pytree_node +6:1 TOR101 Use of deprecated function torch.utils._pytree._register_pytree_node +9:1 TOR101 Use of deprecated function torch.utils._pytree._register_pytree_node \ No newline at end of file diff --git a/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py new file mode 100644 index 0000000..b594d68 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py @@ -0,0 +1,9 @@ +from torch.utils._pytree import _register_pytree_node + +_register_pytree_node() + +from torch.utils import _pytree as xx +xx._register_pytree_node() + +import torch +torch.utils._pytree._register_pytree_node() diff --git a/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out new file mode 100644 index 0000000..67d838e --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out @@ -0,0 +1,9 @@ +from torch.utils._pytree import register_pytree_node, _register_pytree_node + +register_pytree_node() + +from torch.utils import _pytree as xx +xx.register_pytree_node() + +import torch +torch.utils._pytree.register_pytree_node() diff --git a/torchfix/deprecated_symbols.yaml b/torchfix/deprecated_symbols.yaml index 9cce56d..a24f8f3 100644 --- a/torchfix/deprecated_symbols.yaml +++ b/torchfix/deprecated_symbols.yaml @@ -65,6 +65,11 @@ remove_pr: reference: https://github.com/pytorch-labs/torchfix#torchnnutilsweight_norm +- name: torch.utils._pytree._register_pytree_node + deprecate_pr: https://github.com/pytorch/pytorch/pull/112111 + remove_pr: + replacement: torch.utils._pytree.register_pytree_node + - name: torch.backends.cuda.sdp_kernel deprecate_pr: https://github.com/pytorch/pytorch/pull/114689 remove_pr: From effb27ebd23ba8a32573c4d80d18b762078d2eea Mon Sep 17 00:00:00 2001 From: Ivan Zaitsev <108101595+izaitsevfb@users.noreply.github.com> Date: Mon, 22 Apr 2024 15:13:42 -0700 Subject: [PATCH 32/66] Add pre-commit hooks add lint docs (#51) fix #48 * add pre-commit hooks * add black to CI * add contributor guidelines re: linting * reformat code using black * address mypy warnings re: generators ### Testing: 1. ``` pip install -r requirements-dev.txt pre-commit install pre-commit run --all-files ``` image 2. ``` git commit -m 'test' ``` 3. Black CI works: https://github.com/pytorch-labs/torchfix/actions/runs/8791096314/job/24124561579?pr=51 --- .flake8 | 1 + .github/workflows/test-torchfix.yml | 3 +++ .pre-commit-config.yaml | 23 ++++++++++++++++++++ CONTRIBUTING.md | 28 +++++++++++++++++++++++-- requirements-dev.txt | 2 ++ torchfix/common.py | 12 ++++++----- torchfix/torchfix.py | 4 ++-- torchfix/visitors/misc/__init__.py | 8 ++++--- torchfix/visitors/nonpublic/__init__.py | 6 ++++-- torchfix/visitors/security/__init__.py | 12 +++++------ 10 files changed, 79 insertions(+), 20 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.flake8 b/.flake8 index 13dea6a..4e741ee 100644 --- a/.flake8 +++ b/.flake8 @@ -3,3 +3,4 @@ exclude = ./tests/fixtures/ # Match black tool's default. max-line-length = 88 +extend-ignore = E203 diff --git a/.github/workflows/test-torchfix.yml b/.github/workflows/test-torchfix.yml index a80c845..e6a1aac 100644 --- a/.github/workflows/test-torchfix.yml +++ b/.github/workflows/test-torchfix.yml @@ -25,3 +25,6 @@ jobs: - name: Run mypy run: | mypy . + - name: Run black + run: | + black --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2f41268 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: local + hooks: + - id: black + name: black + entry: black + language: system + types: [python] + args: ["--config=./pyproject.toml"] + exclude: ^tests/fixtures/ + - id: flake8 + name: flake8 + entry: flake8 + language: system + types: [python] + args: ["--config=./.flake8"] + exclude: ^tests/fixtures/ + - id: mypy + name: mypy + entry: mypy + language: system + types: [python] + exclude: ^tests/fixtures/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6e27b16..ea84a36 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,10 +14,34 @@ We actively welcome your pull requests. 1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. -4. Ensure the test suite passes. -5. Make sure your code lints. +4. Ensure the test suite passes (`pytest tests`). +5. Make sure your code lints (see Linting section below). 6. If you haven't already, complete the Contributor License Agreement ("CLA"). +## Linting + +We use `black`, `flake8`, and `mypy` to lint the code. +``` +pip install -r requirements-dev.txt +``` + +Linting via pre-commit hook: +``` +# install pre-commit hooks for the first time +pre-commit install + +# manually run pre-commit hooks on all files (runs all linters) +pre-commit run --all-files +``` + +Manually running individual linters: +``` +black . +flake8 +mypy . +``` + + ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only diff --git a/requirements-dev.txt b/requirements-dev.txt index 134840c..7910c57 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,3 +3,5 @@ pytest==7.2.0 libcst==1.1.0 types-PyYAML==6.0.7 mypy==1.7.0 +black==24.4.0 +pre-commit==3.7.0 diff --git a/torchfix/common.py b/torchfix/common.py index 36d144a..73c0a35 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -78,15 +78,17 @@ def get_specific_arg( @staticmethod def has_specific_arg( - node: cst.Call, arg_name: str, position: Optional[int] = None + node: cst.Call, arg_name: str, position: Optional[int] = None ) -> bool: """ Check if the specific argument is present in a call. """ - return TorchVisitor.get_specific_arg( - node, arg_name, - position if position is not None else -1 - ) is not None + return ( + TorchVisitor.get_specific_arg( + node, arg_name, position if position is not None else -1 + ) + is not None + ) def add_violation( self, diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 85c3943..369bd98 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -48,7 +48,7 @@ def GET_ALL_ERROR_CODES(): codes = set() for cls in ALL_VISITOR_CLS: - codes |= set(error.error_code for error in cls.ERRORS) + codes |= {error.error_code for error in cls.ERRORS} return codes @@ -83,7 +83,7 @@ def get_visitors_with_error_codes(error_codes): # only correspond to one visitor. found = False for visitor_cls in ALL_VISITOR_CLS: - if error_code in list(err.error_code for err in visitor_cls.ERRORS): + if any(error_code == err.error_code for err in visitor_cls.ERRORS): visitor_classes.add(visitor_cls) found = True break diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 08601c0..be5f1c9 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -59,9 +59,11 @@ class TorchReentrantCheckpointVisitor(TorchVisitor): ] def visit_Call(self, node): - if (self.get_qualified_name_for_call(node) == - "torch.utils.checkpoint.checkpoint" and - not self.has_specific_arg(node, "use_reentrant")): + if self.get_qualified_name_for_call( + node + ) == "torch.utils.checkpoint.checkpoint" and not self.has_specific_arg( + node, "use_reentrant" + ): # This codemod maybe unsafe correctness-wise # if reentrant behavior is actually needed, # so the changes need to be verified/tested. diff --git a/torchfix/visitors/nonpublic/__init__.py b/torchfix/visitors/nonpublic/__init__.py index 575ad9d..839feab 100644 --- a/torchfix/visitors/nonpublic/__init__.py +++ b/torchfix/visitors/nonpublic/__init__.py @@ -19,13 +19,15 @@ class TorchNonPublicAliasVisitor(TorchVisitor): ERRORS: List[TorchError] = [ TorchError( - "TOR104", ( + "TOR104", + ( "Use of non-public function `{qualified_name}`, " "please use `{public_name}` instead" ), ), TorchError( - "TOR105", ( + "TOR105", + ( "Import of non-public function `{qualified_name}`, " "please use `{public_name}` instead" ), diff --git a/torchfix/visitors/security/__init__.py b/torchfix/visitors/security/__init__.py index c53c43f..e0ecc92 100644 --- a/torchfix/visitors/security/__init__.py +++ b/torchfix/visitors/security/__init__.py @@ -15,15 +15,17 @@ class TorchUnsafeLoadVisitor(TorchVisitor): ( "`torch.load` without `weights_only` parameter is unsafe. " "Explicitly set `weights_only` to False only if you trust " - "the data you load " "and full pickle functionality is needed," + "the data you load " + "and full pickle functionality is needed," " otherwise set `weights_only=True`." ), ) ] def visit_Call(self, node): - if self.get_qualified_name_for_call(node) == "torch.load" and \ - not self.has_specific_arg(node, "weights_only"): + if self.get_qualified_name_for_call( + node + ) == "torch.load" and not self.has_specific_arg(node, "weights_only"): # Add `weights_only=True` if there is no `pickle_module`. # (do not add `weights_only=False` with `pickle_module`, as it # needs to be an explicit choice). @@ -37,9 +39,7 @@ def visit_Call(self, node): weights_only_arg = cst.ensure_type( cst.parse_expression("f(weights_only=True)"), cst.Call ).args[0] - replacement = node.with_changes( - args=node.args + (weights_only_arg,) - ) + replacement = node.with_changes(args=node.args + (weights_only_arg,)) self.add_violation( node, error_code=self.ERRORS[0].error_code, From 894113f259270b7b6260b25fe36b5e09910bfd1d Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Wed, 24 Apr 2024 15:05:27 -0700 Subject: [PATCH 33/66] Fix small bug after https://github.com/pytorch-labs/torchfix/pull/46 (#52) --- torchfix/visitors/deprecated_symbols/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index f1cf61f..adefda3 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -19,7 +19,7 @@ class TorchDeprecatedSymbolsVisitor(TorchVisitor): ERRORS: List[TorchError] = [ TorchError("TOR001", "Use of removed function {qualified_name}"), - TorchError("TOR101", "Import of deprecated function {qualified_name}"), + TorchError("TOR101", "Use of deprecated function {qualified_name}"), TorchError("TOR004", "Import of removed function {qualified_name}"), TorchError("TOR103", "Import of deprecated function {qualified_name}"), ] @@ -94,7 +94,6 @@ def visit_Call(self, node) -> None: if self.deprecated_config[qualified_name]["remove_pr"] is None: error_code = self.ERRORS[1].error_code message = self.ERRORS[1].message(qualified_name=qualified_name) - message = f"Use of deprecated function {qualified_name}" else: error_code = self.ERRORS[0].error_code message = self.ERRORS[0].message(qualified_name=qualified_name) From 615e3657b8840f7e33ec3e771329b5b5962101d8 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 25 Apr 2024 11:53:55 -0700 Subject: [PATCH 34/66] Refactor how import and call updates are handled (#53) This PR refactors code to update calls and imports from `TorchNonPublicAliasVisitor` into common functionality and uses the functionality for `TorchDeprecatedSymbolsVisitor`. This allowed to fix https://github.com/pytorch-labs/torchfix/issues/50 and resolve a TODO to remove `_UpdateFunctorchImports`. --- .../codemod/ger-outer.py.out | 2 +- .../codemod/register_pytree_node.py.out | 2 +- torchfix/common.py | 110 +++++++++++++----- torchfix/torchfix.py | 11 +- .../visitors/deprecated_symbols/__init__.py | 81 +++++-------- torchfix/visitors/nonpublic/__init__.py | 88 +++++--------- 6 files changed, 148 insertions(+), 146 deletions(-) diff --git a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out index 3303ed0..3378fde 100644 --- a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out +++ b/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out @@ -1,5 +1,5 @@ import torch -from torch import outer, ger +from torch import outer deprecated = torch.norm() sinusoid_inp = torch.outer(pos_seq, inv_freq) other = something.ger(pos_seq, inv_freq) diff --git a/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out index 67d838e..28bfe2b 100644 --- a/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out +++ b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out @@ -1,4 +1,4 @@ -from torch.utils._pytree import register_pytree_node, _register_pytree_node +from torch.utils._pytree import register_pytree_node register_pytree_node() diff --git a/torchfix/common.py b/torchfix/common.py index 73c0a35..e2c9c22 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -1,7 +1,8 @@ import sys from abc import ABC from dataclasses import dataclass -from typing import List, Optional, Set, Tuple +from os.path import commonprefix +from typing import Dict, List, Optional, Sequence, Set, Tuple import libcst as cst from libcst.codemod.visitors import ImportItem @@ -132,51 +133,106 @@ def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]: def call_with_name_changes( - node: cst.Call, old_qualified_name: str, new_qualified_name: str + node: cst.Call, qualified_name: str, new_qualified_name: str ) -> Optional[Tuple[cst.Call, Set[ImportItem]]]: """ Return an optional tuple: new `Call` node with name changes and a set of newly needed imports. """ - old_begin, _, old_last = old_qualified_name.rpartition(".") - new_begin, _, new_last = new_qualified_name.rpartition(".") needed_imports: Set[ImportItem] = set() - - # If the only difference is the last name part. - if old_begin == new_begin: - if isinstance(node.func, cst.Attribute): - replacement = node.with_deep_changes( - old_node=node.func.attr, - value=new_last, - ) - elif isinstance(node.func, cst.Name): - replacement = node.with_deep_changes( - old_node=node.func, - value=new_last, - ) + call_name = cst.helpers.get_full_name_for_node(node) + assert call_name is not None + replacement = None + + alias_prefix = "" + if not qualified_name.endswith(call_name): + # This means we have an alias (`import from as`). + common_suffix = commonprefix([qualified_name[::-1], call_name[::-1]])[::-1] + alias_prefix = call_name.removesuffix(common_suffix) + "." + + if not new_qualified_name.endswith(call_name): + # We need to change the call name as it's not a part of the new qualified name. + # Get the new call name on the same hierarchical level. + new_call_name = new_qualified_name.removeprefix( + commonprefix([qualified_name.removesuffix(call_name), new_qualified_name]) + ) + new_call_name = new_call_name + new_module_name = new_qualified_name.removesuffix(new_call_name).removesuffix( + "." + ) + if new_module_name: needed_imports.add( ImportItem( - module_name=new_begin, - obj_name=new_last, + module_name=new_module_name, + obj_name=new_call_name.split(".")[0], ) ) - - # If the last name part is the same and - # originally called without a dot: don't change the call site, - # just change the imports elsewhere. - elif old_last == new_last and isinstance(node.func, cst.Name): - replacement = None + replacement = node.with_changes( + func=cst.parse_expression(alias_prefix + new_call_name) + ) # Replace with new_qualified_name. - else: - replacement = node.with_changes(func=cst.parse_expression(new_qualified_name)) if replacement is None: return None else: return replacement, needed_imports +def check_old_names_in_import_from( + node: cst.ImportFrom, old_new_name_map: Dict[str, Optional[str]] +) -> Tuple[List[str], Optional[cst.ImportFrom]]: + """ + Using `old_new_name_map`, check if there are any old names in the import from. + Return a tuple of two elements: + 1. List of all founds old names. + 2. Optional replacement for the ImportFrom node. + """ + if node.module is None: + return [], None + + old_names: List[str] = [] + replacement = None + if isinstance(node.names, Sequence): + new_names: List[str] = [] + module = cst.helpers.get_full_name_for_node(node.module) + + # `possible_new_modules` and `has_non_updated_names` are used + # to decide if we can replace the ImportFrom node. + new_modules: Set[str] = set() + has_non_updated_names = False + + for name in node.names: + qualified_name = f"{module}.{name.name.value}" + if qualified_name in old_new_name_map: + old_names.append(qualified_name) + new_qualified_name = old_new_name_map[qualified_name] + if new_qualified_name is not None: + new_module = ".".join(new_qualified_name.split(".")[:-1]) + new_name = new_qualified_name.split(".")[-1] + new_names.append(new_name) + new_modules.add(new_module) + else: + has_non_updated_names = True + else: + has_non_updated_names = True + + # Replace only if the new module is the same for all names in the import. + if not has_non_updated_names and len(new_modules) == 1: + new_module = new_modules.pop() + import_aliases = list(node.names) + for i in range(len(import_aliases)): + import_aliases[i] = import_aliases[i].with_changes( + name=cst.Name(new_names[i]) + ) + replacement = node.with_changes( + module=cst.parse_expression(new_module), # type: ignore[arg-type] # noqa: E501 + names=import_aliases, + ) + + return old_names, replacement + + def deep_multi_replace(tree, replacement_map): class MultiChildReplacementTransformer(cst.CSTTransformer): def __init__(self, replacement_map) -> None: diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 369bd98..32e3b9d 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -6,11 +6,7 @@ import libcst.codemod as codemod from .common import deep_multi_replace -from .visitors.deprecated_symbols import ( - TorchDeprecatedSymbolsVisitor, - _UpdateFunctorchImports, -) - +from .visitors.deprecated_symbols import TorchDeprecatedSymbolsVisitor from .visitors.internal import TorchScopedLibraryVisitor from .visitors.performance import TorchSynchronizedDataLoaderVisitor @@ -223,10 +219,7 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: ) new_module = new_module.visit(add_imports_visitor) - update_functorch_imports_visitor = _UpdateFunctorchImports() - new_module = new_module.visit(update_functorch_imports_visitor) - - if fixes_count == 0 and not update_functorch_imports_visitor.changed: + if fixes_count == 0: raise codemod.SkipFile("No changes") return new_module diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index adefda3..f2b9cf2 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -2,12 +2,12 @@ import pkgutil import yaml from typing import Optional, List -from collections.abc import Sequence from ...common import ( TorchVisitor, TorchError, call_with_name_changes, + check_old_names_in_import_from, ) from .range import call_replacement_range @@ -18,10 +18,10 @@ class TorchDeprecatedSymbolsVisitor(TorchVisitor): ERRORS: List[TorchError] = [ - TorchError("TOR001", "Use of removed function {qualified_name}"), - TorchError("TOR101", "Use of deprecated function {qualified_name}"), - TorchError("TOR004", "Import of removed function {qualified_name}"), - TorchError("TOR103", "Import of deprecated function {qualified_name}"), + TorchError("TOR001", "Use of removed function {old_name}"), + TorchError("TOR101", "Use of deprecated function {old_name}"), + TorchError("TOR004", "Import of removed function {old_name}"), + TorchError("TOR103", "Import of deprecated function {old_name}"), ] def __init__(self, deprecated_config_path=None): @@ -35,6 +35,10 @@ def read_deprecated_config(path=None): super().__init__() self.deprecated_config = read_deprecated_config(deprecated_config_path) + self.old_new_name_map = {} + for name in self.deprecated_config: + new_name = self.deprecated_config[name].get("replacement") + self.old_new_name_map[name] = new_name def _call_replacement( self, node: cst.Call, qualified_name: str @@ -67,23 +71,27 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: if node.module is None: return - module = cst.helpers.get_full_name_for_node(node.module) - if isinstance(node.names, Sequence): - for name in node.names: - qualified_name = f"{module}.{name.name.value}" - if qualified_name in self.deprecated_config: - if self.deprecated_config[qualified_name]["remove_pr"] is None: - error_code = self.ERRORS[3].error_code - message = self.ERRORS[3].message(qualified_name=qualified_name) - else: - error_code = self.ERRORS[2].error_code - message = self.ERRORS[2].message(qualified_name=qualified_name) + old_names, replacement = check_old_names_in_import_from( + node, self.old_new_name_map + ) + for qualified_name in old_names: + if self.deprecated_config[qualified_name]["remove_pr"] is None: + error_code = self.ERRORS[3].error_code + message = self.ERRORS[3].message(old_name=qualified_name) + else: + error_code = self.ERRORS[2].error_code + message = self.ERRORS[2].message(old_name=qualified_name) - reference = self.deprecated_config[qualified_name].get("reference") - if reference is not None: - message = f"{message}: {reference}" + reference = self.deprecated_config[qualified_name].get("reference") + if reference is not None: + message = f"{message}: {reference}" - self.add_violation(node, error_code=error_code, message=message) + self.add_violation( + node, + error_code=error_code, + message=message, + replacement=replacement, + ) def visit_Call(self, node) -> None: qualified_name = self.get_qualified_name_for_call(node) @@ -93,10 +101,10 @@ def visit_Call(self, node) -> None: if qualified_name in self.deprecated_config: if self.deprecated_config[qualified_name]["remove_pr"] is None: error_code = self.ERRORS[1].error_code - message = self.ERRORS[1].message(qualified_name=qualified_name) + message = self.ERRORS[1].message(old_name=qualified_name) else: error_code = self.ERRORS[0].error_code - message = self.ERRORS[0].message(qualified_name=qualified_name) + message = self.ERRORS[0].message(old_name=qualified_name) replacement = self._call_replacement(node, qualified_name) reference = self.deprecated_config[qualified_name].get("reference") @@ -106,32 +114,3 @@ def visit_Call(self, node) -> None: self.add_violation( node, error_code=error_code, message=message, replacement=replacement ) - - -# TODO: refactor/generalize this. -class _UpdateFunctorchImports(cst.CSTTransformer): - REPLACEMENTS = { - "vmap", - "grad", - "vjp", - "jvp", - "jacrev", - "jacfwd", - "hessian", - "functionalize", - } - - def __init__(self): - self.changed = False - - def leave_ImportFrom( - self, node: cst.ImportFrom, updated_node: cst.ImportFrom - ) -> cst.ImportFrom: - if ( - getattr(node.module, "value", None) == "functorch" - and isinstance(node.names, Sequence) - and all(name.name.value in self.REPLACEMENTS for name in node.names) - ): - self.changed = True - return updated_node.with_changes(module=cst.parse_expression("torch.func")) - return updated_node diff --git a/torchfix/visitors/nonpublic/__init__.py b/torchfix/visitors/nonpublic/__init__.py index 839feab..6b052d2 100644 --- a/torchfix/visitors/nonpublic/__init__.py +++ b/torchfix/visitors/nonpublic/__init__.py @@ -1,10 +1,13 @@ -from os.path import commonprefix -from typing import Sequence, List +from typing import List import libcst as cst -from libcst.codemod.visitors import ImportItem -from ...common import TorchError, TorchVisitor +from ...common import ( + TorchError, + TorchVisitor, + call_with_name_changes, + check_old_names_in_import_from, +) class TorchNonPublicAliasVisitor(TorchVisitor): @@ -21,14 +24,14 @@ class TorchNonPublicAliasVisitor(TorchVisitor): TorchError( "TOR104", ( - "Use of non-public function `{qualified_name}`, " + "Use of non-public function `{private_name}`, " "please use `{public_name}` instead" ), ), TorchError( "TOR105", ( - "Import of non-public function `{qualified_name}`, " + "Import of non-public function `{private_name}`, " "please use `{public_name}` instead" ), ), @@ -50,30 +53,17 @@ def visit_Call(self, node): public_name = self.ALIASES[qualified_name] error_code = self.ERRORS[0].error_code message = self.ERRORS[0].message( - qualified_name=qualified_name, public_name=public_name + private_name=qualified_name, public_name=public_name ) - call_name = cst.helpers.get_full_name_for_node(node) - replacement = None - if not public_name.endswith(call_name): - # We need to change the call name as it's not in the public name. - # Get the new call name on the same hierarchical level. - new_call_name = public_name.removeprefix( - commonprefix([qualified_name.removesuffix(call_name), public_name]) - ) - new_module_name = public_name.removesuffix(new_call_name).removesuffix( - "." - ) - if new_module_name: - self.needed_imports.add( - ImportItem( - module_name=new_module_name, - obj_name=new_call_name.split(".")[0], - ) - ) - replacement = node.with_changes( - func=cst.parse_expression(new_call_name) - ) + replacement_and_imports = call_with_name_changes( + node, qualified_name, public_name + ) + if replacement_and_imports is not None: + replacement, imports = replacement_and_imports + self.needed_imports.update(imports) + else: + replacement = None self.add_violation( node, error_code=error_code, message=message, replacement=replacement @@ -83,32 +73,16 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: if node.module is None: return - module = cst.helpers.get_full_name_for_node(node.module) - if not isinstance(node.names, Sequence): - return - - for name in node.names: - qualified_name = f"{module}.{name.name.value}" - if qualified_name in self.ALIASES: - public_name = self.ALIASES[qualified_name] - error_code = self.ERRORS[1].error_code - message = self.ERRORS[1].message( - qualified_name=qualified_name, public_name=public_name - ) - - new_module = ".".join(public_name.split(".")[:-1]) - new_name = public_name.split(".")[-1] - # Replace only if the import statement has no other names - if len(node.names) == 1: - replacement = cst.ImportFrom( - module=cst.parse_expression(new_module), # type: ignore[arg-type] # noqa: E501 - names=[cst.ImportAlias(name=cst.Name(new_name))], - ) - else: - replacement = None - self.add_violation( - node, - error_code=error_code, - message=message, - replacement=replacement, - ) + private_names, replacement = check_old_names_in_import_from(node, self.ALIASES) # type: ignore[arg-type] # noqa: E501 + for qualified_name in private_names: + public_name = self.ALIASES[qualified_name] + error_code = self.ERRORS[1].error_code + message = self.ERRORS[1].message( + private_name=qualified_name, public_name=public_name + ) + self.add_violation( + node, + error_code=error_code, + message=message, + replacement=replacement, + ) From fdad98692ed80a5d165affd0be073c53cd300cac Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Wed, 8 May 2024 18:29:36 -0700 Subject: [PATCH 35/66] Enable mypy check_untyped_defs (#55) --- pyproject.toml | 1 + tests/test_torchfix.py | 5 +++-- torchfix/torchfix.py | 4 +++- torchfix/visitors/deprecated_symbols/__init__.py | 1 + torchfix/visitors/vision/pretrained.py | 2 +- torchfix/visitors/vision/to_tensor.py | 2 +- 6 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 87bbb70..8c27031 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ exclude = "tests/fixtures/*" [tool.mypy] exclude = ["tests/fixtures", "build"] +check_untyped_defs = true [tool.setuptools.dynamic] version = {attr = "torchfix.torchfix.__version__"} diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 5f5dff9..d8be743 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -27,9 +27,10 @@ def _codemod_results(source_path): config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) new_module = codemod.transform_module(context, code) - if isinstance(new_module, codemod.TransformFailure): + if isinstance(new_module, codemod.TransformSuccess): + return new_module.code + elif isinstance(new_module, codemod.TransformFailure): raise new_module.error - return new_module.code def test_empty(): diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 32e3b9d..4ffaaa4 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -5,7 +5,7 @@ import libcst as cst import libcst.codemod as codemod -from .common import deep_multi_replace +from .common import deep_multi_replace, TorchVisitor from .visitors.deprecated_symbols import TorchDeprecatedSymbolsVisitor from .visitors.internal import TorchScopedLibraryVisitor @@ -44,6 +44,7 @@ def GET_ALL_ERROR_CODES(): codes = set() for cls in ALL_VISITOR_CLS: + assert issubclass(cls, TorchVisitor) codes |= {error.error_code for error in cls.ERRORS} return codes @@ -79,6 +80,7 @@ def get_visitors_with_error_codes(error_codes): # only correspond to one visitor. found = False for visitor_cls in ALL_VISITOR_CLS: + assert issubclass(visitor_cls, TorchVisitor) if any(error_code == err.error_code for err in visitor_cls.ERRORS): visitor_classes.add(visitor_cls) found = True diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index f2b9cf2..cf088e8 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -29,6 +29,7 @@ def read_deprecated_config(path=None): deprecated_config = {} if path is not None: data = pkgutil.get_data("torchfix", path) + assert data is not None for item in yaml.load(data, yaml.SafeLoader): deprecated_config[item["name"]] = item return deprecated_config diff --git a/torchfix/visitors/vision/pretrained.py b/torchfix/visitors/vision/pretrained.py index af52a0f..acbe564 100644 --- a/torchfix/visitors/vision/pretrained.py +++ b/torchfix/visitors/vision/pretrained.py @@ -177,7 +177,7 @@ class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): def visit_Call(self, node): def _new_arg_and_import( - old_arg: cst.Arg, is_backbone: bool + old_arg: Optional[cst.Arg], is_backbone: bool ) -> Optional[cst.Arg]: old_arg_name = "pretrained_backbone" if is_backbone else "pretrained" if old_arg is None or (model_name, old_arg_name) not in self.MODEL_WEIGHTS: diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index 791a9e5..02a5915 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -39,4 +39,4 @@ def visit_Attribute(self, node): if len(qualified_names) != 1: return - self._maybe_add_violation(qualified_names.pop().name, node) + self._maybe_add_violation(list(qualified_names)[0].name, node) From f3dd6943a7fd8c00e1b79df5a8086e86b961e8ad Mon Sep 17 00:00:00 2001 From: XIIFulminata Date: Thu, 13 Jun 2024 15:23:28 -0700 Subject: [PATCH 36/66] Add rules for `datasets` and `transforms` imports (#61) ## Context * There is a linter rule that recommends changing `import torchvision.models as models` to `from torchvision import models` * The same behavior is expected for `torchvision.datasets` and `torchvision.transforms` ## Changes * Extend the `models` rule to also address the first instance of the `datasets` and `transforms` imports * Rename the models import checker and related fixtures to be more generic since more imports are checked now ## Testing * Updated existing unit test and fixture * No lint errors for modified files --- .../fixtures/vision/checker/models_import.txt | 2 - .../{models_import.py => singleton_import.py} | 2 + .../vision/checker/singleton_import.txt | 4 ++ .../fixtures/vision/codemod/models_import.py | 5 -- .../vision/codemod/models_import.py.out | 5 -- .../vision/codemod/singleton_import.py | 9 +++ .../vision/codemod/singleton_import.py.out | 9 +++ torchfix/torchfix.py | 4 +- torchfix/visitors/vision/__init__.py | 2 +- torchfix/visitors/vision/models_import.py | 42 ------------- torchfix/visitors/vision/singleton_import.py | 60 +++++++++++++++++++ 11 files changed, 87 insertions(+), 57 deletions(-) delete mode 100644 tests/fixtures/vision/checker/models_import.txt rename tests/fixtures/vision/checker/{models_import.py => singleton_import.py} (72%) create mode 100644 tests/fixtures/vision/checker/singleton_import.txt delete mode 100644 tests/fixtures/vision/codemod/models_import.py delete mode 100644 tests/fixtures/vision/codemod/models_import.py.out create mode 100644 tests/fixtures/vision/codemod/singleton_import.py create mode 100644 tests/fixtures/vision/codemod/singleton_import.py.out delete mode 100644 torchfix/visitors/vision/models_import.py create mode 100644 torchfix/visitors/vision/singleton_import.py diff --git a/tests/fixtures/vision/checker/models_import.txt b/tests/fixtures/vision/checker/models_import.txt deleted file mode 100644 index 7a517da..0000000 --- a/tests/fixtures/vision/checker/models_import.txt +++ /dev/null @@ -1,2 +0,0 @@ -1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. -6:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. diff --git a/tests/fixtures/vision/checker/models_import.py b/tests/fixtures/vision/checker/singleton_import.py similarity index 72% rename from tests/fixtures/vision/checker/models_import.py rename to tests/fixtures/vision/checker/singleton_import.py index 3a16490..ad30130 100644 --- a/tests/fixtures/vision/checker/models_import.py +++ b/tests/fixtures/vision/checker/singleton_import.py @@ -4,3 +4,5 @@ import torchvision.models from torchvision.models import * import torchvision.models as models, torch +import torchvision.datasets as datasets +import torchvision.transforms as transforms diff --git a/tests/fixtures/vision/checker/singleton_import.txt b/tests/fixtures/vision/checker/singleton_import.txt new file mode 100644 index 0000000..d6e7421 --- /dev/null +++ b/tests/fixtures/vision/checker/singleton_import.txt @@ -0,0 +1,4 @@ +1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. +6:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. +7:1 TOR203 Consider replacing 'import torchvision.datasets as datasets' with 'from torchvision import datasets'. +8:1 TOR203 Consider replacing 'import torchvision.transforms as transforms' with 'from torchvision import transforms'. diff --git a/tests/fixtures/vision/codemod/models_import.py b/tests/fixtures/vision/codemod/models_import.py deleted file mode 100644 index 6b75141..0000000 --- a/tests/fixtures/vision/codemod/models_import.py +++ /dev/null @@ -1,5 +0,0 @@ -import torchvision.models as models -import torchvision.models as cnn - -# don't touch if more than one name imported -import torchvision.models as models, torch diff --git a/tests/fixtures/vision/codemod/models_import.py.out b/tests/fixtures/vision/codemod/models_import.py.out deleted file mode 100644 index 53269c1..0000000 --- a/tests/fixtures/vision/codemod/models_import.py.out +++ /dev/null @@ -1,5 +0,0 @@ -from torchvision import models -import torchvision.models as cnn - -# don't touch if more than one name imported -import torchvision.models as models, torch diff --git a/tests/fixtures/vision/codemod/singleton_import.py b/tests/fixtures/vision/codemod/singleton_import.py new file mode 100644 index 0000000..50d1ecd --- /dev/null +++ b/tests/fixtures/vision/codemod/singleton_import.py @@ -0,0 +1,9 @@ +import torchvision.models as models +import torchvision.models as cnn +import torchvision.datasets as datasets +import torchvision.datasets as datasets_alt +import torchvision.transforms as transforms +import torchvision.transforms as transforms_alt + +# don't touch if more than one name imported +import torchvision.models as models, torch diff --git a/tests/fixtures/vision/codemod/singleton_import.py.out b/tests/fixtures/vision/codemod/singleton_import.py.out new file mode 100644 index 0000000..e17def6 --- /dev/null +++ b/tests/fixtures/vision/codemod/singleton_import.py.out @@ -0,0 +1,9 @@ +from torchvision import models +import torchvision.models as cnn +from torchvision import datasets +import torchvision.datasets as datasets_alt +from torchvision import transforms +import torchvision.transforms as transforms_alt + +# don't touch if more than one name imported +import torchvision.models as models, torch diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 4ffaaa4..df78289 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -16,7 +16,7 @@ from .visitors.vision import ( TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, - TorchVisionModelsImportVisitor, + TorchVisionSingletonImportVisitor, ) from .visitors.security import TorchUnsafeLoadVisitor @@ -33,7 +33,7 @@ TorchSynchronizedDataLoaderVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, - TorchVisionModelsImportVisitor, + TorchVisionSingletonImportVisitor, TorchUnsafeLoadVisitor, TorchReentrantCheckpointVisitor, TorchNonPublicAliasVisitor, diff --git a/torchfix/visitors/vision/__init__.py b/torchfix/visitors/vision/__init__.py index 9bc944e..16ec6c4 100644 --- a/torchfix/visitors/vision/__init__.py +++ b/torchfix/visitors/vision/__init__.py @@ -1,3 +1,3 @@ from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401 from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401 -from .models_import import TorchVisionModelsImportVisitor # noqa: F401 +from .singleton_import import TorchVisionSingletonImportVisitor # noqa: F401 diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py deleted file mode 100644 index f3b0797..0000000 --- a/torchfix/visitors/vision/models_import.py +++ /dev/null @@ -1,42 +0,0 @@ -import libcst as cst -import libcst.matchers as m - -from ...common import TorchError, TorchVisitor - - -class TorchVisionModelsImportVisitor(TorchVisitor): - ERRORS = [ - TorchError( - "TOR203", - ( - "Consider replacing 'import torchvision.models as models' " - "with 'from torchvision import models'." - ), - ) - ] - - def visit_Import(self, node: cst.Import) -> None: - replacement = None - for imported_item in node.names: - if m.matches( - imported_item, - m.ImportAlias( - name=m.Attribute( - value=m.Name("torchvision"), attr=m.Name("models") - ), - asname=m.AsName(name=m.Name("models")), - ), - ): - # Replace only if the import statement has no other names - if len(node.names) == 1: - replacement = cst.ImportFrom( - module=cst.Name("torchvision"), - names=[cst.ImportAlias(name=cst.Name("models"))], - ) - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=replacement, - ) - break diff --git a/torchfix/visitors/vision/singleton_import.py b/torchfix/visitors/vision/singleton_import.py new file mode 100644 index 0000000..f64257b --- /dev/null +++ b/torchfix/visitors/vision/singleton_import.py @@ -0,0 +1,60 @@ +import libcst as cst +import libcst.matchers as m + +from ...common import TorchError, TorchVisitor + + +class TorchVisionSingletonImportVisitor(TorchVisitor): + ERRORS = [ + TorchError( + "TOR203", + ( + "Consider replacing 'import torchvision.datasets as datasets' " + "with 'from torchvision import datasets'." + ), + ), + TorchError( + "TOR203", + ( + "Consider replacing 'import torchvision.models as models' " + "with 'from torchvision import models'." + ), + ), + TorchError( + "TOR203", + ( + "Consider replacing 'import torchvision.transforms as transforms' " + "with 'from torchvision import transforms'." + ), + ), + ] + + # Keep attr order in sync with ERRORS. + REPLACEABLE_ATTRS = ["datasets", "models", "transforms"] + + def visit_Import(self, node: cst.Import) -> None: + replacement = None + for i, import_attr in enumerate(self.REPLACEABLE_ATTRS): + for imported_item in node.names: + if m.matches( + imported_item, + m.ImportAlias( + name=m.Attribute( + value=m.Name("torchvision"), attr=m.Name(import_attr) + ), + asname=m.AsName(name=m.Name(import_attr)), + ), + ): + # Replace only if the import statement has no other names + if len(node.names) == 1: + replacement = cst.ImportFrom( + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name(import_attr))], + ) + self.add_violation( + node, + error_code=self.ERRORS[i].error_code, + message=self.ERRORS[i].message(), + replacement=replacement, + ) + break From 63cf15228144f0f5cb6146b733a341e3649825a3 Mon Sep 17 00:00:00 2001 From: Arun Date: Tue, 18 Jun 2024 00:39:46 +0530 Subject: [PATCH 37/66] Add docs for torch.chain_matmul, torch.qr, torch.range, torch.cholesky (#59) Partially addresses #56 Added `TOR101` deprecation doc for `torch.cholesky`, `torch.chain_matmul`, `torch.qr`, `torch.range`. --- README.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/README.md b/README.md index 9fbb7bd..755344f 100644 --- a/README.md +++ b/README.md @@ -110,5 +110,37 @@ This function is deprecated. Use the `torch.nn.attention.sdpa_kernel` context ma Migration guide: Each boolean input parameter (defaulting to true unless specified) of `sdp_kernel` corresponds to a `SDPBackened`. If the input parameter is true, the corresponding backend should be added to the input list of `sdpa_kernel`. +#### torch.chain_matmul + +This function is deprecated in favor of `torch.linalg.multi_dot`. + +Migration guide: +`multi_dot` accepts a list of two or more tensors whereas `chain_matmul` accepted multiple tensors as input arguments. For migration, convert the multiple tensors in argument of `chain_matmul` into a list of two or more tensors for `multi_dot`. + +Example: Replace `torch.chain_matmul(a, b, c)` with `torch.linalg.multi_dot([a, b, c])`. + +#### torch.cholesky + +`torch.cholesky()` is deprecated in favor of `torch.linalg.cholesky()`. + +Migration guide: +* `L = torch.cholesky(A)` should be replaced with `L = torch.linalg.cholesky(A)`. +* `L = torch.cholesky(A, upper=True)` should be replaced with `L = torch.linalg.cholesky(A).mH` + +#### torch.qr + +`torch.qr()` is deprecated in favor of `torch.linalg.qr()`. + +Migration guide: +* The usage `Q, R = torch.qr(A)` should be replaced with `Q, R = torch.linalg.qr(A)`. +* The boolean parameter `some` of `torch.qr` is replaced with a string parameter `mode` in `torch.linalg.qr`. The corresponding change in usage is from `Q, R = torch.qr(A, some=False)` to `Q, R = torch.linalg.qr(A, mode="complete")`. + +#### torch.range + +The function `torch.range()` is deprecated as its usage is incompatible with Python's builtin range. Instead, use `torch.arange()` as it produces values in `[start, end)`. + +Migration guide: +* `torch.range(start, end)` produces values in the range of `[start, end]`. But `torch.arange(start, end)` produces values in `[start, end)`. For step size of 1, migrate usage from `torch.range(start, end, 1)` to `torch.arange(start, end+1, 1)`. + ## License TorchFix is BSD License licensed, as found in the LICENSE file. From a0d3b2e5699943a64cbdd8f2f349b82d37a66af5 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Fri, 16 Aug 2024 21:51:59 +0200 Subject: [PATCH 38/66] Document TOR004,TOR102,TOR103 and add `torch.symeig` to TOR001 (#65) Linked to https://github.com/pytorch-labs/torchfix/issues/56 --- README.md | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/README.md b/README.md index 755344f..19b3e73 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,36 @@ To get the LU factorization see `torch.lu`, which can be used with `torch.lu_sol `X = torch.solve(B, A).solution` should be replaced with `X = torch.linalg.solve(A, B)`. +#### torch.symeig + +This function was deprecated since PyTorch version 1.9 and is now removed. + +`torch.symeig` is deprecated in favor of `torch.linalg.eigh`. + +The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion. + +```python +L, _ = torch.symeig(A, upper=upper) +``` + +should be replaced with + +```python +L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L') +``` + +and + +```python +L, V = torch.symeig(A, eigenvectors=True) +``` + +should be replaced with + +```python +L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') +``` + ### TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`? This is a common misspelling that can lead to silent performance issues. @@ -81,6 +111,10 @@ from `True` to `False`. In the meantime, the value needs to be passed explicitly See this [forum post](https://dev-discuss.pytorch.org/t/bc-breaking-update-to-torch-utils-checkpoint-not-passing-in-use-reentrant-flag-will-raise-an-error/1745) for details. +### TOR004 Import of removed function + +See `TOR001`. + ### TOR101 Use of deprecated function #### torch.nn.utils.weight_norm @@ -142,5 +176,14 @@ The function `torch.range()` is deprecated as its usage is incompatible with Pyt Migration guide: * `torch.range(start, end)` produces values in the range of `[start, end]`. But `torch.arange(start, end)` produces values in `[start, end)`. For step size of 1, migrate usage from `torch.range(start, end, 1)` to `torch.arange(start, end+1, 1)`. +### TOR102 `torch.load` without `weights_only` parameter is unsafe. + +Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`. + +### TOR103 Import of deprecated function + +See `TOR101`. + ## License + TorchFix is BSD License licensed, as found in the LICENSE file. From 8aef63bc548d100a0bb3826b4ca6ce05330f5ccf Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Tue, 20 Aug 2024 00:35:30 +0200 Subject: [PATCH 39/66] Sort rule codes in CLI (#68) Sort the rule codes in the CLI help is more intuitive to the user. _Before_: ```text --select SELECT Comma-separated list of rules to enable or 'ALL' to enable all rules. Available rules: TOR001, TOR105, TOR402, TOR103, TOR401, TOR201, TOR004, TOR104, TOR501, TOR102, TOR202, TOR403, TOR203, TOR002, TOR901, TOR003, TOR101. Defaults to all except for TOR3, TOR4, TOR9. ``` _After_: ```text --select SELECT Comma-separated list of rules to enable or 'ALL' to enable all rules. Available rules: TOR001, TOR002, TOR003, TOR004, TOR101, TOR102, TOR103, TOR104, TOR105, TOR201, TOR202, TOR203, TOR401, TOR402, TOR403, TOR501, TOR901. Defaults to all except for TOR3, TOR4, TOR9. ``` Changes: - Sort rule codes - Minor code modernisation for readability (list comprehension, superfluous else) --- tests/test_torchfix.py | 2 +- torchfix/torchfix.py | 18 ++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index d8be743..17ff183 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -77,7 +77,7 @@ def test_parse_error_code_str(): ("TOR102", {"TOR102"}), ("TOR102,TOR101", {"TOR102", "TOR101"}), ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), - (None, GET_ALL_ERROR_CODES() - exclude_set), + (None, set(GET_ALL_ERROR_CODES()) - exclude_set), ] for case, expected in cases: assert expected == process_error_code_str(case) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index df78289..8dabacf 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -46,7 +46,7 @@ def GET_ALL_ERROR_CODES(): for cls in ALL_VISITOR_CLS: assert issubclass(cls, TorchVisitor) codes |= {error.error_code for error in cls.ERRORS} - return codes + return sorted(codes) @functools.cache @@ -62,15 +62,12 @@ def expand_error_codes(codes): def construct_visitor(cls): if cls is TorchDeprecatedSymbolsVisitor: return cls(DEPRECATED_CONFIG_PATH) - else: - return cls() + + return cls() def GET_ALL_VISITORS(): - out = [] - for v in ALL_VISITOR_CLS: - out.append(construct_visitor(v)) - return out + return [construct_visitor(v) for v in ALL_VISITOR_CLS] def get_visitors_with_error_codes(error_codes): @@ -87,10 +84,7 @@ def get_visitors_with_error_codes(error_codes): break if not found: raise AssertionError(f"Unknown error code: {error_code}") - out = [] - for cls in visitor_classes: - out.append(construct_visitor(cls)) - return out + return [construct_visitor(cls) for cls in visitor_classes] def process_error_code_str(code_str): @@ -100,7 +94,7 @@ def process_error_code_str(code_str): # Default when --select is not provided. if code_str is None: exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) - return GET_ALL_ERROR_CODES() - exclude_set + return set(GET_ALL_ERROR_CODES()) - exclude_set raw_codes = [s.strip() for s in code_str.split(",")] From 555bb8d59819e9117cc41b0f3df87f3ed1e6e60f Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Sat, 31 Aug 2024 00:39:16 +0200 Subject: [PATCH 40/66] Move development dependencies to `pyproject.toml` (#64) Minor PR to kick-off contributing to `torchfix`. I'm developing new functionality, which I plan to contribute in follow-up PRs. Changes: - Move `requirements-dev.txt` to `optional-dependencies` - Explicitly set the setuptools build system - Omit `pip install pre-commit` in CONTRIBUTING.md, already included in dev dependencies --- .github/workflows/test-torchfix.yml | 7 ++++--- CONTRIBUTING.md | 14 +++++++------- pyproject.toml | 21 ++++++++++++++++++++- requirements-dev.txt | 7 ------- torchfix/visitors/__init__.py | 0 5 files changed, 31 insertions(+), 18 deletions(-) delete mode 100644 requirements-dev.txt create mode 100644 torchfix/visitors/__init__.py diff --git a/.github/workflows/test-torchfix.yml b/.github/workflows/test-torchfix.yml index e6a1aac..c465286 100644 --- a/.github/workflows/test-torchfix.yml +++ b/.github/workflows/test-torchfix.yml @@ -10,12 +10,13 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - - name: Install requirements + - name: Upgrade build dependencies run: | - pip3 install -r requirements-dev.txt + pip3 install -U pip + pip3 install -U setuptools - name: Install TorchFix run: | - pip3 install . + pip3 install ".[dev]" - name: Run pytest run: | pytest tests diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ea84a36..8995e56 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,22 +20,22 @@ We actively welcome your pull requests. ## Linting -We use `black`, `flake8`, and `mypy` to lint the code. -``` -pip install -r requirements-dev.txt +We use `black`, `flake8`, and `mypy` to lint the code. Configuration is available to run lints via `pre-commit`. + +```shell +pip install ".[dev]" ``` Linting via pre-commit hook: -``` -# install pre-commit hooks for the first time -pre-commit install +```shell # manually run pre-commit hooks on all files (runs all linters) pre-commit run --all-files ``` Manually running individual linters: -``` + +```shell black . flake8 mypy . diff --git a/pyproject.toml b/pyproject.toml index 8c27031..7114063 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ +[build-system] +requires = ["setuptools >= 65.0"] +build-backend = "setuptools.build_meta" + [project] name = "TorchFix" requires-python = ">=3.9" @@ -9,7 +13,22 @@ classifiers = [ "Programming Language :: Python" ] dynamic = ["version"] -dependencies = ["flake8>=3.8.2", "PyYAML", "libcst>=1.1.0,<1.2.0"] +dependencies = [ + "flake8>=3.8.2", + "PyYAML", + "libcst>=1.1.0,<1.2.0" +] + +[project.optional-dependencies] +dev = [ + "flake8==6.0.0", + "pytest==7.2.0", + "libcst==1.1.0", + "types-PyYAML==6.0.7", + "mypy==1.7.0", + "black==24.4.0", + "pre-commit==3.7.0", +] [project.urls] Repository = "https://github.com/pytorch-labs/torchfix" diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 7910c57..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,7 +0,0 @@ -flake8==6.0.0 -pytest==7.2.0 -libcst==1.1.0 -types-PyYAML==6.0.7 -mypy==1.7.0 -black==24.4.0 -pre-commit==3.7.0 diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py new file mode 100644 index 0000000..e69de29 From c755731b1faa2e7b74e02386540b4b2aa546ab68 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Thu, 5 Sep 2024 21:34:38 +0200 Subject: [PATCH 41/66] Fix distinct error codes test (#73) Closes https://github.com/pytorch-labs/torchfix/issues/71 --- tests/test_torchfix.py | 4 ++-- torchfix/visitors/vision/singleton_import.py | 22 ++++---------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 17ff183..3d9ddc2 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -65,8 +65,8 @@ def test_errorcodes_distinct(): LOGGER.info("Checking error code for %s", visitor.__class__.__name__) errors = visitor.ERRORS for e in errors: - assert e not in seen - seen.add(e) + assert e.error_code not in seen + seen.add(e.error_code) def test_parse_error_code_str(): diff --git a/torchfix/visitors/vision/singleton_import.py b/torchfix/visitors/vision/singleton_import.py index f64257b..f2b207b 100644 --- a/torchfix/visitors/vision/singleton_import.py +++ b/torchfix/visitors/vision/singleton_import.py @@ -9,22 +9,8 @@ class TorchVisionSingletonImportVisitor(TorchVisitor): TorchError( "TOR203", ( - "Consider replacing 'import torchvision.datasets as datasets' " - "with 'from torchvision import datasets'." - ), - ), - TorchError( - "TOR203", - ( - "Consider replacing 'import torchvision.models as models' " - "with 'from torchvision import models'." - ), - ), - TorchError( - "TOR203", - ( - "Consider replacing 'import torchvision.transforms as transforms' " - "with 'from torchvision import transforms'." + "Consider replacing 'import torchvision.{module} as {module}' " + "with 'from torchvision import {module}'." ), ), ] @@ -53,8 +39,8 @@ def visit_Import(self, node: cst.Import) -> None: ) self.add_violation( node, - error_code=self.ERRORS[i].error_code, - message=self.ERRORS[i].message(), + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(module=import_attr), replacement=replacement, ) break From 458d6ab74cd9edb12828b2f098a20e24be070ce5 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Fri, 6 Sep 2024 18:47:59 +0200 Subject: [PATCH 42/66] TST: add deprecated false negative example (#74) --- .../fixtures/deprecated_symbols/checker/deprecated_from_nn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/fixtures/deprecated_symbols/checker/deprecated_from_nn.py b/tests/fixtures/deprecated_symbols/checker/deprecated_from_nn.py index 17e9058..a9f2866 100644 --- a/tests/fixtures/deprecated_symbols/checker/deprecated_from_nn.py +++ b/tests/fixtures/deprecated_symbols/checker/deprecated_from_nn.py @@ -6,3 +6,6 @@ import torch.nn as yy yy.UpsamplingNearest2d() + +func = torch.nn.UpsamplingNearest2d # not detected currently +func() From 311cdd72caa243882e2ed45860e04676069857c1 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Tue, 10 Sep 2024 00:36:15 +0200 Subject: [PATCH 43/66] CI: enable testing on macOS (#75) When I run `torchfix` on macOS the `contextlib.redirect_stderr` works as expected. However, in the code we make an exception for `Darwin` (macOS): https://github.com/pytorch-labs/torchfix/blob/main/torchfix/__main__.py#L25 By enabling this test in CI we could test if removal of the dup2 workaround is possible. --- .github/workflows/test-torchfix.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-torchfix.yml b/.github/workflows/test-torchfix.yml index c465286..9a61e01 100644 --- a/.github/workflows/test-torchfix.yml +++ b/.github/workflows/test-torchfix.yml @@ -6,10 +6,16 @@ on: jobs: test-torchfix: - runs-on: ubuntu-latest + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{ matrix.os }} steps: - name: Checkout uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.10' - name: Upgrade build dependencies run: | pip3 install -U pip From 5699fc87dcd27e5aea3420c8251255cfc21284ad Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Tue, 10 Sep 2024 14:09:49 +0200 Subject: [PATCH 44/66] Split tests into individual test cases (#67) `metafunc.parametrize` can be used to create more granular test cases by dynamically generating test fixtures (i.e. arguments). This in our case is helpful for debugging failing test cases, by knowing directly which of the paths are failing. _Before_ _After_ Changes: - Split checker and codemod tests into individual test cases via metafunc - Few pathlib cleanups - Removed obviated logger calls in these functions --- tests/test_torchfix.py | 78 +++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 3d9ddc2..cc3631e 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -16,20 +16,44 @@ LOGGER = logging.getLogger(__name__) +def pytest_generate_tests(metafunc): + # Dynamically generate test cases from paths + if "checker_source_path" in metafunc.fixturenames: + files = list(FIXTURES_PATH.glob("**/checker/*.py")) + metafunc.parametrize( + "checker_source_path", files, ids=[file_name.stem for file_name in files] + ) + if "codemod_source_path" in metafunc.fixturenames: + files = list(FIXTURES_PATH.glob("**/codemod/*.py")) + metafunc.parametrize( + "codemod_source_path", files, ids=[file_name.stem for file_name in files] + ) + if "case" in metafunc.fixturenames: + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + cases = [ + ("ALL", GET_ALL_ERROR_CODES()), + ("ALL,TOR102", GET_ALL_ERROR_CODES()), + ("TOR102", {"TOR102"}), + ("TOR102,TOR101", {"TOR102", "TOR101"}), + ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), + (None, set(GET_ALL_ERROR_CODES()) - exclude_set), + ] + metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases]) + + def _checker_results(s): checker = TorchChecker(None, s) return [f"{line}:{col} {msg}" for line, col, msg, _ in checker.run()] -def _codemod_results(source_path): - with open(source_path) as source: - code = source.read() +def _codemod_results(source_path: Path): + code = source_path.read_text() config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) - context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) + context = TorchCodemod(codemod.CodemodContext(filename=str(source_path)), config) new_module = codemod.transform_module(context, code) if isinstance(new_module, codemod.TransformSuccess): return new_module.code - elif isinstance(new_module, codemod.TransformFailure): + if isinstance(new_module, codemod.TransformFailure): raise new_module.error @@ -37,25 +61,20 @@ def test_empty(): assert _checker_results([""]) == [] -def test_checker_fixtures(): - for source_path in FIXTURES_PATH.glob("**/checker/*.py"): - LOGGER.info("Testing %s", source_path.relative_to(Path.cwd())) - expected_path = str(source_path)[:-2] + "txt" - expected_results = [] - with open(expected_path) as expected: - for line in expected: - expected_results.append(line.rstrip()) +def test_checker_fixtures(checker_source_path: Path): + expected_path = checker_source_path.with_suffix(".txt") + expected_results = expected_path.read_text().splitlines() - with open(source_path) as source: - assert _checker_results(source.readlines()) == expected_results + assert ( + _checker_results(checker_source_path.read_text().splitlines(keepends=True)) + == expected_results + ) -def test_codemod_fixtures(): - for source_path in FIXTURES_PATH.glob("**/codemod/*.py"): - LOGGER.info("Testing %s", source_path.relative_to(Path.cwd())) - expected_path = source_path.with_suffix(".py.out") - expected_results = expected_path.read_text() - assert _codemod_results(source_path) == expected_results +def test_codemod_fixtures(codemod_source_path: Path): + expected_path = codemod_source_path.with_suffix(".py.out") + expected_results = expected_path.read_text() + assert _codemod_results(codemod_source_path) == expected_results def test_errorcodes_distinct(): @@ -63,21 +82,10 @@ def test_errorcodes_distinct(): seen = set() for visitor in visitors: LOGGER.info("Checking error code for %s", visitor.__class__.__name__) - errors = visitor.ERRORS - for e in errors: + for e in visitor.ERRORS: assert e.error_code not in seen seen.add(e.error_code) -def test_parse_error_code_str(): - exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) - cases = [ - ("ALL", GET_ALL_ERROR_CODES()), - ("ALL,TOR102", GET_ALL_ERROR_CODES()), - ("TOR102", {"TOR102"}), - ("TOR102,TOR101", {"TOR102", "TOR101"}), - ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), - (None, set(GET_ALL_ERROR_CODES()) - exclude_set), - ] - for case, expected in cases: - assert expected == process_error_code_str(case) +def test_parse_error_code_str(case, expected): + assert process_error_code_str(case) == expected From 6b2e85a165a49d997c6595cf61c7798c0c78036c Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 2 Sep 2024 11:29:39 +0200 Subject: [PATCH 45/66] Unused parentheses --- torchfix/__main__.py | 2 +- torchfix/common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index b8413bf..b8eaebf 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -46,7 +46,7 @@ def main() -> None: parser.add_argument( "path", nargs="+", - help=("Path to check/fix. Can be a directory, a file, or multiple of either."), + help="Path to check/fix. Can be a directory, a file, or multiple of either.", ) parser.add_argument( "--fix", diff --git a/torchfix/common.py b/torchfix/common.py index e2c9c22..e71f35d 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -26,7 +26,7 @@ class LintViolation: def flake8_result(self): full_message = f"{self.error_code} {self.message}" - return (self.line, 1 + self.column, full_message, "TorchFix") + return self.line, 1 + self.column, full_message, "TorchFix" def codemod_result(self) -> str: fixable = f" [{CYAN}*{ENDC}]" if self.replacement is not None else "" From 77ea129c1300f585a0728cca781b541b4ca0bbf5 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 2 Sep 2024 11:31:22 +0200 Subject: [PATCH 46/66] Missing `super().__init__()` calls --- torchfix/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchfix/common.py b/torchfix/common.py index e71f35d..e04164e 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -56,6 +56,7 @@ class TorchVisitor(cst.BatchableCSTVisitor, ABC): ERRORS: List[TorchError] def __init__(self) -> None: + super().__init__() self.violations: List[LintViolation] = [] self.needed_imports: Set[ImportItem] = set() @@ -236,6 +237,7 @@ def check_old_names_in_import_from( def deep_multi_replace(tree, replacement_map): class MultiChildReplacementTransformer(cst.CSTTransformer): def __init__(self, replacement_map) -> None: + super().__init__() self.replacement_map = replacement_map def on_leave(self, original_node, updated_node): From e36657f106fdaf173b9b8e40adc933a11d743af0 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 2 Sep 2024 11:35:57 +0200 Subject: [PATCH 47/66] Export visitors in `torchfix/visitors/__init__.py` --- torchfix/torchfix.py | 16 ++++++++-------- torchfix/visitors/__init__.py | 24 ++++++++++++++++++++++++ torchfix/visitors/vision/__init__.py | 12 +++++++++--- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 8dabacf..24e88b5 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -6,19 +6,19 @@ import libcst.codemod as codemod from .common import deep_multi_replace, TorchVisitor -from .visitors.deprecated_symbols import TorchDeprecatedSymbolsVisitor -from .visitors.internal import TorchScopedLibraryVisitor -from .visitors.performance import TorchSynchronizedDataLoaderVisitor -from .visitors.misc import TorchRequireGradVisitor, TorchReentrantCheckpointVisitor -from .visitors.nonpublic import TorchNonPublicAliasVisitor - -from .visitors.vision import ( +from .visitors import ( + TorchDeprecatedSymbolsVisitor, + TorchNonPublicAliasVisitor, + TorchReentrantCheckpointVisitor, + TorchRequireGradVisitor, + TorchScopedLibraryVisitor, + TorchSynchronizedDataLoaderVisitor, + TorchUnsafeLoadVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, ) -from .visitors.security import TorchUnsafeLoadVisitor __version__ = "0.5.0" diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index e69de29..af2b62b 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -0,0 +1,24 @@ +from .deprecated_symbols import TorchDeprecatedSymbolsVisitor +from .internal import TorchScopedLibraryVisitor +from .misc import TorchReentrantCheckpointVisitor, TorchRequireGradVisitor +from .nonpublic import TorchNonPublicAliasVisitor +from .performance import TorchSynchronizedDataLoaderVisitor +from .security import TorchUnsafeLoadVisitor +from .vision import ( + TorchVisionDeprecatedPretrainedVisitor, + TorchVisionDeprecatedToTensorVisitor, + TorchVisionSingletonImportVisitor, +) + +__all__ = [ + "TorchDeprecatedSymbolsVisitor", + "TorchRequireGradVisitor", + "TorchScopedLibraryVisitor", + "TorchSynchronizedDataLoaderVisitor", + "TorchVisionDeprecatedPretrainedVisitor", + "TorchVisionDeprecatedToTensorVisitor", + "TorchVisionSingletonImportVisitor", + "TorchUnsafeLoadVisitor", + "TorchReentrantCheckpointVisitor", + "TorchNonPublicAliasVisitor", +] diff --git a/torchfix/visitors/vision/__init__.py b/torchfix/visitors/vision/__init__.py index 16ec6c4..3a1745f 100644 --- a/torchfix/visitors/vision/__init__.py +++ b/torchfix/visitors/vision/__init__.py @@ -1,3 +1,9 @@ -from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401 -from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401 -from .singleton_import import TorchVisionSingletonImportVisitor # noqa: F401 +from .pretrained import TorchVisionDeprecatedPretrainedVisitor +from .singleton_import import TorchVisionSingletonImportVisitor +from .to_tensor import TorchVisionDeprecatedToTensorVisitor + +__all__ = [ + "TorchVisionDeprecatedPretrainedVisitor", + "TorchVisionDeprecatedToTensorVisitor", + "TorchVisionSingletonImportVisitor", +] From bd86ba630c9906ad9afd5c1a5c9626828e814a8e Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 2 Sep 2024 11:37:35 +0200 Subject: [PATCH 48/66] Remove unused assignments --- torchfix/common.py | 1 - torchfix/visitors/deprecated_symbols/range.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/torchfix/common.py b/torchfix/common.py index e04164e..de78eb2 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -158,7 +158,6 @@ def call_with_name_changes( new_call_name = new_qualified_name.removeprefix( commonprefix([qualified_name.removesuffix(call_name), new_qualified_name]) ) - new_call_name = new_call_name new_module_name = new_qualified_name.removesuffix(new_call_name).removesuffix( "." ) diff --git a/torchfix/visitors/deprecated_symbols/range.py b/torchfix/visitors/deprecated_symbols/range.py index 26f0a4f..2097af1 100644 --- a/torchfix/visitors/deprecated_symbols/range.py +++ b/torchfix/visitors/deprecated_symbols/range.py @@ -54,8 +54,6 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: else: return None - updated_end_arg = None - # `end` is a literal (positive) integer if isinstance(end_arg.value, cst.Integer): end = int(end_arg.value.value) + step From 76f61577569b949d04d581d406ebef16e0fbddd1 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 2 Sep 2024 11:46:29 +0200 Subject: [PATCH 49/66] Use unpacking instead of concatenation --- torchfix/visitors/misc/__init__.py | 2 +- torchfix/visitors/security/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index be5f1c9..ea8c7be 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -70,7 +70,7 @@ def visit_Call(self, node): use_reentrant_arg = cst.ensure_type( cst.parse_expression("f(use_reentrant=False)"), cst.Call ).args[0] - replacement = node.with_changes(args=node.args + (use_reentrant_arg,)) + replacement = node.with_changes(args=(*node.args, use_reentrant_arg)) self.add_violation( node, error_code=self.ERRORS[0].error_code, diff --git a/torchfix/visitors/security/__init__.py b/torchfix/visitors/security/__init__.py index e0ecc92..d1a9380 100644 --- a/torchfix/visitors/security/__init__.py +++ b/torchfix/visitors/security/__init__.py @@ -39,7 +39,7 @@ def visit_Call(self, node): weights_only_arg = cst.ensure_type( cst.parse_expression("f(weights_only=True)"), cst.Call ).args[0] - replacement = node.with_changes(args=node.args + (weights_only_arg,)) + replacement = node.with_changes(args=(*node.args, weights_only_arg)) self.add_violation( node, error_code=self.ERRORS[0].error_code, From 1f9f873f99eb33b6dd2f47e8a958bf884cb2e580 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 2 Sep 2024 11:49:19 +0200 Subject: [PATCH 50/66] Make mypy happy --- torchfix/common.py | 4 +-- .../visitors/deprecated_symbols/cholesky.py | 9 +++-- torchfix/visitors/deprecated_symbols/range.py | 34 ++++++++++++------- torchfix/visitors/nonpublic/__init__.py | 2 +- 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/torchfix/common.py b/torchfix/common.py index de78eb2..b5eeeec 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -2,7 +2,7 @@ from abc import ABC from dataclasses import dataclass from os.path import commonprefix -from typing import Dict, List, Optional, Sequence, Set, Tuple +from typing import List, Optional, Sequence, Set, Tuple, Mapping import libcst as cst from libcst.codemod.visitors import ImportItem @@ -180,7 +180,7 @@ def call_with_name_changes( def check_old_names_in_import_from( - node: cst.ImportFrom, old_new_name_map: Dict[str, Optional[str]] + node: cst.ImportFrom, old_new_name_map: Mapping[str, Optional[str]] ) -> Tuple[List[str], Optional[cst.ImportFrom]]: """ Using `old_new_name_map`, check if there are any old names in the import from. diff --git a/torchfix/visitors/deprecated_symbols/cholesky.py b/torchfix/visitors/deprecated_symbols/cholesky.py index c44c831..c80bf7e 100644 --- a/torchfix/visitors/deprecated_symbols/cholesky.py +++ b/torchfix/visitors/deprecated_symbols/cholesky.py @@ -19,9 +19,14 @@ def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode: and cst.ensure_type(upper_arg.value, cst.Name).value == "True" ): replacement = cst.parse_expression(f"{module_name}.linalg.cholesky(A).mH") + + # Make mypy happy + assert isinstance(replacement, (cst.Name, cst.Attribute)) + + old_node = cst.ensure_type(replacement.value, cst.Call).args replacement = replacement.with_deep_changes( - # Ignore type error, see https://github.com/Instagram/LibCST/issues/963 - old_node=cst.ensure_type(replacement.value, cst.Call).args, # type: ignore + # see https://github.com/Instagram/LibCST/issues/963 + old_node=old_node, # type: ignore[arg-type] value=[input_arg], ) else: diff --git a/torchfix/visitors/deprecated_symbols/range.py b/torchfix/visitors/deprecated_symbols/range.py index 2097af1..6bfae12 100644 --- a/torchfix/visitors/deprecated_symbols/range.py +++ b/torchfix/visitors/deprecated_symbols/range.py @@ -46,9 +46,10 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: step_arg, m.Arg(value=m.UnaryOperation(operator=m.Minus(), expression=m.Integer())), ): - # Ignore type error here and further in this file. - # See https://github.com/Instagram/LibCST/issues/964 - step = -int(step_arg.value.expression.value) # type: ignore + # make mypy happy + assert isinstance(step_arg.value, cst.UnaryOperation) + assert isinstance(step_arg.value.expression, cst.Integer) + step = -int(step_arg.value.expression.value) # Bail out, don't know how to update with non-integer `step`. else: @@ -75,10 +76,15 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: end_arg, m.Arg(value=m.UnaryOperation(operator=m.Minus(), expression=m.Integer())), ): - end = -int(end_arg.value.expression.value) + step # type: ignore + op = end_arg.value + # make mypy happy + assert isinstance(op, cst.UnaryOperation) + assert isinstance(op.expression, cst.Integer) + end = -int(op.expression.value) + step if end < 0: updated_end_arg = end_arg.with_deep_changes( - old_node=end_arg.value.expression, value=str(-end) # type: ignore + old_node=op.expression, + value=str(-end), ) else: # `end` became non-negative @@ -92,7 +98,9 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: value=m.BinaryOperation(operator=m.Subtract(), right=m.Integer(value="1")) ), ): - updated_end_arg = end_arg.with_changes(value=end_arg.value.left) # type: ignore + # make mypy happy + assert isinstance(end_arg.value, cst.BinaryOperation) + updated_end_arg = end_arg.with_changes(value=end_arg.value.left) # `end` something else: add `+ 1` at the end else: @@ -104,12 +112,14 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: ) ) - replacement = node if updated_end_arg is not None: - # Ignore type error, see https://github.com/Instagram/LibCST/issues/965 - replacement = replacement.deep_replace(end_arg, updated_end_arg) # type: ignore - replacement = replacement.with_deep_changes( + replacement = node.deep_replace(end_arg, updated_end_arg) + + # make mypy happy + assert isinstance(replacement, cst.Call) + else: + replacement = node + + return replacement.with_deep_changes( old_node=cst.ensure_type(replacement.func, cst.Attribute).attr, value="arange" ) - - return replacement diff --git a/torchfix/visitors/nonpublic/__init__.py b/torchfix/visitors/nonpublic/__init__.py index 6b052d2..281c0f3 100644 --- a/torchfix/visitors/nonpublic/__init__.py +++ b/torchfix/visitors/nonpublic/__init__.py @@ -73,7 +73,7 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: if node.module is None: return - private_names, replacement = check_old_names_in_import_from(node, self.ALIASES) # type: ignore[arg-type] # noqa: E501 + private_names, replacement = check_old_names_in_import_from(node, self.ALIASES) for qualified_name in private_names: public_name = self.ALIASES[qualified_name] error_code = self.ERRORS[1].error_code From 587ce3b4e14f36acc6683633d6c8c2c0aa402d57 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 2 Sep 2024 11:54:55 +0200 Subject: [PATCH 51/66] Simplify flow for readability --- torchfix/common.py | 75 +++++++++---------- .../visitors/deprecated_symbols/__init__.py | 26 +++---- .../deprecated_symbols/chain_matmul.py | 20 ++--- torchfix/visitors/deprecated_symbols/qr.py | 4 +- torchfix/visitors/deprecated_symbols/range.py | 9 +-- 5 files changed, 62 insertions(+), 72 deletions(-) diff --git a/torchfix/common.py b/torchfix/common.py index b5eeeec..5222576 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -129,8 +129,7 @@ def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]: name_metadata = list(self.get_metadata(QualifiedNameProvider, node)) if not name_metadata: return None - qualified_name = name_metadata[0].name - return qualified_name + return name_metadata[0].name def call_with_name_changes( @@ -175,8 +174,8 @@ def call_with_name_changes( # Replace with new_qualified_name. if replacement is None: return None - else: - return replacement, needed_imports + + return replacement, needed_imports def check_old_names_in_import_from( @@ -188,47 +187,45 @@ def check_old_names_in_import_from( 1. List of all founds old names. 2. Optional replacement for the ImportFrom node. """ - if node.module is None: + if node.module is None or not isinstance(node.names, Sequence): return [], None old_names: List[str] = [] replacement = None - if isinstance(node.names, Sequence): - new_names: List[str] = [] - module = cst.helpers.get_full_name_for_node(node.module) - - # `possible_new_modules` and `has_non_updated_names` are used - # to decide if we can replace the ImportFrom node. - new_modules: Set[str] = set() - has_non_updated_names = False - - for name in node.names: - qualified_name = f"{module}.{name.name.value}" - if qualified_name in old_new_name_map: - old_names.append(qualified_name) - new_qualified_name = old_new_name_map[qualified_name] - if new_qualified_name is not None: - new_module = ".".join(new_qualified_name.split(".")[:-1]) - new_name = new_qualified_name.split(".")[-1] - new_names.append(new_name) - new_modules.add(new_module) - else: - has_non_updated_names = True + new_names: List[str] = [] + module = cst.helpers.get_full_name_for_node(node.module) + + # `possible_new_modules` and `has_non_updated_names` are used + # to decide if we can replace the ImportFrom node. + new_modules: Set[str] = set() + has_non_updated_names = False + + for name in node.names: + qualified_name = f"{module}.{name.name.value}" + if qualified_name in old_new_name_map: + old_names.append(qualified_name) + new_qualified_name = old_new_name_map[qualified_name] + if new_qualified_name is not None: + new_module = ".".join(new_qualified_name.split(".")[:-1]) + new_name = new_qualified_name.split(".")[-1] + new_names.append(new_name) + new_modules.add(new_module) else: has_non_updated_names = True - - # Replace only if the new module is the same for all names in the import. - if not has_non_updated_names and len(new_modules) == 1: - new_module = new_modules.pop() - import_aliases = list(node.names) - for i in range(len(import_aliases)): - import_aliases[i] = import_aliases[i].with_changes( - name=cst.Name(new_names[i]) - ) - replacement = node.with_changes( - module=cst.parse_expression(new_module), # type: ignore[arg-type] # noqa: E501 - names=import_aliases, - ) + else: + has_non_updated_names = True + + # Replace only if the new module is the same for all names in the import. + if not has_non_updated_names and len(new_modules) == 1: + new_module = new_modules.pop() + import_aliases = [ + import_alias.with_changes(name=cst.Name(new_name)) + for import_alias, new_name in zip(list(node.names), new_names) + ] + replacement = node.with_changes( + module=cst.parse_expression(new_module), + names=import_aliases, + ) return old_names, replacement diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index cf088e8..20b6cca 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -53,19 +53,19 @@ def _call_replacement( replacement = None if qualified_name in replacements_map: - replacement = replacements_map[qualified_name](node) - else: - # Replace names for functions that have drop-in replacement. - function_name_replacement = self.deprecated_config.get( - qualified_name, {} - ).get("replacement", "") - if function_name_replacement: - replacement_and_imports = call_with_name_changes( - node, qualified_name, function_name_replacement - ) - if replacement_and_imports is not None: - replacement, imports = replacement_and_imports - self.needed_imports.update(imports) + return replacements_map[qualified_name](node) + + # Replace names for functions that have drop-in replacement. + function_name_replacement = self.deprecated_config.get(qualified_name, {}).get( + "replacement", "" + ) + if function_name_replacement: + replacement_and_imports = call_with_name_changes( + node, qualified_name, function_name_replacement + ) + if replacement_and_imports is not None: + replacement, imports = replacement_and_imports + self.needed_imports.update(imports) return replacement def visit_ImportFrom(self, node: cst.ImportFrom) -> None: diff --git a/torchfix/visitors/deprecated_symbols/chain_matmul.py b/torchfix/visitors/deprecated_symbols/chain_matmul.py index ca546c3..1e3873f 100644 --- a/torchfix/visitors/deprecated_symbols/chain_matmul.py +++ b/torchfix/visitors/deprecated_symbols/chain_matmul.py @@ -7,21 +7,17 @@ def call_replacement_chain_matmul(node: cst.Call) -> cst.CSTNode: Replace `torch.chain_matmul` with `torch.linalg.multi_dot`, changing multiple parameters to a list. """ - matrices = [] + matrices = [ + cst.Element(value=arg.value) for arg in node.args if arg.keyword is None + ] + matrices_arg = cst.Arg(value=cst.List(elements=matrices)) + out_arg = None for arg in node.args: - if arg.keyword is None: - matrices.append(cst.Element(value=arg.value)) - elif arg.keyword.value == "out": + if arg.keyword is not None and arg.keyword.value == "out": out_arg = arg - matrices_arg = cst.Arg(value=cst.List(elements=matrices)) - if out_arg is None: - replacement_args = [matrices_arg] - else: - replacement_args = [matrices_arg, out_arg] + replacement_args = [matrices_arg] if out_arg is None else [matrices_arg, out_arg] module_name = get_module_name(node, "torch") replacement = cst.parse_expression(f"{module_name}.linalg.multi_dot(args)") - replacement = replacement.with_changes(args=replacement_args) - - return replacement + return replacement.with_changes(args=replacement_args) diff --git a/torchfix/visitors/deprecated_symbols/qr.py b/torchfix/visitors/deprecated_symbols/qr.py index 9fc4874..ab81fcf 100644 --- a/torchfix/visitors/deprecated_symbols/qr.py +++ b/torchfix/visitors/deprecated_symbols/qr.py @@ -29,6 +29,4 @@ def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]: replacement_args = [input_arg] module_name = get_module_name(node, "torch") replacement = cst.parse_expression(f"{module_name}.linalg.qr(args)") - replacement = replacement.with_changes(args=replacement_args) - - return replacement + return replacement.with_changes(args=replacement_args) diff --git a/torchfix/visitors/deprecated_symbols/range.py b/torchfix/visitors/deprecated_symbols/range.py index 6bfae12..45e0580 100644 --- a/torchfix/visitors/deprecated_symbols/range.py +++ b/torchfix/visitors/deprecated_symbols/range.py @@ -18,11 +18,10 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: for arg in node.args: if arg.keyword is None: non_kw_args.append(arg) - else: - if arg.keyword.value == "end": - end_arg = arg - elif arg.keyword.value == "step": - step_arg = arg + elif arg.keyword.value == "end": + end_arg = arg + elif arg.keyword.value == "step": + step_arg = arg if end_arg is None: if len(non_kw_args) == 1: end_arg = non_kw_args[0] From 823782f193d01a624ea2e416615c5d404b631523 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 2 Sep 2024 12:20:17 +0200 Subject: [PATCH 52/66] `args.path` is never not set --- torchfix/__main__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index b8eaebf..eb17658 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -40,7 +40,7 @@ def StderrSilencer(redirect: bool = True): libc.close(orig_stderr) -def main() -> None: +def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( @@ -78,11 +78,11 @@ def main() -> None: action="store_true", ) - args = parser.parse_args() + return parser.parse_args() - if not args.path: - parser.print_usage() - sys.exit(1) + +def main() -> None: + args = _parse_args() files = codemod.gather_files(args.path) From 7855ac1a17700b76d1a46ce9038882b63c62c901 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 2 Sep 2024 12:27:11 +0200 Subject: [PATCH 53/66] Use dict comprehension --- torchfix/visitors/deprecated_symbols/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 20b6cca..6a05472 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -36,10 +36,10 @@ def read_deprecated_config(path=None): super().__init__() self.deprecated_config = read_deprecated_config(deprecated_config_path) - self.old_new_name_map = {} - for name in self.deprecated_config: - new_name = self.deprecated_config[name].get("replacement") - self.old_new_name_map[name] = new_name + self.old_new_name_map = { + name: self.deprecated_config[name].get("replacement") + for name in self.deprecated_config + } def _call_replacement( self, node: cst.Call, qualified_name: str From af3c6ff13dd8dd4cc937825ec11cd00739ff65a9 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Thu, 12 Sep 2024 08:59:43 +0200 Subject: [PATCH 54/66] Improving test case developer experience (#69) This PR updates the test fixtures to have the `.py` extension for codemod output files. This is helpful for writing new tests (I'll be adding some for the new rules) - Codemod output is now stored as `.out.py` instead of `.py.out`. This marks the files as Python source to IDEs, linters etc. For test discovery, the input files were renamed to `.in.py` from `.py`. Follow-up PR for https://github.com/pytorch-labs/torchfix/pull/67. Will rebase this one once that PR is merged. --- .../{aliased_import.py => aliased_import.in.py} | 0 ...liased_import.py.out => aliased_import.out.py} | 0 .../{chain_matmul.py => chain_matmul.in.py} | 0 .../{chain_matmul.py.out => chain_matmul.out.py} | 0 .../codemod/{cholesky.py => cholesky.in.py} | 0 .../codemod/{cholesky.py.out => cholesky.out.py} | 0 .../codemod/{functorch.py => functorch.in.py} | 0 .../{functorch.py.out => functorch.out.py} | 0 .../codemod/{ger-outer.py => ger-outer.in.py} | 0 .../{ger-outer.py.out => ger-outer.out.py} | 0 .../codemod/{qr.py => qr.in.py} | 0 .../codemod/{qr.py.out => qr.out.py} | 0 .../{range-arange.py => range-arange.in.py} | 0 .../{range-arange.py.out => range-arange.out.py} | 0 ..._pytree_node.py => register_pytree_node.in.py} | 0 ...ee_node.py.out => register_pytree_node.out.py} | 0 ...t_checkpoint.py => reentrant_checkpoint.in.py} | 0 ...ckpoint.py.out => reentrant_checkpoint.out.py} | 0 .../{require_grad.py => require_grad.in.py} | 0 .../{require_grad.py.out => require_grad.out.py} | 0 ...e_convert.py => default_collate_convert.in.py} | 0 ...vert.py.out => default_collate_convert.out.py} | 0 .../codemod/{pretrained.py => pretrained.in.py} | 0 .../{pretrained.py.out => pretrained.out.py} | 0 ...s_import.py => pretrained_models_import.in.py} | 0 ...ort.py.out => pretrained_models_import.out.py} | 0 ...one_import.py => pretrained_none_import.in.py} | 0 ...mport.py.out => pretrained_none_import.out.py} | 0 ...mport.py => pretrained_submodule_import.in.py} | 0 ....py.out => pretrained_submodule_import.out.py} | 0 ...singleton_import.py => singleton_import.in.py} | 0 ...eton_import.py.out => singleton_import.out.py} | 0 tests/test_torchfix.py | 15 +++++++++------ 33 files changed, 9 insertions(+), 6 deletions(-) rename tests/fixtures/deprecated_symbols/codemod/{aliased_import.py => aliased_import.in.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{aliased_import.py.out => aliased_import.out.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{chain_matmul.py => chain_matmul.in.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{chain_matmul.py.out => chain_matmul.out.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{cholesky.py => cholesky.in.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{cholesky.py.out => cholesky.out.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{functorch.py => functorch.in.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{functorch.py.out => functorch.out.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{ger-outer.py => ger-outer.in.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{ger-outer.py.out => ger-outer.out.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{qr.py => qr.in.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{qr.py.out => qr.out.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{range-arange.py => range-arange.in.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{range-arange.py.out => range-arange.out.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{register_pytree_node.py => register_pytree_node.in.py} (100%) rename tests/fixtures/deprecated_symbols/codemod/{register_pytree_node.py.out => register_pytree_node.out.py} (100%) rename tests/fixtures/misc/codemod/{reentrant_checkpoint.py => reentrant_checkpoint.in.py} (100%) rename tests/fixtures/misc/codemod/{reentrant_checkpoint.py.out => reentrant_checkpoint.out.py} (100%) rename tests/fixtures/misc/codemod/{require_grad.py => require_grad.in.py} (100%) rename tests/fixtures/misc/codemod/{require_grad.py.out => require_grad.out.py} (100%) rename tests/fixtures/nonpublic/codemod/{default_collate_convert.py => default_collate_convert.in.py} (100%) rename tests/fixtures/nonpublic/codemod/{default_collate_convert.py.out => default_collate_convert.out.py} (100%) rename tests/fixtures/vision/codemod/{pretrained.py => pretrained.in.py} (100%) rename tests/fixtures/vision/codemod/{pretrained.py.out => pretrained.out.py} (100%) rename tests/fixtures/vision/codemod/{pretrained_models_import.py => pretrained_models_import.in.py} (100%) rename tests/fixtures/vision/codemod/{pretrained_models_import.py.out => pretrained_models_import.out.py} (100%) rename tests/fixtures/vision/codemod/{pretrained_none_import.py => pretrained_none_import.in.py} (100%) rename tests/fixtures/vision/codemod/{pretrained_none_import.py.out => pretrained_none_import.out.py} (100%) rename tests/fixtures/vision/codemod/{pretrained_submodule_import.py => pretrained_submodule_import.in.py} (100%) rename tests/fixtures/vision/codemod/{pretrained_submodule_import.py.out => pretrained_submodule_import.out.py} (100%) rename tests/fixtures/vision/codemod/{singleton_import.py => singleton_import.in.py} (100%) rename tests/fixtures/vision/codemod/{singleton_import.py.out => singleton_import.out.py} (100%) diff --git a/tests/fixtures/deprecated_symbols/codemod/aliased_import.py b/tests/fixtures/deprecated_symbols/codemod/aliased_import.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/aliased_import.py rename to tests/fixtures/deprecated_symbols/codemod/aliased_import.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/aliased_import.py.out b/tests/fixtures/deprecated_symbols/codemod/aliased_import.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/aliased_import.py.out rename to tests/fixtures/deprecated_symbols/codemod/aliased_import.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/chain_matmul.py b/tests/fixtures/deprecated_symbols/codemod/chain_matmul.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/chain_matmul.py rename to tests/fixtures/deprecated_symbols/codemod/chain_matmul.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/chain_matmul.py.out b/tests/fixtures/deprecated_symbols/codemod/chain_matmul.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/chain_matmul.py.out rename to tests/fixtures/deprecated_symbols/codemod/chain_matmul.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/cholesky.py b/tests/fixtures/deprecated_symbols/codemod/cholesky.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/cholesky.py rename to tests/fixtures/deprecated_symbols/codemod/cholesky.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/cholesky.py.out b/tests/fixtures/deprecated_symbols/codemod/cholesky.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/cholesky.py.out rename to tests/fixtures/deprecated_symbols/codemod/cholesky.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/functorch.py b/tests/fixtures/deprecated_symbols/codemod/functorch.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/functorch.py rename to tests/fixtures/deprecated_symbols/codemod/functorch.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/functorch.py.out b/tests/fixtures/deprecated_symbols/codemod/functorch.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/functorch.py.out rename to tests/fixtures/deprecated_symbols/codemod/functorch.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py b/tests/fixtures/deprecated_symbols/codemod/ger-outer.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/ger-outer.py rename to tests/fixtures/deprecated_symbols/codemod/ger-outer.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out b/tests/fixtures/deprecated_symbols/codemod/ger-outer.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out rename to tests/fixtures/deprecated_symbols/codemod/ger-outer.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/qr.py b/tests/fixtures/deprecated_symbols/codemod/qr.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/qr.py rename to tests/fixtures/deprecated_symbols/codemod/qr.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/qr.py.out b/tests/fixtures/deprecated_symbols/codemod/qr.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/qr.py.out rename to tests/fixtures/deprecated_symbols/codemod/qr.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/range-arange.py b/tests/fixtures/deprecated_symbols/codemod/range-arange.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/range-arange.py rename to tests/fixtures/deprecated_symbols/codemod/range-arange.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/range-arange.py.out b/tests/fixtures/deprecated_symbols/codemod/range-arange.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/range-arange.py.out rename to tests/fixtures/deprecated_symbols/codemod/range-arange.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py rename to tests/fixtures/deprecated_symbols/codemod/register_pytree_node.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/register_pytree_node.py.out rename to tests/fixtures/deprecated_symbols/codemod/register_pytree_node.out.py diff --git a/tests/fixtures/misc/codemod/reentrant_checkpoint.py b/tests/fixtures/misc/codemod/reentrant_checkpoint.in.py similarity index 100% rename from tests/fixtures/misc/codemod/reentrant_checkpoint.py rename to tests/fixtures/misc/codemod/reentrant_checkpoint.in.py diff --git a/tests/fixtures/misc/codemod/reentrant_checkpoint.py.out b/tests/fixtures/misc/codemod/reentrant_checkpoint.out.py similarity index 100% rename from tests/fixtures/misc/codemod/reentrant_checkpoint.py.out rename to tests/fixtures/misc/codemod/reentrant_checkpoint.out.py diff --git a/tests/fixtures/misc/codemod/require_grad.py b/tests/fixtures/misc/codemod/require_grad.in.py similarity index 100% rename from tests/fixtures/misc/codemod/require_grad.py rename to tests/fixtures/misc/codemod/require_grad.in.py diff --git a/tests/fixtures/misc/codemod/require_grad.py.out b/tests/fixtures/misc/codemod/require_grad.out.py similarity index 100% rename from tests/fixtures/misc/codemod/require_grad.py.out rename to tests/fixtures/misc/codemod/require_grad.out.py diff --git a/tests/fixtures/nonpublic/codemod/default_collate_convert.py b/tests/fixtures/nonpublic/codemod/default_collate_convert.in.py similarity index 100% rename from tests/fixtures/nonpublic/codemod/default_collate_convert.py rename to tests/fixtures/nonpublic/codemod/default_collate_convert.in.py diff --git a/tests/fixtures/nonpublic/codemod/default_collate_convert.py.out b/tests/fixtures/nonpublic/codemod/default_collate_convert.out.py similarity index 100% rename from tests/fixtures/nonpublic/codemod/default_collate_convert.py.out rename to tests/fixtures/nonpublic/codemod/default_collate_convert.out.py diff --git a/tests/fixtures/vision/codemod/pretrained.py b/tests/fixtures/vision/codemod/pretrained.in.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained.py rename to tests/fixtures/vision/codemod/pretrained.in.py diff --git a/tests/fixtures/vision/codemod/pretrained.py.out b/tests/fixtures/vision/codemod/pretrained.out.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained.py.out rename to tests/fixtures/vision/codemod/pretrained.out.py diff --git a/tests/fixtures/vision/codemod/pretrained_models_import.py b/tests/fixtures/vision/codemod/pretrained_models_import.in.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_models_import.py rename to tests/fixtures/vision/codemod/pretrained_models_import.in.py diff --git a/tests/fixtures/vision/codemod/pretrained_models_import.py.out b/tests/fixtures/vision/codemod/pretrained_models_import.out.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_models_import.py.out rename to tests/fixtures/vision/codemod/pretrained_models_import.out.py diff --git a/tests/fixtures/vision/codemod/pretrained_none_import.py b/tests/fixtures/vision/codemod/pretrained_none_import.in.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_none_import.py rename to tests/fixtures/vision/codemod/pretrained_none_import.in.py diff --git a/tests/fixtures/vision/codemod/pretrained_none_import.py.out b/tests/fixtures/vision/codemod/pretrained_none_import.out.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_none_import.py.out rename to tests/fixtures/vision/codemod/pretrained_none_import.out.py diff --git a/tests/fixtures/vision/codemod/pretrained_submodule_import.py b/tests/fixtures/vision/codemod/pretrained_submodule_import.in.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_submodule_import.py rename to tests/fixtures/vision/codemod/pretrained_submodule_import.in.py diff --git a/tests/fixtures/vision/codemod/pretrained_submodule_import.py.out b/tests/fixtures/vision/codemod/pretrained_submodule_import.out.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_submodule_import.py.out rename to tests/fixtures/vision/codemod/pretrained_submodule_import.out.py diff --git a/tests/fixtures/vision/codemod/singleton_import.py b/tests/fixtures/vision/codemod/singleton_import.in.py similarity index 100% rename from tests/fixtures/vision/codemod/singleton_import.py rename to tests/fixtures/vision/codemod/singleton_import.in.py diff --git a/tests/fixtures/vision/codemod/singleton_import.py.out b/tests/fixtures/vision/codemod/singleton_import.out.py similarity index 100% rename from tests/fixtures/vision/codemod/singleton_import.py.out rename to tests/fixtures/vision/codemod/singleton_import.out.py diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index cc3631e..56fd05c 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -24,7 +24,7 @@ def pytest_generate_tests(metafunc): "checker_source_path", files, ids=[file_name.stem for file_name in files] ) if "codemod_source_path" in metafunc.fixturenames: - files = list(FIXTURES_PATH.glob("**/codemod/*.py")) + files = list(FIXTURES_PATH.glob("**/codemod/*.in.py")) metafunc.parametrize( "codemod_source_path", files, ids=[file_name.stem for file_name in files] ) @@ -64,15 +64,18 @@ def test_empty(): def test_checker_fixtures(checker_source_path: Path): expected_path = checker_source_path.with_suffix(".txt") expected_results = expected_path.read_text().splitlines() - - assert ( - _checker_results(checker_source_path.read_text().splitlines(keepends=True)) - == expected_results + results = _checker_results( + checker_source_path.read_text().splitlines(keepends=True) ) + # Overwrite the expected data with the results (useful when updating tests) + # expected_path.write_text("".join([f"{line}\n" for line in results])) + assert results == expected_results def test_codemod_fixtures(codemod_source_path: Path): - expected_path = codemod_source_path.with_suffix(".py.out") + expected_path = codemod_source_path.with_stem( + codemod_source_path.stem.replace(".in", ".out") + ) expected_results = expected_path.read_text() assert _codemod_results(codemod_source_path) == expected_results From aef3ea17db107e9e0c015ffd4a223bc19127cb11 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Sat, 14 Sep 2024 00:41:47 +0200 Subject: [PATCH 55/66] DOC(README): rule code assignment policy (#72) The documentation should explain the rule code assignment policy, so that its clear which code to pick for newly developed rules. This was my best guess from the existing rules. Is it correct? --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index 19b3e73..b51e24b 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,16 @@ To enable them, use standard flake8 configuration options for the plugin mode or If you encounter a bug or some other problem with TorchFix, please file an issue on https://github.com/pytorch-labs/torchfix/issues. +## Rule Code Assignment Policy + +New rule codes are assigned incrementally across the following categories: + +* **TOR0XX, TOR1XX**: General-purpose `torch` functionality. +* **TOR2XX**: Domain-specific rules, such as TorchVision. +* **TOR4XX**: Noisy rules that are disabled by default. +* **TOR9XX**: Internal rules specific for `pytorch/pytorch` repo, other users should not use these. + +TOR0, TOR1 and TOR2 are enabled by default. ## Rules From 7bf6f67f5220aaf9a99b2ba64f9df9814df10738 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 17 Sep 2024 10:43:32 -0700 Subject: [PATCH 56/66] Add TorchLog1pVisitor (#77) Suggest using `torch.log1p(x)` instead of `torch.log(1 + x)`. Only a checker for now, probably not worth spending time on a codemod as it's not a very common issue. I used this to create a PR to botorch https://github.com/pytorch/botorch/pull/2539 --- tests/fixtures/misc/checker/log1p.py | 9 ++++++ tests/fixtures/misc/checker/log1p.txt | 4 +++ tests/test_torchfix.py | 5 ++- torchfix/torchfix.py | 8 +++-- torchfix/visitors/__init__.py | 13 +++++--- torchfix/visitors/misc/__init__.py | 44 +++++++++++++++++++++++++++ 6 files changed, 75 insertions(+), 8 deletions(-) create mode 100644 tests/fixtures/misc/checker/log1p.py create mode 100644 tests/fixtures/misc/checker/log1p.txt diff --git a/tests/fixtures/misc/checker/log1p.py b/tests/fixtures/misc/checker/log1p.py new file mode 100644 index 0000000..6afffdb --- /dev/null +++ b/tests/fixtures/misc/checker/log1p.py @@ -0,0 +1,9 @@ +import torch +a = torch.randn(5) +b = torch.log(1 + a) +c = torch.log(a + 1) +b = torch.log(1.0 + a) +c = torch.log(a + 1.0) + +# False negative: can not detect currently +x = (a + 1).log() diff --git a/tests/fixtures/misc/checker/log1p.txt b/tests/fixtures/misc/checker/log1p.txt new file mode 100644 index 0000000..3bcbeac --- /dev/null +++ b/tests/fixtures/misc/checker/log1p.txt @@ -0,0 +1,4 @@ +3:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. +4:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. +5:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. +6:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 56fd05c..29f6cf9 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -35,7 +35,10 @@ def pytest_generate_tests(metafunc): ("ALL,TOR102", GET_ALL_ERROR_CODES()), ("TOR102", {"TOR102"}), ("TOR102,TOR101", {"TOR102", "TOR101"}), - ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), + ( + "TOR1,TOR102", + {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105", "TOR106"}, + ), (None, set(GET_ALL_ERROR_CODES()) - exclude_set), ] metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases]) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 24e88b5..80acb1c 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -9,6 +9,7 @@ from .visitors import ( TorchDeprecatedSymbolsVisitor, + TorchLog1pVisitor, TorchNonPublicAliasVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, @@ -28,15 +29,16 @@ ALL_VISITOR_CLS = [ TorchDeprecatedSymbolsVisitor, + TorchLog1pVisitor, + TorchNonPublicAliasVisitor, TorchRequireGradVisitor, + TorchReentrantCheckpointVisitor, TorchScopedLibraryVisitor, TorchSynchronizedDataLoaderVisitor, + TorchUnsafeLoadVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, - TorchUnsafeLoadVisitor, - TorchReentrantCheckpointVisitor, - TorchNonPublicAliasVisitor, ] diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index af2b62b..f63e405 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -1,6 +1,10 @@ from .deprecated_symbols import TorchDeprecatedSymbolsVisitor from .internal import TorchScopedLibraryVisitor -from .misc import TorchReentrantCheckpointVisitor, TorchRequireGradVisitor +from .misc import ( + TorchReentrantCheckpointVisitor, + TorchRequireGradVisitor, + TorchLog1pVisitor, +) from .nonpublic import TorchNonPublicAliasVisitor from .performance import TorchSynchronizedDataLoaderVisitor from .security import TorchUnsafeLoadVisitor @@ -12,13 +16,14 @@ __all__ = [ "TorchDeprecatedSymbolsVisitor", + "TorchLog1pVisitor", + "TorchNonPublicAliasVisitor", + "TorchReentrantCheckpointVisitor", "TorchRequireGradVisitor", "TorchScopedLibraryVisitor", "TorchSynchronizedDataLoaderVisitor", + "TorchUnsafeLoadVisitor", "TorchVisionDeprecatedPretrainedVisitor", "TorchVisionDeprecatedToTensorVisitor", "TorchVisionSingletonImportVisitor", - "TorchUnsafeLoadVisitor", - "TorchReentrantCheckpointVisitor", - "TorchNonPublicAliasVisitor", ] diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index ea8c7be..edc6809 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -77,3 +77,47 @@ def visit_Call(self, node): message=self.ERRORS[0].message(), replacement=replacement, ) + + +class TorchLog1pVisitor(TorchVisitor): + """ + Suggest using `torch.log1p(x)` instead of `torch.log(1 + x)`. + """ + + ERRORS = [ + TorchError( + "TOR106", + ( + "Use `torch.log1p(x)` instead of `torch.log(1 + x)`. " + "It is more accurate for small values of `x`." + ), + ) + ] + + def visit_Call(self, node): + if self.get_qualified_name_for_call(node) == "torch.log": + + if m.matches( + node, + m.Call( + args=[ + m.Arg( + value=m.BinaryOperation( + left=m.Integer(value="1") | m.Float(value="1.0"), + operator=m.Add(), + ) + | m.BinaryOperation( + operator=m.Add(), + right=m.Integer(value="1") | m.Float(value="1.0"), + ), + ), + ], + ), + ): + + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From e0988aa1b10ca7d319c752a2a095f00cd676637a Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 17 Sep 2024 14:52:56 -0700 Subject: [PATCH 57/66] Add TorchExpm1Visitor (#78) Follow-up for https://github.com/pytorch-labs/torchfix/pull/77 --- tests/fixtures/misc/checker/expm1.py | 12 ++++++++++ tests/fixtures/misc/checker/expm1.txt | 3 +++ tests/test_torchfix.py | 10 +++++++- torchfix/torchfix.py | 2 ++ torchfix/visitors/__init__.py | 4 +++- torchfix/visitors/misc/__init__.py | 33 +++++++++++++++++++++++++++ 6 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 tests/fixtures/misc/checker/expm1.py create mode 100644 tests/fixtures/misc/checker/expm1.txt diff --git a/tests/fixtures/misc/checker/expm1.py b/tests/fixtures/misc/checker/expm1.py new file mode 100644 index 0000000..4a7646d --- /dev/null +++ b/tests/fixtures/misc/checker/expm1.py @@ -0,0 +1,12 @@ +import torch +a = torch.randn(5) +b = torch.exp(a) - 1 +c = torch.exp(a) - 1.0 + +ret = (torch.exp(a) - 1) * torch.exp(2 * b) + +# False negative: can not detect currently +x = a.exp() - 1 + +# False negative: should be rare and would complicate implementation +x = -1 + torch.exp(a) diff --git a/tests/fixtures/misc/checker/expm1.txt b/tests/fixtures/misc/checker/expm1.txt new file mode 100644 index 0000000..ed24905 --- /dev/null +++ b/tests/fixtures/misc/checker/expm1.txt @@ -0,0 +1,3 @@ +3:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`. +4:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`. +6:7 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`. diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 29f6cf9..7b79fb7 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -37,7 +37,15 @@ def pytest_generate_tests(metafunc): ("TOR102,TOR101", {"TOR102", "TOR101"}), ( "TOR1,TOR102", - {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105", "TOR106"}, + { + "TOR101", + "TOR102", + "TOR103", + "TOR104", + "TOR105", + "TOR106", + "TOR107", + }, ), (None, set(GET_ALL_ERROR_CODES()) - exclude_set), ] diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 80acb1c..0798a5e 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -9,6 +9,7 @@ from .visitors import ( TorchDeprecatedSymbolsVisitor, + TorchExpm1Visitor, TorchLog1pVisitor, TorchNonPublicAliasVisitor, TorchReentrantCheckpointVisitor, @@ -29,6 +30,7 @@ ALL_VISITOR_CLS = [ TorchDeprecatedSymbolsVisitor, + TorchExpm1Visitor, TorchLog1pVisitor, TorchNonPublicAliasVisitor, TorchRequireGradVisitor, diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index f63e405..8e56b4a 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -1,9 +1,10 @@ from .deprecated_symbols import TorchDeprecatedSymbolsVisitor from .internal import TorchScopedLibraryVisitor from .misc import ( + TorchExpm1Visitor, + TorchLog1pVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, - TorchLog1pVisitor, ) from .nonpublic import TorchNonPublicAliasVisitor from .performance import TorchSynchronizedDataLoaderVisitor @@ -16,6 +17,7 @@ __all__ = [ "TorchDeprecatedSymbolsVisitor", + "TorchExpm1Visitor", "TorchLog1pVisitor", "TorchNonPublicAliasVisitor", "TorchReentrantCheckpointVisitor", diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index edc6809..348612c 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -121,3 +121,36 @@ def visit_Call(self, node): message=self.ERRORS[0].message(), replacement=None, ) + + +class TorchExpm1Visitor(TorchVisitor): + """ + Suggest using `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. + """ + + ERRORS = [ + TorchError( + "TOR107", + ( + "Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. " + "It is more accurate for small values of `x`." + ), + ) + ] + + def visit_BinaryOperation(self, node): + if m.matches( + node, + m.BinaryOperation( + left=m.Call(), + operator=m.Subtract(), + right=m.Integer(value="1") | m.Float(value="1.0"), + ), + ): + if self.get_qualified_name_for_call(node.left) == "torch.exp": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From 6d37cc5287fdd3b2a6911d56af5958975bd1c538 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 17 Sep 2024 17:35:13 -0700 Subject: [PATCH 58/66] Bump version to 0.6.0 Preparing 0.6.0 release, which is probably a bit overdue already. Also it's nice to align the release with PyTorch conference. - Added `torch.utils._pytree._register_pytree_node` and `torch.backends.cuda.sdp_kernel` to the deprecated APIs rules - Enhanced rule TOR203 to support `torchvision.datasets` and `transforms` in addition to `models` - Added rules TOR106 and TOR107 to suggest replacing `torch.log(1 + x)` and `torch.exp(x) - 1` with more numerically stable equivalents - Multiple code refactorings, bug fixes, and quality of life and documentation improvements --- torchfix/torchfix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 0798a5e..a14cd9f 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -22,7 +22,7 @@ TorchVisionSingletonImportVisitor, ) -__version__ = "0.5.0" +__version__ = "0.6.0" DEPRECATED_CONFIG_PATH = "deprecated_symbols.yaml" From 87289c14d4548e9692e038704ca321f8046addee Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Thu, 19 Sep 2024 23:34:54 +0200 Subject: [PATCH 59/66] Augment deprecated symbols (#80) Extended the existing symbols --- torchfix/deprecated_symbols.yaml | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torchfix/deprecated_symbols.yaml b/torchfix/deprecated_symbols.yaml index a24f8f3..e219b2c 100644 --- a/torchfix/deprecated_symbols.yaml +++ b/torchfix/deprecated_symbols.yaml @@ -25,11 +25,11 @@ replacement: torch.outer - name: torch.lu_solve - deprecate_pr: TBA + deprecate_pr: https://github.com/pytorch/pytorch/pull/73806 remove_pr: - name: torch.norm - deprecate_pr: TBA + deprecate_pr: https://github.com/pytorch/pytorch/pull/57986 remove_pr: - name: torch.range @@ -45,9 +45,17 @@ remove_pr: - name: torch.lu - deprecate_pr: TBA + deprecate_pr: https://github.com/pytorch/pytorch/pull/73804 remove_pr: +- name: torch.matrix_rank + deprecate_pr: https://github.com/pytorch/pytorch/pull/57734 + remove_pr: https://github.com/pytorch/pytorch/pull/70981 + +- name: torch.lstsq + deprecate_pr: https://github.com/pytorch/pytorch/pull/57743 + remove_pr: https://github.com/pytorch/pytorch/pull/70980 + - name: torch.nn.UpsamplingNearest2d deprecate_pr: TBA remove_pr: From 6bfffd9e09900f0c50aca88b367b5725aa10a874 Mon Sep 17 00:00:00 2001 From: Zack Leman Date: Mon, 11 Nov 2024 17:49:59 -0800 Subject: [PATCH 60/66] Fix: Handle no valid python files in the directory. (#83) If you run torchfix script on a directory without Python files, it will terminate with an error. This modifies the script to just do nothing in such cases, without terminating with an error. ### Testing: Added test "test_no_python_files" ### Without fix the new test fails: assert 1 == 0 ...... raise Exception("Must have at least one job to process!")\nException: Must have at least one job to process!\n').returncode ...... --- tests/test_torchfix.py | 29 +++++++++++++++++++++++------ torchfix/__main__.py | 15 +++++++++------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 7b79fb7..a31bc1b 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -1,16 +1,18 @@ +import logging +import subprocess from pathlib import Path + +import libcst.codemod as codemod from torchfix.torchfix import ( - TorchChecker, - TorchCodemod, - TorchCodemodConfig, DISABLED_BY_DEFAULT, expand_error_codes, - GET_ALL_VISITORS, GET_ALL_ERROR_CODES, + GET_ALL_VISITORS, process_error_code_str, + TorchChecker, + TorchCodemod, + TorchCodemodConfig, ) -import logging -import libcst.codemod as codemod FIXTURES_PATH = Path(__file__).absolute().parent / "fixtures" LOGGER = logging.getLogger(__name__) @@ -103,3 +105,18 @@ def test_errorcodes_distinct(): def test_parse_error_code_str(case, expected): assert process_error_code_str(case) == expected + + +def test_no_python_files(tmp_path): + # Create a temporary directory with no Python files + non_python_file = tmp_path / "not_a_python_file.txt" + non_python_file.write_text("This is not a Python file") + + # Run torchfix on the temporary directory + result = subprocess.run( + ["python3", "-m", "torchfix", str(tmp_path)], + capture_output=True, + text=True, + ) + # Check that the script exits successfully + assert result.returncode == 0 diff --git a/torchfix/__main__.py b/torchfix/__main__.py index eb17658..4964e01 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -1,20 +1,22 @@ import argparse -import libcst.codemod as codemod import contextlib import ctypes -import sys import io +import sys + +import libcst.codemod as codemod + +from .common import CYAN, ENDC from .torchfix import ( - TorchCodemod, - TorchCodemodConfig, __version__ as TorchFixVersion, DISABLED_BY_DEFAULT, GET_ALL_ERROR_CODES, process_error_code_str, + TorchCodemod, + TorchCodemodConfig, ) -from .common import CYAN, ENDC # Should get rid of this code eventually. @@ -83,7 +85,6 @@ def _parse_args() -> argparse.Namespace: def main() -> None: args = _parse_args() - files = codemod.gather_files(args.path) # Filter out files that don't have "torch" string in them. @@ -97,6 +98,8 @@ def main() -> None: torch_files.append(file) break + if not torch_files: + return config = TorchCodemodConfig() config.select = list(process_error_code_str(args.select)) command_instance = TorchCodemod(codemod.CodemodContext(), config) From 9bd4fb532757f14cceaa31d5312a6748904019b0 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:06:27 -0800 Subject: [PATCH 61/66] Update libcst dependency to 1.5 (#85) To enable python-3.13 support, see https://github.com/pytorch-labs/torchfix/issues/84 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7114063..48ac050 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,14 +16,14 @@ dynamic = ["version"] dependencies = [ "flake8>=3.8.2", "PyYAML", - "libcst>=1.1.0,<1.2.0" + "libcst>=1.5.0,<1.6.0" ] [project.optional-dependencies] dev = [ "flake8==6.0.0", "pytest==7.2.0", - "libcst==1.1.0", + "libcst==1.5.0", "types-PyYAML==6.0.7", "mypy==1.7.0", "black==24.4.0", From 86186f4b678bc3534fe0fe4a43997bc90cfc45a5 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 14 Nov 2024 17:16:45 -0800 Subject: [PATCH 62/66] Bump version to 0.7.0 (#86) Preparing 0.7.0 release - Updated libCST dependency to 1.5.0 to support running on Python 3.13 - Added `torch.matrix_rank` and `torch.lstsq` to the list of deprecated/removed APIs --- torchfix/torchfix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index a14cd9f..1cb3e69 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -22,7 +22,7 @@ TorchVisionSingletonImportVisitor, ) -__version__ = "0.6.0" +__version__ = "0.7.0" DEPRECATED_CONFIG_PATH = "deprecated_symbols.yaml" From 4ff3caf4e690bb44a5ee9ca8367b8bf65b1bbe73 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 12 Dec 2024 21:59:30 -0800 Subject: [PATCH 63/66] Add rules for deprecated AMP APIs (#87) Add codemods for `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast`, and checkers for `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. --- .../deprecated_symbols/checker/amp.py | 10 +++++++ .../deprecated_symbols/checker/amp.txt | 6 +++++ .../deprecated_symbols/codemod/amp.in.py | 11 ++++++++ .../deprecated_symbols/codemod/amp.out.py | 11 ++++++++ torchfix/deprecated_symbols.yaml | 16 ++++++++++++ .../visitors/deprecated_symbols/__init__.py | 17 +++++++----- torchfix/visitors/deprecated_symbols/amp.py | 26 +++++++++++++++++++ 7 files changed, 91 insertions(+), 6 deletions(-) create mode 100644 tests/fixtures/deprecated_symbols/checker/amp.py create mode 100644 tests/fixtures/deprecated_symbols/checker/amp.txt create mode 100644 tests/fixtures/deprecated_symbols/codemod/amp.in.py create mode 100644 tests/fixtures/deprecated_symbols/codemod/amp.out.py create mode 100644 torchfix/visitors/deprecated_symbols/amp.py diff --git a/tests/fixtures/deprecated_symbols/checker/amp.py b/tests/fixtures/deprecated_symbols/checker/amp.py new file mode 100644 index 0000000..278ac39 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/amp.py @@ -0,0 +1,10 @@ +import torch + +torch.cuda.amp.autocast() +torch.cuda.amp.custom_fwd() +torch.cuda.amp.custom_bwd() + +dtype = torch.float32 +maybe_autocast = torch.cpu.amp.autocast() +maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16) +maybe_autocast = torch.cpu.amp.autocast(dtype=dtype) diff --git a/tests/fixtures/deprecated_symbols/checker/amp.txt b/tests/fixtures/deprecated_symbols/checker/amp.txt new file mode 100644 index 0000000..71939e9 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/amp.txt @@ -0,0 +1,6 @@ +3:1 TOR101 Use of deprecated function torch.cuda.amp.autocast +4:1 TOR101 Use of deprecated function torch.cuda.amp.custom_fwd +5:1 TOR101 Use of deprecated function torch.cuda.amp.custom_bwd +8:18 TOR101 Use of deprecated function torch.cpu.amp.autocast +9:18 TOR101 Use of deprecated function torch.cpu.amp.autocast +10:18 TOR101 Use of deprecated function torch.cpu.amp.autocast diff --git a/tests/fixtures/deprecated_symbols/codemod/amp.in.py b/tests/fixtures/deprecated_symbols/codemod/amp.in.py new file mode 100644 index 0000000..6a1227c --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/amp.in.py @@ -0,0 +1,11 @@ +import torch + +dtype = torch.float32 + +maybe_autocast = torch.cuda.amp.autocast() +maybe_autocast = torch.cuda.amp.autocast(dtype=torch.bfloat16) +maybe_autocast = torch.cuda.amp.autocast(dtype=dtype) + +maybe_autocast = torch.cpu.amp.autocast() +maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16) +maybe_autocast = torch.cpu.amp.autocast(dtype=dtype) diff --git a/tests/fixtures/deprecated_symbols/codemod/amp.out.py b/tests/fixtures/deprecated_symbols/codemod/amp.out.py new file mode 100644 index 0000000..da39d0a --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/amp.out.py @@ -0,0 +1,11 @@ +import torch + +dtype = torch.float32 + +maybe_autocast = torch.amp.autocast("cuda") +maybe_autocast = torch.amp.autocast("cuda", dtype=torch.bfloat16) +maybe_autocast = torch.amp.autocast("cuda", dtype=dtype) + +maybe_autocast = torch.amp.autocast("cpu") +maybe_autocast = torch.amp.autocast("cpu", dtype=torch.bfloat16) +maybe_autocast = torch.amp.autocast("cpu", dtype=dtype) diff --git a/torchfix/deprecated_symbols.yaml b/torchfix/deprecated_symbols.yaml index e219b2c..eaa5119 100644 --- a/torchfix/deprecated_symbols.yaml +++ b/torchfix/deprecated_symbols.yaml @@ -83,6 +83,22 @@ remove_pr: reference: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +- name: torch.cuda.amp.autocast + deprecate_pr: TBA + remove_pr: + +- name: torch.cuda.amp.custom_fwd + deprecate_pr: TBA + remove_pr: + +- name: torch.cuda.amp.custom_bwd + deprecate_pr: TBA + remove_pr: + +- name: torch.cpu.amp.autocast + deprecate_pr: TBA + remove_pr: + # functorch - name: functorch.vmap deprecate_pr: TBA diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 6a05472..40885ee 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -1,20 +1,23 @@ -import libcst as cst import pkgutil +from typing import List, Optional + +import libcst as cst import yaml -from typing import Optional, List from ...common import ( - TorchVisitor, - TorchError, call_with_name_changes, check_old_names_in_import_from, + TorchError, + TorchVisitor, ) -from .range import call_replacement_range -from .cholesky import call_replacement_cholesky +from .amp import call_replacement_cpu_amp_autocast, call_replacement_cuda_amp_autocast from .chain_matmul import call_replacement_chain_matmul +from .cholesky import call_replacement_cholesky from .qr import call_replacement_qr +from .range import call_replacement_range + class TorchDeprecatedSymbolsVisitor(TorchVisitor): ERRORS: List[TorchError] = [ @@ -49,6 +52,8 @@ def _call_replacement( "torch.range": call_replacement_range, "torch.chain_matmul": call_replacement_chain_matmul, "torch.qr": call_replacement_qr, + "torch.cuda.amp.autocast": call_replacement_cuda_amp_autocast, + "torch.cpu.amp.autocast": call_replacement_cpu_amp_autocast, } replacement = None diff --git a/torchfix/visitors/deprecated_symbols/amp.py b/torchfix/visitors/deprecated_symbols/amp.py new file mode 100644 index 0000000..9aa87c7 --- /dev/null +++ b/torchfix/visitors/deprecated_symbols/amp.py @@ -0,0 +1,26 @@ +import libcst as cst + +from ...common import get_module_name + + +def call_replacement_cpu_amp_autocast(node: cst.Call) -> cst.CSTNode: + return _call_replacement_amp(node, "cpu") + + +def call_replacement_cuda_amp_autocast(node: cst.Call) -> cst.CSTNode: + return _call_replacement_amp(node, "cuda") + + +def _call_replacement_amp(node: cst.Call, device: str) -> cst.CSTNode: + """ + Replace `torch.cuda.amp.autocast()` with `torch.amp.autocast("cuda")` and + Replace `torch.cpu.amp.autocast()` with `torch.amp.autocast("cpu")`. + """ + device_arg = cst.ensure_type(cst.parse_expression(f'f("{device}")'), cst.Call).args[ + 0 + ] + + module_name = get_module_name(node, "torch") + replacement = cst.parse_expression(f"{module_name}.amp.autocast(args)") + replacement = replacement.with_changes(args=(device_arg, *node.args)) + return replacement From 28f1a5f71524aafc354db5b02b4e920506165500 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 6 Jan 2025 16:00:54 -0800 Subject: [PATCH 64/66] Add TorchLogsumexpVisitor (#89) Suggest using `torch.logsumexp(x)` instead of `torch.log(torch.sum(torch.exp(x))`. https://pytorch.org/docs/stable/generated/torch.logsumexp.html --- tests/fixtures/misc/checker/logsumexp.py | 14 ++++++++ tests/fixtures/misc/checker/logsumexp.txt | 2 ++ tests/test_torchfix.py | 1 + torchfix/torchfix.py | 2 ++ torchfix/visitors/__init__.py | 2 ++ torchfix/visitors/misc/__init__.py | 40 +++++++++++++++++++++-- 6 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 tests/fixtures/misc/checker/logsumexp.py create mode 100644 tests/fixtures/misc/checker/logsumexp.txt diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py new file mode 100644 index 0000000..6473f99 --- /dev/null +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -0,0 +1,14 @@ +import torch +a = torch.randn(5) +b = torch.randn(5) + +# logsumexp +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) +y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) + +# not logsumexp +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) +y = torch.log(torch.sum(torch.exp(x) + 2.5, 1)) +y = torch.log(2 + x) +y = torch.sum(torch.log(torch.exp(x)), 1) +y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt new file mode 100644 index 0000000..4a4f5ec --- /dev/null +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -0,0 +1,2 @@ +6:5 TOR108 Use numerically stabilized `torch.logsumexp`. +7:5 TOR108 Use numerically stabilized `torch.logsumexp`. diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index a31bc1b..5baa12a 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -47,6 +47,7 @@ def pytest_generate_tests(metafunc): "TOR105", "TOR106", "TOR107", + "TOR108", }, ), (None, set(GET_ALL_ERROR_CODES()) - exclude_set), diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 1cb3e69..dae1a24 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -11,6 +11,7 @@ TorchDeprecatedSymbolsVisitor, TorchExpm1Visitor, TorchLog1pVisitor, + TorchLogsumexpVisitor, TorchNonPublicAliasVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, @@ -32,6 +33,7 @@ TorchDeprecatedSymbolsVisitor, TorchExpm1Visitor, TorchLog1pVisitor, + TorchLogsumexpVisitor, TorchNonPublicAliasVisitor, TorchRequireGradVisitor, TorchReentrantCheckpointVisitor, diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index 8e56b4a..5317d1b 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -3,6 +3,7 @@ from .misc import ( TorchExpm1Visitor, TorchLog1pVisitor, + TorchLogsumexpVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, ) @@ -19,6 +20,7 @@ "TorchDeprecatedSymbolsVisitor", "TorchExpm1Visitor", "TorchLog1pVisitor", + "TorchLogsumexpVisitor", "TorchNonPublicAliasVisitor", "TorchReentrantCheckpointVisitor", "TorchRequireGradVisitor", diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index 348612c..e77de4f 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -96,7 +96,6 @@ class TorchLog1pVisitor(TorchVisitor): def visit_Call(self, node): if self.get_qualified_name_for_call(node) == "torch.log": - if m.matches( node, m.Call( @@ -114,7 +113,6 @@ def visit_Call(self, node): ], ), ): - self.add_violation( node, error_code=self.ERRORS[0].error_code, @@ -154,3 +152,41 @@ def visit_BinaryOperation(self, node): message=self.ERRORS[0].message(), replacement=None, ) + + +class TorchLogsumexpVisitor(TorchVisitor): + """ + Suggest using `torch.logsumexp(x)` instead of `torch.log(torch.sum(torch.exp(x))`. + """ + + ERRORS = [ + TorchError( + "TOR108", + ("Use numerically stabilized `torch.logsumexp`."), + ) + ] + + def visit_Call(self, node): + if self.get_qualified_name_for_call(node) == "torch.log": + if m.matches( + node, + m.Call( + args=[ + m.Arg(m.Call(args=[m.Arg(m.Call()), m.ZeroOrMore()])), + m.ZeroOrMore(), + ] + ), + ): + if self.get_qualified_name_for_call(node.args[0].value) == "torch.sum": + if ( + self.get_qualified_name_for_call( + node.args[0].value.args[0].value + ) + == "torch.exp" + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From 00954c918e3e6b46b1f55b7d1eb16f22bf4e6a3d Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 3 Feb 2025 18:59:17 -0800 Subject: [PATCH 65/66] Don't suggest logsumexp if sum's dim is None (#91) See https://github.com/pytorch/pytorch/issues/144339 --- tests/fixtures/misc/checker/logsumexp.py | 11 +++++++++-- tests/fixtures/misc/checker/logsumexp.txt | 2 ++ torchfix/visitors/misc/__init__.py | 21 ++++++++++++++++----- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py index 6473f99..d4399f0 100644 --- a/tests/fixtures/misc/checker/logsumexp.py +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -1,10 +1,12 @@ import torch -a = torch.randn(5) -b = torch.randn(5) + +x = torch.randn(5) # logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), dim=1, keepdim=True)) y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) +y = torch.log(torch.sum(torch.exp(2.5 + x), dim=1)) # not logsumexp y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) @@ -12,3 +14,8 @@ y = torch.log(2 + x) y = torch.sum(torch.log(torch.exp(x)), 1) y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) + +# not logsumexp because of https://github.com/pytorch/pytorch/issues/144339 +y = torch.log(torch.sum(torch.exp(x), None, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), dim=None, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), keepdim=True)) diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt index 4a4f5ec..5298d5f 100644 --- a/tests/fixtures/misc/checker/logsumexp.txt +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -1,2 +1,4 @@ 6:5 TOR108 Use numerically stabilized `torch.logsumexp`. 7:5 TOR108 Use numerically stabilized `torch.logsumexp`. +8:5 TOR108 Use numerically stabilized `torch.logsumexp`. +9:5 TOR108 Use numerically stabilized `torch.logsumexp`. diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index e77de4f..8f0c70c 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -184,9 +184,20 @@ def visit_Call(self, node): ) == "torch.exp" ): - self.add_violation( - node, - error_code=self.ERRORS[0].error_code, - message=self.ERRORS[0].message(), - replacement=None, + + # if `dim` is not provided or None for sum, skip: + # https://github.com/pytorch/pytorch/issues/144339 + dim_arg = self.get_specific_arg( + node.args[0].value, arg_name="dim", arg_pos=1 ) + if dim_arg is not None: + if not ( + isinstance(dim_arg.value, cst.Name) + and dim_arg.value.value == "None" + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) From 0ee382c8bc4d9388fe08d42c8963fba53b73273b Mon Sep 17 00:00:00 2001 From: Shivam Agarwal <35878114+shivam096@users.noreply.github.com> Date: Fri, 7 Feb 2025 12:59:35 -0800 Subject: [PATCH 66/66] Reimplementation of GradNotSetToNonePattern from Torchtidy (#92) Adding rules to check for `set_to_none` parameter for `zero_grad()`. By setting set_to_none=True, we can gain speedup --- .../fixtures/performance/checker/zerograd.py | 16 +++++++++ .../fixtures/performance/checker/zerograd.txt | 2 ++ torchfix/torchfix.py | 2 ++ torchfix/visitors/__init__.py | 6 +++- torchfix/visitors/performance/__init__.py | 33 +++++++++++++++++++ 5 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/performance/checker/zerograd.py create mode 100644 tests/fixtures/performance/checker/zerograd.txt diff --git a/tests/fixtures/performance/checker/zerograd.py b/tests/fixtures/performance/checker/zerograd.py new file mode 100644 index 0000000..8f0d6fc --- /dev/null +++ b/tests/fixtures/performance/checker/zerograd.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + +x = torch.ones((100, 100)) +model = nn.Sequential() +optimizer = torch.optim.Adam(model.parameters()) + +# This should raise flags +optimizer.zero_grad(set_to_none=False) +model.zero_grad(set_to_none=False) + +# This should not raise flags +optimizer.zero_grad() +model.zero_grad() + + diff --git a/tests/fixtures/performance/checker/zerograd.txt b/tests/fixtures/performance/checker/zerograd.txt new file mode 100644 index 0000000..ed29bf4 --- /dev/null +++ b/tests/fixtures/performance/checker/zerograd.txt @@ -0,0 +1,2 @@ +9:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad(). +10:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad(). \ No newline at end of file diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index dae1a24..5e96e38 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -21,6 +21,7 @@ TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, + TorchGradNotSetToNonePatternVisitor, ) __version__ = "0.7.0" @@ -43,6 +44,7 @@ TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, + TorchGradNotSetToNonePatternVisitor, ] diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index 5317d1b..45f2438 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -8,7 +8,10 @@ TorchRequireGradVisitor, ) from .nonpublic import TorchNonPublicAliasVisitor -from .performance import TorchSynchronizedDataLoaderVisitor +from .performance import ( + TorchSynchronizedDataLoaderVisitor, + TorchGradNotSetToNonePatternVisitor, +) from .security import TorchUnsafeLoadVisitor from .vision import ( TorchVisionDeprecatedPretrainedVisitor, @@ -30,4 +33,5 @@ "TorchVisionDeprecatedPretrainedVisitor", "TorchVisionDeprecatedToTensorVisitor", "TorchVisionSingletonImportVisitor", + "TorchGradNotSetToNonePatternVisitor", ] diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index 249df4c..0558af5 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -32,3 +32,36 @@ def visit_Call(self, node): error_code=self.ERRORS[0].error_code, message=self.ERRORS[0].message(), ) + + +class TorchGradNotSetToNonePatternVisitor(TorchVisitor): + """ + Reimplementation of GradNotSetToNonePattern from + https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py + """ + + ERRORS = [ + TorchError( + "TOR402", + ( + "Detected gradient set to zero instead of None. " + "Please add 'set_to_none=True' when calling zero_grad()." + ), + ) + ] + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + + if qualified_name and qualified_name.endswith("zero_grad"): + + set_to_none_arg = self.get_specific_arg(node, "set_to_none", 0) + + # hasattr check to handle mypy error + if set_to_none_arg and hasattr(set_to_none_arg.value, "value"): + if set_to_none_arg.value.value == "False": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + )