diff --git a/.flake8 b/.flake8 index 13dea6a..4e741ee 100644 --- a/.flake8 +++ b/.flake8 @@ -3,3 +3,4 @@ exclude = ./tests/fixtures/ # Match black tool's default. max-line-length = 88 +extend-ignore = E203 diff --git a/.github/workflows/test-torchfix.yml b/.github/workflows/test-torchfix.yml index 2047373..9a61e01 100644 --- a/.github/workflows/test-torchfix.yml +++ b/.github/workflows/test-torchfix.yml @@ -6,16 +6,23 @@ on: jobs: test-torchfix: - runs-on: ubuntu-latest + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{ matrix.os }} steps: - name: Checkout - uses: actions/checkout@v3 - - name: Install requirements + uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Upgrade build dependencies run: | - pip3 install -r requirements-dev.txt + pip3 install -U pip + pip3 install -U setuptools - name: Install TorchFix run: | - pip3 install . + pip3 install ".[dev]" - name: Run pytest run: | pytest tests @@ -25,3 +32,6 @@ jobs: - name: Run mypy run: | mypy . + - name: Run black + run: | + black --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2f41268 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: local + hooks: + - id: black + name: black + entry: black + language: system + types: [python] + args: ["--config=./pyproject.toml"] + exclude: ^tests/fixtures/ + - id: flake8 + name: flake8 + entry: flake8 + language: system + types: [python] + args: ["--config=./.flake8"] + exclude: ^tests/fixtures/ + - id: mypy + name: mypy + entry: mypy + language: system + types: [python] + exclude: ^tests/fixtures/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6e27b16..8995e56 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,10 +14,34 @@ We actively welcome your pull requests. 1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. -4. Ensure the test suite passes. -5. Make sure your code lints. +4. Ensure the test suite passes (`pytest tests`). +5. Make sure your code lints (see Linting section below). 6. If you haven't already, complete the Contributor License Agreement ("CLA"). +## Linting + +We use `black`, `flake8`, and `mypy` to lint the code. Configuration is available to run lints via `pre-commit`. + +```shell +pip install ".[dev]" +``` + +Linting via pre-commit hook: + +```shell +# manually run pre-commit hooks on all files (runs all linters) +pre-commit run --all-files +``` + +Manually running individual linters: + +```shell +black . +flake8 +mypy . +``` + + ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only diff --git a/README.md b/README.md index 30fac53..b51e24b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # TorchFix - a linter for PyTorch-using code with autofix support +[![PyPI](https://img.shields.io/pypi/v/torchfix.svg)](https://pypi.org/project/torchfix/) + TorchFix is a Python code static analysis tool - a linter with autofix capabilities - for users of PyTorch. It can be used to find and fix issues like usage of deprecated PyTorch functions and non-public symbols, and to adopt PyTorch best practices in general. @@ -11,7 +13,8 @@ reporting issues. TorchFix can be used as a Flake8 plugin (linting only) or as a standalone program (with autofix available for a subset of the lint violations). -Currently TorchFix is in a **beta version** stage, so there are still a lot of rough +> [!WARNING] +> Currently TorchFix is in a **beta version** stage, so there are still a lot of rough edges and many things can and will change. ## Installation @@ -34,7 +37,8 @@ TorchFix can also be run as a standalone program: `torchfix .` Add `--fix` parameter to try to autofix some of the issues (the files will be overwritten!) To see some additional debug info, add `--show-stderr` parameter. -Please keep in mind that autofix is a best-effort mechanism. Given the dynamic nature of Python, +> [!CAUTION] +> Please keep in mind that autofix is a best-effort mechanism. Given the dynamic nature of Python, and especially the beta version status of TorchFix, it's very difficult to have certainty when making changes to code, even for the seemingly trivial fixes. @@ -49,6 +53,16 @@ To enable them, use standard flake8 configuration options for the plugin mode or If you encounter a bug or some other problem with TorchFix, please file an issue on https://github.com/pytorch-labs/torchfix/issues. +## Rule Code Assignment Policy + +New rule codes are assigned incrementally across the following categories: + +* **TOR0XX, TOR1XX**: General-purpose `torch` functionality. +* **TOR2XX**: Domain-specific rules, such as TorchVision. +* **TOR4XX**: Noisy rules that are disabled by default. +* **TOR9XX**: Internal rules specific for `pytorch/pytorch` repo, other users should not use these. + +TOR0, TOR1 and TOR2 are enabled by default. ## Rules @@ -65,5 +79,121 @@ To get the LU factorization see `torch.lu`, which can be used with `torch.lu_sol `X = torch.solve(B, A).solution` should be replaced with `X = torch.linalg.solve(A, B)`. +#### torch.symeig + +This function was deprecated since PyTorch version 1.9 and is now removed. + +`torch.symeig` is deprecated in favor of `torch.linalg.eigh`. + +The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion. + +```python +L, _ = torch.symeig(A, upper=upper) +``` + +should be replaced with + +```python +L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L') +``` + +and + +```python +L, V = torch.symeig(A, eigenvectors=True) +``` + +should be replaced with + +```python +L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') +``` + +### TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`? + +This is a common misspelling that can lead to silent performance issues. + +### TOR003 Please pass `use_reentrant` explicitly to `checkpoint` + +The default value of the `use_reentrant` parameter in `torch.utils.checkpoint` is being changed +from `True` to `False`. In the meantime, the value needs to be passed explicitly. + +See this [forum post](https://dev-discuss.pytorch.org/t/bc-breaking-update-to-torch-utils-checkpoint-not-passing-in-use-reentrant-flag-will-raise-an-error/1745) +for details. + +### TOR004 Import of removed function + +See `TOR001`. + +### TOR101 Use of deprecated function + +#### torch.nn.utils.weight_norm + +This function is deprecated. Use `torch.nn.utils.parametrizations.weight_norm` +which uses the modern parametrization API. The new `weight_norm` is compatible +with `state_dict` generated from old `weight_norm`. + +Migration guide: + +* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed + as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` + respectively. + +* To remove the weight normalization reparametrization, use + `torch.nn.utils.parametrize.remove_parametrizations`. + +* The weight is no longer recomputed once at module forward; instead, it will + be recomputed on every access. To restore the old behavior, use + `torch.nn.utils.parametrize.cached` before invoking the module + in question. + +#### torch.backends.cuda.sdp_kernel + +This function is deprecated. Use the `torch.nn.attention.sdpa_kernel` context manager instead. + +Migration guide: +Each boolean input parameter (defaulting to true unless specified) of `sdp_kernel` corresponds to a `SDPBackened`. If the input parameter is true, the corresponding backend should be added to the input list of `sdpa_kernel`. + +#### torch.chain_matmul + +This function is deprecated in favor of `torch.linalg.multi_dot`. + +Migration guide: +`multi_dot` accepts a list of two or more tensors whereas `chain_matmul` accepted multiple tensors as input arguments. For migration, convert the multiple tensors in argument of `chain_matmul` into a list of two or more tensors for `multi_dot`. + +Example: Replace `torch.chain_matmul(a, b, c)` with `torch.linalg.multi_dot([a, b, c])`. + +#### torch.cholesky + +`torch.cholesky()` is deprecated in favor of `torch.linalg.cholesky()`. + +Migration guide: +* `L = torch.cholesky(A)` should be replaced with `L = torch.linalg.cholesky(A)`. +* `L = torch.cholesky(A, upper=True)` should be replaced with `L = torch.linalg.cholesky(A).mH` + +#### torch.qr + +`torch.qr()` is deprecated in favor of `torch.linalg.qr()`. + +Migration guide: +* The usage `Q, R = torch.qr(A)` should be replaced with `Q, R = torch.linalg.qr(A)`. +* The boolean parameter `some` of `torch.qr` is replaced with a string parameter `mode` in `torch.linalg.qr`. The corresponding change in usage is from `Q, R = torch.qr(A, some=False)` to `Q, R = torch.linalg.qr(A, mode="complete")`. + +#### torch.range + +The function `torch.range()` is deprecated as its usage is incompatible with Python's builtin range. Instead, use `torch.arange()` as it produces values in `[start, end)`. + +Migration guide: +* `torch.range(start, end)` produces values in the range of `[start, end]`. But `torch.arange(start, end)` produces values in `[start, end)`. For step size of 1, migrate usage from `torch.range(start, end, 1)` to `torch.arange(start, end+1, 1)`. + +### TOR102 `torch.load` without `weights_only` parameter is unsafe. + +Explicitly set `weights_only` to False only if you trust the data you load and full pickle functionality is needed, otherwise set `weights_only=True`. + +### TOR103 Import of deprecated function + +See `TOR101`. + ## License + TorchFix is BSD License licensed, as found in the LICENSE file. diff --git a/pyproject.toml b/pyproject.toml index 244c9aa..48ac050 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,10 @@ +[build-system] +requires = ["setuptools >= 65.0"] +build-backend = "setuptools.build_meta" + [project] name = "TorchFix" +requires-python = ">=3.9" description = "TorchFix - a linter for PyTorch-using code with autofix support" readme = "README.md" license = {file = "LICENSE"} @@ -8,7 +13,22 @@ classifiers = [ "Programming Language :: Python" ] dynamic = ["version"] -dependencies = ["flake8>=3.8.2", "PyYAML", "libcst>=1.1.0,<1.2.0"] +dependencies = [ + "flake8>=3.8.2", + "PyYAML", + "libcst>=1.5.0,<1.6.0" +] + +[project.optional-dependencies] +dev = [ + "flake8==6.0.0", + "pytest==7.2.0", + "libcst==1.5.0", + "types-PyYAML==6.0.7", + "mypy==1.7.0", + "black==24.4.0", + "pre-commit==3.7.0", +] [project.urls] Repository = "https://github.com/pytorch-labs/torchfix" @@ -35,6 +55,7 @@ exclude = "tests/fixtures/*" [tool.mypy] exclude = ["tests/fixtures", "build"] +check_untyped_defs = true [tool.setuptools.dynamic] version = {attr = "torchfix.torchfix.__version__"} diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 225a8fe..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,5 +0,0 @@ -flake8==6.0.0 -pytest==7.2.0 -libcst==1.1.0 -types-PyYAML==6.0.7 -mypy==1.4.1 diff --git a/tests/fixtures/deprecated_symbols/checker/amp.py b/tests/fixtures/deprecated_symbols/checker/amp.py new file mode 100644 index 0000000..278ac39 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/amp.py @@ -0,0 +1,10 @@ +import torch + +torch.cuda.amp.autocast() +torch.cuda.amp.custom_fwd() +torch.cuda.amp.custom_bwd() + +dtype = torch.float32 +maybe_autocast = torch.cpu.amp.autocast() +maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16) +maybe_autocast = torch.cpu.amp.autocast(dtype=dtype) diff --git a/tests/fixtures/deprecated_symbols/checker/amp.txt b/tests/fixtures/deprecated_symbols/checker/amp.txt new file mode 100644 index 0000000..71939e9 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/amp.txt @@ -0,0 +1,6 @@ +3:1 TOR101 Use of deprecated function torch.cuda.amp.autocast +4:1 TOR101 Use of deprecated function torch.cuda.amp.custom_fwd +5:1 TOR101 Use of deprecated function torch.cuda.amp.custom_bwd +8:18 TOR101 Use of deprecated function torch.cpu.amp.autocast +9:18 TOR101 Use of deprecated function torch.cpu.amp.autocast +10:18 TOR101 Use of deprecated function torch.cpu.amp.autocast diff --git a/tests/fixtures/deprecated_symbols/checker/deprecated_from_nn.py b/tests/fixtures/deprecated_symbols/checker/deprecated_from_nn.py index 17e9058..a9f2866 100644 --- a/tests/fixtures/deprecated_symbols/checker/deprecated_from_nn.py +++ b/tests/fixtures/deprecated_symbols/checker/deprecated_from_nn.py @@ -6,3 +6,6 @@ import torch.nn as yy yy.UpsamplingNearest2d() + +func = torch.nn.UpsamplingNearest2d # not detected currently +func() diff --git a/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt b/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt index 9768b1f..3da6ee5 100644 --- a/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt +++ b/tests/fixtures/deprecated_symbols/checker/deprecated_qr.txt @@ -1,4 +1,6 @@ 2:7 TOR101 Use of deprecated function torch.qr 6:7 TOR101 Use of deprecated function torch.qr +9:1 TOR103 Import of deprecated function torch.qr 10:7 TOR101 Use of deprecated function torch.qr +13:1 TOR103 Import of deprecated function torch.qr 16:7 TOR101 Use of deprecated function torch.qr diff --git a/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.py b/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.py new file mode 100644 index 0000000..b594d68 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.py @@ -0,0 +1,9 @@ +from torch.utils._pytree import _register_pytree_node + +_register_pytree_node() + +from torch.utils import _pytree as xx +xx._register_pytree_node() + +import torch +torch.utils._pytree._register_pytree_node() diff --git a/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.txt b/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.txt new file mode 100644 index 0000000..eb93906 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/deprecated_register_pytree_node.txt @@ -0,0 +1,4 @@ +1:1 TOR103 Import of deprecated function torch.utils._pytree._register_pytree_node +3:1 TOR101 Use of deprecated function torch.utils._pytree._register_pytree_node +6:1 TOR101 Use of deprecated function torch.utils._pytree._register_pytree_node +9:1 TOR101 Use of deprecated function torch.utils._pytree._register_pytree_node \ No newline at end of file diff --git a/tests/fixtures/deprecated_symbols/checker/functorch.py b/tests/fixtures/deprecated_symbols/checker/functorch.py index f072cbb..044240f 100644 --- a/tests/fixtures/deprecated_symbols/checker/functorch.py +++ b/tests/fixtures/deprecated_symbols/checker/functorch.py @@ -2,3 +2,5 @@ # Check that we get only one warning for the line functorch.vmap(tdmodule, (None, 0))(td, params) + +from functorch import vmap, jacrev diff --git a/tests/fixtures/deprecated_symbols/checker/functorch.txt b/tests/fixtures/deprecated_symbols/checker/functorch.txt index 336c7ae..e4f802c 100644 --- a/tests/fixtures/deprecated_symbols/checker/functorch.txt +++ b/tests/fixtures/deprecated_symbols/checker/functorch.txt @@ -1 +1,3 @@ 4:1 TOR101 Use of deprecated function functorch.vmap +6:1 TOR103 Import of deprecated function functorch.vmap +6:1 TOR103 Import of deprecated function functorch.jacrev diff --git a/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt b/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt index 06b13e7..7610c28 100644 --- a/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt +++ b/tests/fixtures/deprecated_symbols/checker/removed_symeig.txt @@ -1,2 +1,3 @@ +2:1 TOR004 Import of removed function torch.symeig 4:8 TOR001 Use of removed function torch.symeig 5:8 TOR001 Use of removed function torch.symeig diff --git a/tests/fixtures/deprecated_symbols/checker/sdp_kernel.py b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.py new file mode 100644 index 0000000..06d14a8 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.py @@ -0,0 +1,12 @@ +import torch +from torch.backends import cuda +from torch.backends.cuda import sdp_kernel + +with torch.backends.cuda.sdp_kernel() as context: + pass + +with cuda.sdp_kernel() as context: + pass + +with sdp_kernel() as context: + pass diff --git a/tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt new file mode 100644 index 0000000..d18f1ee --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/sdp_kernel.txt @@ -0,0 +1,4 @@ +3:1 TOR103 Import of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +5:6 TOR101 Use of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +8:6 TOR101 Use of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel +11:6 TOR101 Use of deprecated function torch.backends.cuda.sdp_kernel: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel diff --git a/tests/fixtures/deprecated_symbols/checker/weight_norm.py b/tests/fixtures/deprecated_symbols/checker/weight_norm.py new file mode 100644 index 0000000..e4bb515 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/weight_norm.py @@ -0,0 +1,2 @@ +from torch import nn +m = nn.utils.weight_norm(nn.Linear(20, 40), name='weight') diff --git a/tests/fixtures/deprecated_symbols/checker/weight_norm.txt b/tests/fixtures/deprecated_symbols/checker/weight_norm.txt new file mode 100644 index 0000000..74e8c05 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/checker/weight_norm.txt @@ -0,0 +1 @@ +2:5 TOR101 Use of deprecated function torch.nn.utils.weight_norm: https://github.com/pytorch-labs/torchfix#torchnnutilsweight_norm diff --git a/tests/fixtures/deprecated_symbols/codemod/aliased_import.py b/tests/fixtures/deprecated_symbols/codemod/aliased_import.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/aliased_import.py rename to tests/fixtures/deprecated_symbols/codemod/aliased_import.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/aliased_import.py.out b/tests/fixtures/deprecated_symbols/codemod/aliased_import.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/aliased_import.py.out rename to tests/fixtures/deprecated_symbols/codemod/aliased_import.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/amp.in.py b/tests/fixtures/deprecated_symbols/codemod/amp.in.py new file mode 100644 index 0000000..6a1227c --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/amp.in.py @@ -0,0 +1,11 @@ +import torch + +dtype = torch.float32 + +maybe_autocast = torch.cuda.amp.autocast() +maybe_autocast = torch.cuda.amp.autocast(dtype=torch.bfloat16) +maybe_autocast = torch.cuda.amp.autocast(dtype=dtype) + +maybe_autocast = torch.cpu.amp.autocast() +maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16) +maybe_autocast = torch.cpu.amp.autocast(dtype=dtype) diff --git a/tests/fixtures/deprecated_symbols/codemod/amp.out.py b/tests/fixtures/deprecated_symbols/codemod/amp.out.py new file mode 100644 index 0000000..da39d0a --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/amp.out.py @@ -0,0 +1,11 @@ +import torch + +dtype = torch.float32 + +maybe_autocast = torch.amp.autocast("cuda") +maybe_autocast = torch.amp.autocast("cuda", dtype=torch.bfloat16) +maybe_autocast = torch.amp.autocast("cuda", dtype=dtype) + +maybe_autocast = torch.amp.autocast("cpu") +maybe_autocast = torch.amp.autocast("cpu", dtype=torch.bfloat16) +maybe_autocast = torch.amp.autocast("cpu", dtype=dtype) diff --git a/tests/fixtures/deprecated_symbols/codemod/chain_matmul.py b/tests/fixtures/deprecated_symbols/codemod/chain_matmul.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/chain_matmul.py rename to tests/fixtures/deprecated_symbols/codemod/chain_matmul.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/chain_matmul.py.out b/tests/fixtures/deprecated_symbols/codemod/chain_matmul.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/chain_matmul.py.out rename to tests/fixtures/deprecated_symbols/codemod/chain_matmul.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/cholesky.py b/tests/fixtures/deprecated_symbols/codemod/cholesky.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/cholesky.py rename to tests/fixtures/deprecated_symbols/codemod/cholesky.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/cholesky.py.out b/tests/fixtures/deprecated_symbols/codemod/cholesky.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/cholesky.py.out rename to tests/fixtures/deprecated_symbols/codemod/cholesky.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/functorch.py b/tests/fixtures/deprecated_symbols/codemod/functorch.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/functorch.py rename to tests/fixtures/deprecated_symbols/codemod/functorch.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/functorch.py.out b/tests/fixtures/deprecated_symbols/codemod/functorch.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/functorch.py.out rename to tests/fixtures/deprecated_symbols/codemod/functorch.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py b/tests/fixtures/deprecated_symbols/codemod/ger-outer.in.py similarity index 76% rename from tests/fixtures/deprecated_symbols/codemod/ger-outer.py rename to tests/fixtures/deprecated_symbols/codemod/ger-outer.in.py index c5e64c4..6fce087 100644 --- a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py +++ b/tests/fixtures/deprecated_symbols/codemod/ger-outer.in.py @@ -1,6 +1,9 @@ import torch +from torch import ger deprecated = torch.norm() sinusoid_inp = torch.ger(pos_seq, inv_freq) other = something.ger(pos_seq, inv_freq) deprecated = torch.norm() one_more = torch.ger(pos_seq, inv_freq) + +just_name = ger(pos_seq, inv_freq) diff --git a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out b/tests/fixtures/deprecated_symbols/codemod/ger-outer.out.py similarity index 75% rename from tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out rename to tests/fixtures/deprecated_symbols/codemod/ger-outer.out.py index 45f3d84..3378fde 100644 --- a/tests/fixtures/deprecated_symbols/codemod/ger-outer.py.out +++ b/tests/fixtures/deprecated_symbols/codemod/ger-outer.out.py @@ -1,6 +1,9 @@ import torch +from torch import outer deprecated = torch.norm() sinusoid_inp = torch.outer(pos_seq, inv_freq) other = something.ger(pos_seq, inv_freq) deprecated = torch.norm() one_more = torch.outer(pos_seq, inv_freq) + +just_name = outer(pos_seq, inv_freq) diff --git a/tests/fixtures/deprecated_symbols/codemod/qr.py b/tests/fixtures/deprecated_symbols/codemod/qr.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/qr.py rename to tests/fixtures/deprecated_symbols/codemod/qr.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/qr.py.out b/tests/fixtures/deprecated_symbols/codemod/qr.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/qr.py.out rename to tests/fixtures/deprecated_symbols/codemod/qr.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/range-arange.py b/tests/fixtures/deprecated_symbols/codemod/range-arange.in.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/range-arange.py rename to tests/fixtures/deprecated_symbols/codemod/range-arange.in.py diff --git a/tests/fixtures/deprecated_symbols/codemod/range-arange.py.out b/tests/fixtures/deprecated_symbols/codemod/range-arange.out.py similarity index 100% rename from tests/fixtures/deprecated_symbols/codemod/range-arange.py.out rename to tests/fixtures/deprecated_symbols/codemod/range-arange.out.py diff --git a/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.in.py b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.in.py new file mode 100644 index 0000000..b594d68 --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.in.py @@ -0,0 +1,9 @@ +from torch.utils._pytree import _register_pytree_node + +_register_pytree_node() + +from torch.utils import _pytree as xx +xx._register_pytree_node() + +import torch +torch.utils._pytree._register_pytree_node() diff --git a/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.out.py b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.out.py new file mode 100644 index 0000000..28bfe2b --- /dev/null +++ b/tests/fixtures/deprecated_symbols/codemod/register_pytree_node.out.py @@ -0,0 +1,9 @@ +from torch.utils._pytree import register_pytree_node + +register_pytree_node() + +from torch.utils import _pytree as xx +xx.register_pytree_node() + +import torch +torch.utils._pytree.register_pytree_node() diff --git a/tests/fixtures/internal/checker/scoped_library.py b/tests/fixtures/internal/checker/scoped_library.py new file mode 100644 index 0000000..54d5316 --- /dev/null +++ b/tests/fixtures/internal/checker/scoped_library.py @@ -0,0 +1,3 @@ +import torch +from torch.library import Library, impl, fallthrough_kernel +my_lib1 = Library("aten", "IMPL") diff --git a/tests/fixtures/internal/checker/scoped_library.txt b/tests/fixtures/internal/checker/scoped_library.txt new file mode 100644 index 0000000..1f1e7f8 --- /dev/null +++ b/tests/fixtures/internal/checker/scoped_library.txt @@ -0,0 +1 @@ +3:11 TOR901 Use `torch.library._scoped_library` instead of `torch.library.Library` in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 for details. diff --git a/tests/fixtures/misc/checker/expm1.py b/tests/fixtures/misc/checker/expm1.py new file mode 100644 index 0000000..4a7646d --- /dev/null +++ b/tests/fixtures/misc/checker/expm1.py @@ -0,0 +1,12 @@ +import torch +a = torch.randn(5) +b = torch.exp(a) - 1 +c = torch.exp(a) - 1.0 + +ret = (torch.exp(a) - 1) * torch.exp(2 * b) + +# False negative: can not detect currently +x = a.exp() - 1 + +# False negative: should be rare and would complicate implementation +x = -1 + torch.exp(a) diff --git a/tests/fixtures/misc/checker/expm1.txt b/tests/fixtures/misc/checker/expm1.txt new file mode 100644 index 0000000..ed24905 --- /dev/null +++ b/tests/fixtures/misc/checker/expm1.txt @@ -0,0 +1,3 @@ +3:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`. +4:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`. +6:7 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`. diff --git a/tests/fixtures/misc/checker/log1p.py b/tests/fixtures/misc/checker/log1p.py new file mode 100644 index 0000000..6afffdb --- /dev/null +++ b/tests/fixtures/misc/checker/log1p.py @@ -0,0 +1,9 @@ +import torch +a = torch.randn(5) +b = torch.log(1 + a) +c = torch.log(a + 1) +b = torch.log(1.0 + a) +c = torch.log(a + 1.0) + +# False negative: can not detect currently +x = (a + 1).log() diff --git a/tests/fixtures/misc/checker/log1p.txt b/tests/fixtures/misc/checker/log1p.txt new file mode 100644 index 0000000..3bcbeac --- /dev/null +++ b/tests/fixtures/misc/checker/log1p.txt @@ -0,0 +1,4 @@ +3:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. +4:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. +5:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. +6:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. diff --git a/tests/fixtures/misc/checker/logsumexp.py b/tests/fixtures/misc/checker/logsumexp.py new file mode 100644 index 0000000..d4399f0 --- /dev/null +++ b/tests/fixtures/misc/checker/logsumexp.py @@ -0,0 +1,21 @@ +import torch + +x = torch.randn(5) + +# logsumexp +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), dim=1, keepdim=True)) +y = torch.log(torch.sum(torch.exp(2.5 + x), 1)) +y = torch.log(torch.sum(torch.exp(2.5 + x), dim=1)) + +# not logsumexp +y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5) +y = torch.log(torch.sum(torch.exp(x) + 2.5, 1)) +y = torch.log(2 + x) +y = torch.sum(torch.log(torch.exp(x)), 1) +y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True)) + +# not logsumexp because of https://github.com/pytorch/pytorch/issues/144339 +y = torch.log(torch.sum(torch.exp(x), None, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), dim=None, keepdim=True)) +y = torch.log(torch.sum(torch.exp(x), keepdim=True)) diff --git a/tests/fixtures/misc/checker/logsumexp.txt b/tests/fixtures/misc/checker/logsumexp.txt new file mode 100644 index 0000000..5298d5f --- /dev/null +++ b/tests/fixtures/misc/checker/logsumexp.txt @@ -0,0 +1,4 @@ +6:5 TOR108 Use numerically stabilized `torch.logsumexp`. +7:5 TOR108 Use numerically stabilized `torch.logsumexp`. +8:5 TOR108 Use numerically stabilized `torch.logsumexp`. +9:5 TOR108 Use numerically stabilized `torch.logsumexp`. diff --git a/tests/fixtures/misc/checker/reentrant_checkpoint.py b/tests/fixtures/misc/checker/reentrant_checkpoint.py new file mode 100644 index 0000000..938a41f --- /dev/null +++ b/tests/fixtures/misc/checker/reentrant_checkpoint.py @@ -0,0 +1,13 @@ +import torch +def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) + +import torch.utils.checkpoint +def fn(x, y): + return checkpoint(gn, torch.sin(x), y) + return checkpoint(gn, torch.sin(x), y, use_reentrant=False) + +from torch.utils.checkpoint import checkpoint +def fn(x, y): + return checkpoint(gn, torch.sin(x), y) + return checkpoint(gn, torch.sin(x), y, use_reentrant=True) diff --git a/tests/fixtures/misc/checker/reentrant_checkpoint.txt b/tests/fixtures/misc/checker/reentrant_checkpoint.txt new file mode 100644 index 0000000..af867d6 --- /dev/null +++ b/tests/fixtures/misc/checker/reentrant_checkpoint.txt @@ -0,0 +1,2 @@ +7:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`. +12:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`. diff --git a/tests/fixtures/misc/codemod/reentrant_checkpoint.in.py b/tests/fixtures/misc/codemod/reentrant_checkpoint.in.py new file mode 100644 index 0000000..3d0051d --- /dev/null +++ b/tests/fixtures/misc/codemod/reentrant_checkpoint.in.py @@ -0,0 +1,6 @@ +import torch +from torch.utils.checkpoint import checkpoint +def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) +def fn(x, y): + return checkpoint(gn, torch.sin(x), y) diff --git a/tests/fixtures/misc/codemod/reentrant_checkpoint.out.py b/tests/fixtures/misc/codemod/reentrant_checkpoint.out.py new file mode 100644 index 0000000..57c69b7 --- /dev/null +++ b/tests/fixtures/misc/codemod/reentrant_checkpoint.out.py @@ -0,0 +1,6 @@ +import torch +from torch.utils.checkpoint import checkpoint +def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) +def fn(x, y): + return checkpoint(gn, torch.sin(x), y, use_reentrant=False) diff --git a/tests/fixtures/misc/codemod/require_grad.py b/tests/fixtures/misc/codemod/require_grad.in.py similarity index 100% rename from tests/fixtures/misc/codemod/require_grad.py rename to tests/fixtures/misc/codemod/require_grad.in.py diff --git a/tests/fixtures/misc/codemod/require_grad.py.out b/tests/fixtures/misc/codemod/require_grad.out.py similarity index 100% rename from tests/fixtures/misc/codemod/require_grad.py.out rename to tests/fixtures/misc/codemod/require_grad.out.py diff --git a/tests/fixtures/nonpublic/checker/default_collate_convert.py b/tests/fixtures/nonpublic/checker/default_collate_convert.py new file mode 100644 index 0000000..c7b7c65 --- /dev/null +++ b/tests/fixtures/nonpublic/checker/default_collate_convert.py @@ -0,0 +1,11 @@ +from torch.utils.data import _utils +batch = _utils.collate.default_collate(batch) + +from torch.utils.data._utils.collate import default_collate +inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx) + +from torch.utils.data._utils.collate import default_convert +values = default_convert(values) + +import torch +values = torch.utils.data._utils.collate.default_convert(values) diff --git a/tests/fixtures/nonpublic/checker/default_collate_convert.txt b/tests/fixtures/nonpublic/checker/default_collate_convert.txt new file mode 100644 index 0000000..edfc9a9 --- /dev/null +++ b/tests/fixtures/nonpublic/checker/default_collate_convert.txt @@ -0,0 +1,8 @@ +2:9 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +4:1 TOR105 Import of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +5:29 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +5:54 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +5:79 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +7:1 TOR105 Import of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead +8:10 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead +11:10 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead diff --git a/tests/fixtures/nonpublic/codemod/default_collate_convert.in.py b/tests/fixtures/nonpublic/codemod/default_collate_convert.in.py new file mode 100644 index 0000000..b61bad6 --- /dev/null +++ b/tests/fixtures/nonpublic/codemod/default_collate_convert.in.py @@ -0,0 +1,14 @@ +from torch.utils.data import _utils # will not be removed as it could be used for something besides default_collate +batch = _utils.collate.default_collate(batch) + +from torch.utils.data._utils import collate # also will not be removed +batch = collate.default_collate(batch) + +from torch.utils.data._utils.collate import default_collate +inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx) + +from torch.utils.data._utils.collate import default_convert +values = default_convert(values) + +import torch +values = torch.utils.data._utils.collate.default_convert(values) diff --git a/tests/fixtures/nonpublic/codemod/default_collate_convert.out.py b/tests/fixtures/nonpublic/codemod/default_collate_convert.out.py new file mode 100644 index 0000000..11bfb53 --- /dev/null +++ b/tests/fixtures/nonpublic/codemod/default_collate_convert.out.py @@ -0,0 +1,14 @@ +from torch.utils.data import dataloader, _utils # will not be removed as it could be used for something besides default_collate +batch = dataloader.default_collate(batch) + +from torch.utils.data._utils import collate # also will not be removed +batch = dataloader.default_collate(batch) + +from torch.utils.data.dataloader import default_collate +inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx) + +from torch.utils.data.dataloader import default_convert +values = default_convert(values) + +import torch +values = torch.utils.data.dataloader.default_convert(values) diff --git a/tests/fixtures/performance/checker/zerograd.py b/tests/fixtures/performance/checker/zerograd.py new file mode 100644 index 0000000..8f0d6fc --- /dev/null +++ b/tests/fixtures/performance/checker/zerograd.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + +x = torch.ones((100, 100)) +model = nn.Sequential() +optimizer = torch.optim.Adam(model.parameters()) + +# This should raise flags +optimizer.zero_grad(set_to_none=False) +model.zero_grad(set_to_none=False) + +# This should not raise flags +optimizer.zero_grad() +model.zero_grad() + + diff --git a/tests/fixtures/performance/checker/zerograd.txt b/tests/fixtures/performance/checker/zerograd.txt new file mode 100644 index 0000000..ed29bf4 --- /dev/null +++ b/tests/fixtures/performance/checker/zerograd.txt @@ -0,0 +1,2 @@ +9:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad(). +10:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad(). \ No newline at end of file diff --git a/tests/fixtures/vision/checker/singleton_import.py b/tests/fixtures/vision/checker/singleton_import.py new file mode 100644 index 0000000..ad30130 --- /dev/null +++ b/tests/fixtures/vision/checker/singleton_import.py @@ -0,0 +1,8 @@ +import torchvision.models as models +import torchvision.models as cnn +from torchvision.models import resnet50, resnet101 +import torchvision.models +from torchvision.models import * +import torchvision.models as models, torch +import torchvision.datasets as datasets +import torchvision.transforms as transforms diff --git a/tests/fixtures/vision/checker/singleton_import.txt b/tests/fixtures/vision/checker/singleton_import.txt new file mode 100644 index 0000000..d6e7421 --- /dev/null +++ b/tests/fixtures/vision/checker/singleton_import.txt @@ -0,0 +1,4 @@ +1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. +6:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. +7:1 TOR203 Consider replacing 'import torchvision.datasets as datasets' with 'from torchvision import datasets'. +8:1 TOR203 Consider replacing 'import torchvision.transforms as transforms' with 'from torchvision import transforms'. diff --git a/tests/fixtures/vision/codemod/pretrained.py b/tests/fixtures/vision/codemod/pretrained.in.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained.py rename to tests/fixtures/vision/codemod/pretrained.in.py diff --git a/tests/fixtures/vision/codemod/pretrained.py.out b/tests/fixtures/vision/codemod/pretrained.out.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained.py.out rename to tests/fixtures/vision/codemod/pretrained.out.py diff --git a/tests/fixtures/vision/codemod/pretrained_models_import.py b/tests/fixtures/vision/codemod/pretrained_models_import.in.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_models_import.py rename to tests/fixtures/vision/codemod/pretrained_models_import.in.py diff --git a/tests/fixtures/vision/codemod/pretrained_models_import.py.out b/tests/fixtures/vision/codemod/pretrained_models_import.out.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_models_import.py.out rename to tests/fixtures/vision/codemod/pretrained_models_import.out.py diff --git a/tests/fixtures/vision/codemod/pretrained_none_import.py b/tests/fixtures/vision/codemod/pretrained_none_import.in.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_none_import.py rename to tests/fixtures/vision/codemod/pretrained_none_import.in.py diff --git a/tests/fixtures/vision/codemod/pretrained_none_import.py.out b/tests/fixtures/vision/codemod/pretrained_none_import.out.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_none_import.py.out rename to tests/fixtures/vision/codemod/pretrained_none_import.out.py diff --git a/tests/fixtures/vision/codemod/pretrained_submodule_import.py b/tests/fixtures/vision/codemod/pretrained_submodule_import.in.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_submodule_import.py rename to tests/fixtures/vision/codemod/pretrained_submodule_import.in.py diff --git a/tests/fixtures/vision/codemod/pretrained_submodule_import.py.out b/tests/fixtures/vision/codemod/pretrained_submodule_import.out.py similarity index 100% rename from tests/fixtures/vision/codemod/pretrained_submodule_import.py.out rename to tests/fixtures/vision/codemod/pretrained_submodule_import.out.py diff --git a/tests/fixtures/vision/codemod/singleton_import.in.py b/tests/fixtures/vision/codemod/singleton_import.in.py new file mode 100644 index 0000000..50d1ecd --- /dev/null +++ b/tests/fixtures/vision/codemod/singleton_import.in.py @@ -0,0 +1,9 @@ +import torchvision.models as models +import torchvision.models as cnn +import torchvision.datasets as datasets +import torchvision.datasets as datasets_alt +import torchvision.transforms as transforms +import torchvision.transforms as transforms_alt + +# don't touch if more than one name imported +import torchvision.models as models, torch diff --git a/tests/fixtures/vision/codemod/singleton_import.out.py b/tests/fixtures/vision/codemod/singleton_import.out.py new file mode 100644 index 0000000..e17def6 --- /dev/null +++ b/tests/fixtures/vision/codemod/singleton_import.out.py @@ -0,0 +1,9 @@ +from torchvision import models +import torchvision.models as cnn +from torchvision import datasets +import torchvision.datasets as datasets_alt +from torchvision import transforms +import torchvision.transforms as transforms_alt + +# don't touch if more than one name imported +import torchvision.models as models, torch diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index b9f0be0..5baa12a 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -1,54 +1,97 @@ +import logging +import subprocess from pathlib import Path + +import libcst.codemod as codemod from torchfix.torchfix import ( + DISABLED_BY_DEFAULT, + expand_error_codes, + GET_ALL_ERROR_CODES, + GET_ALL_VISITORS, + process_error_code_str, TorchChecker, TorchCodemod, TorchCodemodConfig, - GET_ALL_VISITORS, ) -import logging -import libcst.codemod as codemod FIXTURES_PATH = Path(__file__).absolute().parent / "fixtures" LOGGER = logging.getLogger(__name__) +def pytest_generate_tests(metafunc): + # Dynamically generate test cases from paths + if "checker_source_path" in metafunc.fixturenames: + files = list(FIXTURES_PATH.glob("**/checker/*.py")) + metafunc.parametrize( + "checker_source_path", files, ids=[file_name.stem for file_name in files] + ) + if "codemod_source_path" in metafunc.fixturenames: + files = list(FIXTURES_PATH.glob("**/codemod/*.in.py")) + metafunc.parametrize( + "codemod_source_path", files, ids=[file_name.stem for file_name in files] + ) + if "case" in metafunc.fixturenames: + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + cases = [ + ("ALL", GET_ALL_ERROR_CODES()), + ("ALL,TOR102", GET_ALL_ERROR_CODES()), + ("TOR102", {"TOR102"}), + ("TOR102,TOR101", {"TOR102", "TOR101"}), + ( + "TOR1,TOR102", + { + "TOR101", + "TOR102", + "TOR103", + "TOR104", + "TOR105", + "TOR106", + "TOR107", + "TOR108", + }, + ), + (None, set(GET_ALL_ERROR_CODES()) - exclude_set), + ] + metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases]) + + def _checker_results(s): checker = TorchChecker(None, s) return [f"{line}:{col} {msg}" for line, col, msg, _ in checker.run()] -def _codemod_results(source_path): - with open(source_path) as source: - code = source.read() - config = TorchCodemodConfig(select="ALL") - context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) +def _codemod_results(source_path: Path): + code = source_path.read_text() + config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) + context = TorchCodemod(codemod.CodemodContext(filename=str(source_path)), config) new_module = codemod.transform_module(context, code) - return new_module.code + if isinstance(new_module, codemod.TransformSuccess): + return new_module.code + if isinstance(new_module, codemod.TransformFailure): + raise new_module.error def test_empty(): assert _checker_results([""]) == [] -def test_checker_fixtures(): - for source_path in FIXTURES_PATH.glob("**/checker/*.py"): - LOGGER.info("Testing %s", source_path.relative_to(Path.cwd())) - expected_path = str(source_path)[:-2] + "txt" - expected_results = [] - with open(expected_path) as expected: - for line in expected: - expected_results.append(line.rstrip()) - - with open(source_path) as source: - assert _checker_results(source.readlines()) == expected_results +def test_checker_fixtures(checker_source_path: Path): + expected_path = checker_source_path.with_suffix(".txt") + expected_results = expected_path.read_text().splitlines() + results = _checker_results( + checker_source_path.read_text().splitlines(keepends=True) + ) + # Overwrite the expected data with the results (useful when updating tests) + # expected_path.write_text("".join([f"{line}\n" for line in results])) + assert results == expected_results -def test_codemod_fixtures(): - for source_path in FIXTURES_PATH.glob("**/codemod/*.py"): - LOGGER.info("Testing %s", source_path.relative_to(Path.cwd())) - expected_path = source_path.with_suffix(".py.out") - expected_results = expected_path.read_text() - assert _codemod_results(source_path) == expected_results +def test_codemod_fixtures(codemod_source_path: Path): + expected_path = codemod_source_path.with_stem( + codemod_source_path.stem.replace(".in", ".out") + ) + expected_results = expected_path.read_text() + assert _codemod_results(codemod_source_path) == expected_results def test_errorcodes_distinct(): @@ -56,7 +99,25 @@ def test_errorcodes_distinct(): seen = set() for visitor in visitors: LOGGER.info("Checking error code for %s", visitor.__class__.__name__) - error_code = visitor.ERROR_CODE - for e in error_code if isinstance(error_code, list) else [error_code]: - assert e not in seen - seen.add(e) + for e in visitor.ERRORS: + assert e.error_code not in seen + seen.add(e.error_code) + + +def test_parse_error_code_str(case, expected): + assert process_error_code_str(case) == expected + + +def test_no_python_files(tmp_path): + # Create a temporary directory with no Python files + non_python_file = tmp_path / "not_a_python_file.txt" + non_python_file.write_text("This is not a Python file") + + # Run torchfix on the temporary directory + result = subprocess.run( + ["python3", "-m", "torchfix", str(tmp_path)], + capture_output=True, + text=True, + ) + # Check that the script exits successfully + assert result.returncode == 0 diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 0c3a823..4964e01 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -1,22 +1,54 @@ import argparse -import libcst.codemod as codemod import contextlib -import sys +import ctypes import io -import os +import sys + +import libcst.codemod as codemod -from .torchfix import TorchCodemod, TorchCodemodConfig from .common import CYAN, ENDC +from .torchfix import ( + __version__ as TorchFixVersion, + DISABLED_BY_DEFAULT, + GET_ALL_ERROR_CODES, + process_error_code_str, + TorchCodemod, + TorchCodemodConfig, +) -def main() -> None: + +# Should get rid of this code eventually. +@contextlib.contextmanager +def StderrSilencer(redirect: bool = True): + if not redirect: + yield + elif sys.platform != "darwin": + with contextlib.redirect_stderr(io.StringIO()): + yield + else: + # redirect_stderr does not work for some reason + # Workaround it by using good old dup2 to redirect + # stderr to /dev/null + libc = ctypes.CDLL("libc.dylib") + orig_stderr = libc.dup(2) + with open("/dev/null", "w") as f: + libc.dup2(f.fileno(), 2) + try: + yield + finally: + libc.dup2(orig_stderr, 2) + libc.close(orig_stderr) + + +def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "path", nargs="+", - help=("Path to check/fix. Can be a directory, a file, or multiple of either."), + help="Path to check/fix. Can be a directory, a file, or multiple of either.", ) parser.add_argument( "--fix", @@ -32,11 +64,13 @@ def main() -> None: ) parser.add_argument( "--select", - help="ALL to enable rules disabled by default", - choices=[ - "ALL", - ], + help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. " + f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. " + f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.", + type=str, + default=None, ) + parser.add_argument("--version", action="version", version=f"{TorchFixVersion}") # XXX TODO: Get rid of this! # Silence "Failed to determine module name" @@ -46,8 +80,11 @@ def main() -> None: action="store_true", ) - args = parser.parse_args() + return parser.parse_args() + +def main() -> None: + args = _parse_args() files = codemod.gather_files(args.path) # Filter out files that don't have "torch" string in them. @@ -55,25 +92,20 @@ def main() -> None: MARKER = "torch" # this will catch import torch or functorch torch_files = [] for file in files: - # TODO: remove the check when https://github.com/Instagram/LibCST/pull/994 lands - if os.path.isfile(file): # `codemod.gather_files` can return dirs with ".py" - with open(file, errors="replace") as f: - for line in f: - if MARKER in line: - torch_files.append(file) - break + with open(file, errors="replace") as f: + for line in f: + if MARKER in line: + torch_files.append(file) + break + if not torch_files: + return config = TorchCodemodConfig() - config.select = args.select + config.select = list(process_error_code_str(args.select)) command_instance = TorchCodemod(codemod.CodemodContext(), config) DIFF_CONTEXT = 5 try: - if not args.show_stderr: - context = contextlib.redirect_stderr(io.StringIO()) - else: - # Should get rid of this code eventually. - context = contextlib.nullcontext() # type: ignore - with context: + with StderrSilencer(not args.show_stderr): result = codemod.parallel_exec_transform_with_prettyprint( command_instance, torch_files, diff --git a/torchfix/common.py b/torchfix/common.py index 52f2f52..5222576 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -1,10 +1,12 @@ -from dataclasses import dataclass import sys +from abc import ABC +from dataclasses import dataclass +from os.path import commonprefix +from typing import List, Optional, Sequence, Set, Tuple, Mapping + import libcst as cst -from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider from libcst.codemod.visitors import ImportItem -from typing import Optional, List, Set, Union -from abc import ABC +from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() CYAN = "\033[96m" if IS_TTY else "" @@ -24,7 +26,7 @@ class LintViolation: def flake8_result(self): full_message = f"{self.error_code} {self.message}" - return (self.line, 1 + self.column, full_message, "TorchFix") + return self.line, 1 + self.column, full_message, "TorchFix" def codemod_result(self) -> str: fixable = f" [{CYAN}*{ENDC}]" if self.replacement is not None else "" @@ -34,15 +36,27 @@ def codemod_result(self) -> str: return f"{position} {error_code}{fixable} {self.message}" +@dataclass(frozen=True) +class TorchError: + """Defines an error along with an explanation""" + + error_code: str + message_template: str + + def message(self, **kwargs): + return self.message_template.format(**kwargs) + + class TorchVisitor(cst.BatchableCSTVisitor, ABC): METADATA_DEPENDENCIES = ( QualifiedNameProvider, WhitespaceInclusivePositionProvider, ) - ERROR_CODE: Union[str, List[str]] + ERRORS: List[TorchError] def __init__(self) -> None: + super().__init__() self.violations: List[LintViolation] = [] self.needed_imports: Set[ImportItem] = set() @@ -50,7 +64,10 @@ def __init__(self) -> None: def get_specific_arg( node: cst.Call, arg_name: str, arg_pos: int ) -> Optional[cst.Arg]: - # `arg_pos` is zero-based. + """ + :param arg_pos: `arg_pos` is zero-based. -1 means it's a keyword argument. + :note: consider using `has_specific_arg` if you only need to check for presence. + """ curr_pos = 0 for arg in node.args: if arg.keyword is None: @@ -61,6 +78,41 @@ def get_specific_arg( return arg return None + @staticmethod + def has_specific_arg( + node: cst.Call, arg_name: str, position: Optional[int] = None + ) -> bool: + """ + Check if the specific argument is present in a call. + """ + return ( + TorchVisitor.get_specific_arg( + node, arg_name, position if position is not None else -1 + ) + is not None + ) + + def add_violation( + self, + node: cst.CSTNode, + error_code: str, + message: str, + replacement: Optional[cst.CSTNode] = None, + ) -> None: + position_metadata = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + self.violations.append( + LintViolation( + error_code=error_code, + message=message, + line=position_metadata.start.line, + column=position_metadata.start.column, + node=node, + replacement=replacement, + ) + ) + def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]: # Guard against situations like `vmap(a)(b)`: # @@ -77,41 +129,111 @@ def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]: name_metadata = list(self.get_metadata(QualifiedNameProvider, node)) if not name_metadata: return None - qualified_name = name_metadata[0].name - return qualified_name + return name_metadata[0].name def call_with_name_changes( - node: cst.Call, old_qualified_name: str, new_qualified_name: str -) -> Optional[cst.Call]: + node: cst.Call, qualified_name: str, new_qualified_name: str +) -> Optional[Tuple[cst.Call, Set[ImportItem]]]: """ - Return new `Call` node with name changes. + Return an optional tuple: + new `Call` node with name changes + and a set of newly needed imports. """ - old_begin, _, old_last = old_qualified_name.rpartition(".") - new_begin, _, new_last = new_qualified_name.rpartition(".") - - # If the only difference is the last name part. - if old_begin == new_begin: - replacement = node.with_deep_changes( - old_node=cst.ensure_type(node.func, cst.Attribute).attr, - value=new_last, + needed_imports: Set[ImportItem] = set() + call_name = cst.helpers.get_full_name_for_node(node) + assert call_name is not None + replacement = None + + alias_prefix = "" + if not qualified_name.endswith(call_name): + # This means we have an alias (`import from as`). + common_suffix = commonprefix([qualified_name[::-1], call_name[::-1]])[::-1] + alias_prefix = call_name.removesuffix(common_suffix) + "." + + if not new_qualified_name.endswith(call_name): + # We need to change the call name as it's not a part of the new qualified name. + # Get the new call name on the same hierarchical level. + new_call_name = new_qualified_name.removeprefix( + commonprefix([qualified_name.removesuffix(call_name), new_qualified_name]) + ) + new_module_name = new_qualified_name.removesuffix(new_call_name).removesuffix( + "." + ) + if new_module_name: + needed_imports.add( + ImportItem( + module_name=new_module_name, + obj_name=new_call_name.split(".")[0], + ) + ) + replacement = node.with_changes( + func=cst.parse_expression(alias_prefix + new_call_name) ) - - # If the last name part is the same and - # originally called without a dot: don't change the call site, - # just change the imports elsewhere. - elif old_last == new_last and isinstance(node.func, cst.Name): - replacement = None # Replace with new_qualified_name. - else: - replacement = node.with_changes(func=cst.parse_expression(new_qualified_name)) - return replacement + if replacement is None: + return None + + return replacement, needed_imports + + +def check_old_names_in_import_from( + node: cst.ImportFrom, old_new_name_map: Mapping[str, Optional[str]] +) -> Tuple[List[str], Optional[cst.ImportFrom]]: + """ + Using `old_new_name_map`, check if there are any old names in the import from. + Return a tuple of two elements: + 1. List of all founds old names. + 2. Optional replacement for the ImportFrom node. + """ + if node.module is None or not isinstance(node.names, Sequence): + return [], None + + old_names: List[str] = [] + replacement = None + new_names: List[str] = [] + module = cst.helpers.get_full_name_for_node(node.module) + + # `possible_new_modules` and `has_non_updated_names` are used + # to decide if we can replace the ImportFrom node. + new_modules: Set[str] = set() + has_non_updated_names = False + + for name in node.names: + qualified_name = f"{module}.{name.name.value}" + if qualified_name in old_new_name_map: + old_names.append(qualified_name) + new_qualified_name = old_new_name_map[qualified_name] + if new_qualified_name is not None: + new_module = ".".join(new_qualified_name.split(".")[:-1]) + new_name = new_qualified_name.split(".")[-1] + new_names.append(new_name) + new_modules.add(new_module) + else: + has_non_updated_names = True + else: + has_non_updated_names = True + + # Replace only if the new module is the same for all names in the import. + if not has_non_updated_names and len(new_modules) == 1: + new_module = new_modules.pop() + import_aliases = [ + import_alias.with_changes(name=cst.Name(new_name)) + for import_alias, new_name in zip(list(node.names), new_names) + ] + replacement = node.with_changes( + module=cst.parse_expression(new_module), + names=import_aliases, + ) + + return old_names, replacement def deep_multi_replace(tree, replacement_map): class MultiChildReplacementTransformer(cst.CSTTransformer): def __init__(self, replacement_map) -> None: + super().__init__() self.replacement_map = replacement_map def on_leave(self, original_node, updated_node): diff --git a/torchfix/deprecated_symbols.yaml b/torchfix/deprecated_symbols.yaml index 94d8b60..eaa5119 100644 --- a/torchfix/deprecated_symbols.yaml +++ b/torchfix/deprecated_symbols.yaml @@ -25,11 +25,11 @@ replacement: torch.outer - name: torch.lu_solve - deprecate_pr: TBA + deprecate_pr: https://github.com/pytorch/pytorch/pull/73806 remove_pr: - name: torch.norm - deprecate_pr: TBA + deprecate_pr: https://github.com/pytorch/pytorch/pull/57986 remove_pr: - name: torch.range @@ -45,9 +45,17 @@ remove_pr: - name: torch.lu - deprecate_pr: TBA + deprecate_pr: https://github.com/pytorch/pytorch/pull/73804 remove_pr: +- name: torch.matrix_rank + deprecate_pr: https://github.com/pytorch/pytorch/pull/57734 + remove_pr: https://github.com/pytorch/pytorch/pull/70981 + +- name: torch.lstsq + deprecate_pr: https://github.com/pytorch/pytorch/pull/57743 + remove_pr: https://github.com/pytorch/pytorch/pull/70980 + - name: torch.nn.UpsamplingNearest2d deprecate_pr: TBA remove_pr: @@ -60,6 +68,37 @@ deprecate_pr: TBA remove_pr: +- name: torch.nn.utils.weight_norm + deprecate_pr: https://github.com/pytorch/pytorch/pull/103001 + remove_pr: + reference: https://github.com/pytorch-labs/torchfix#torchnnutilsweight_norm + +- name: torch.utils._pytree._register_pytree_node + deprecate_pr: https://github.com/pytorch/pytorch/pull/112111 + remove_pr: + replacement: torch.utils._pytree.register_pytree_node + +- name: torch.backends.cuda.sdp_kernel + deprecate_pr: https://github.com/pytorch/pytorch/pull/114689 + remove_pr: + reference: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel + +- name: torch.cuda.amp.autocast + deprecate_pr: TBA + remove_pr: + +- name: torch.cuda.amp.custom_fwd + deprecate_pr: TBA + remove_pr: + +- name: torch.cuda.amp.custom_bwd + deprecate_pr: TBA + remove_pr: + +- name: torch.cpu.amp.autocast + deprecate_pr: TBA + remove_pr: + # functorch - name: functorch.vmap deprecate_pr: TBA diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 1a47e20..5e96e38 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -1,39 +1,125 @@ from dataclasses import dataclass +import functools from pathlib import Path -from typing import Optional +from typing import Optional, List import libcst as cst import libcst.codemod as codemod -from .common import deep_multi_replace -from .visitors.deprecated_symbols import ( +from .common import deep_multi_replace, TorchVisitor + +from .visitors import ( TorchDeprecatedSymbolsVisitor, - _UpdateFunctorchImports, + TorchExpm1Visitor, + TorchLog1pVisitor, + TorchLogsumexpVisitor, + TorchNonPublicAliasVisitor, + TorchReentrantCheckpointVisitor, + TorchRequireGradVisitor, + TorchScopedLibraryVisitor, + TorchSynchronizedDataLoaderVisitor, + TorchUnsafeLoadVisitor, + TorchVisionDeprecatedPretrainedVisitor, + TorchVisionDeprecatedToTensorVisitor, + TorchVisionSingletonImportVisitor, + TorchGradNotSetToNonePatternVisitor, ) -from .visitors.performance import TorchSynchronizedDataLoaderVisitor -from .visitors.misc import TorchRequireGradVisitor -from .visitors.vision import ( +__version__ = "0.7.0" + +DEPRECATED_CONFIG_PATH = "deprecated_symbols.yaml" + +DISABLED_BY_DEFAULT = ["TOR3", "TOR4", "TOR9"] + +ALL_VISITOR_CLS = [ + TorchDeprecatedSymbolsVisitor, + TorchExpm1Visitor, + TorchLog1pVisitor, + TorchLogsumexpVisitor, + TorchNonPublicAliasVisitor, + TorchRequireGradVisitor, + TorchReentrantCheckpointVisitor, + TorchScopedLibraryVisitor, + TorchSynchronizedDataLoaderVisitor, + TorchUnsafeLoadVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, -) -from .visitors.security import TorchUnsafeLoadVisitor + TorchVisionSingletonImportVisitor, + TorchGradNotSetToNonePatternVisitor, +] + + +@functools.cache +def GET_ALL_ERROR_CODES(): + codes = set() + for cls in ALL_VISITOR_CLS: + assert issubclass(cls, TorchVisitor) + codes |= {error.error_code for error in cls.ERRORS} + return sorted(codes) + -__version__ = "0.2.1" +@functools.cache +def expand_error_codes(codes): + out_codes = set() + for c_a in codes: + for c_b in GET_ALL_ERROR_CODES(): + if c_b.startswith(c_a): + out_codes.add(c_b) + return out_codes -DEPRECATED_CONFIG_PATH = Path(__file__).absolute().parent / "deprecated_symbols.yaml" -DISABLED_BY_DEFAULT = ["TOR3", "TOR4"] +def construct_visitor(cls): + if cls is TorchDeprecatedSymbolsVisitor: + return cls(DEPRECATED_CONFIG_PATH) + + return cls() def GET_ALL_VISITORS(): - return [ - TorchDeprecatedSymbolsVisitor(DEPRECATED_CONFIG_PATH), - TorchRequireGradVisitor(), - TorchSynchronizedDataLoaderVisitor(), - TorchVisionDeprecatedPretrainedVisitor(), - TorchVisionDeprecatedToTensorVisitor(), - TorchUnsafeLoadVisitor(), - ] + return [construct_visitor(v) for v in ALL_VISITOR_CLS] + + +def get_visitors_with_error_codes(error_codes): + visitor_classes = set() + for error_code in error_codes: + # Assume the error codes have been expanded so each error code can + # only correspond to one visitor. + found = False + for visitor_cls in ALL_VISITOR_CLS: + assert issubclass(visitor_cls, TorchVisitor) + if any(error_code == err.error_code for err in visitor_cls.ERRORS): + visitor_classes.add(visitor_cls) + found = True + break + if not found: + raise AssertionError(f"Unknown error code: {error_code}") + return [construct_visitor(cls) for cls in visitor_classes] + + +def process_error_code_str(code_str): + # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001. + # We deduplicate them here. + + # Default when --select is not provided. + if code_str is None: + exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) + return set(GET_ALL_ERROR_CODES()) - exclude_set + + raw_codes = [s.strip() for s in code_str.split(",")] + + # Validate error codes + for c in raw_codes: + if c == "ALL": + continue + if len(expand_error_codes((c,))) == 0: + raise ValueError( + f"Invalid error code: {c}, available error " + f"codes: {list(GET_ALL_ERROR_CODES())}" + ) + + if "ALL" in raw_codes: + return GET_ALL_ERROR_CODES() + + return expand_error_codes(tuple(raw_codes)) # Flake8 plugin @@ -76,7 +162,7 @@ def add_options(optmanager): # Standalone torchfix command @dataclass class TorchCodemodConfig: - select: Optional[str] = None + select: Optional[List[str]] = None class TorchCodemod(codemod.Codemod): @@ -95,8 +181,10 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: # in that case we would need to use `wrapped_module.module` # instead of `module`. wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True) + if self.config is None or self.config.select is None: + raise AssertionError("Expected self.config.select to be set") + visitors = get_visitors_with_error_codes(self.config.select) - visitors = GET_ALL_VISITORS() violations = [] needed_imports = [] wrapped_module.visit_batched(visitors) @@ -108,12 +196,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: replacement_map = {} assert self.context.filename is not None for violation in violations: - skip_violation = False - if self.config is None or self.config.select != "ALL": - for disabled_code in DISABLED_BY_DEFAULT: - if violation.error_code.startswith(disabled_code): - skip_violation = True - break + # Still need to skip violations here, since a single visitor can + # correspond to multiple different types of violations. + skip_violation = True + for code in self.config.select: + if violation.error_code.startswith(code): + skip_violation = False + break if skip_violation: continue @@ -134,10 +223,7 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module: ) new_module = new_module.visit(add_imports_visitor) - update_functorch_imports_visitor = _UpdateFunctorchImports() - new_module = new_module.visit(update_functorch_imports_visitor) - - if fixes_count == 0 and not update_functorch_imports_visitor.changed: + if fixes_count == 0: raise codemod.SkipFile("No changes") return new_module diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py new file mode 100644 index 0000000..45f2438 --- /dev/null +++ b/torchfix/visitors/__init__.py @@ -0,0 +1,37 @@ +from .deprecated_symbols import TorchDeprecatedSymbolsVisitor +from .internal import TorchScopedLibraryVisitor +from .misc import ( + TorchExpm1Visitor, + TorchLog1pVisitor, + TorchLogsumexpVisitor, + TorchReentrantCheckpointVisitor, + TorchRequireGradVisitor, +) +from .nonpublic import TorchNonPublicAliasVisitor +from .performance import ( + TorchSynchronizedDataLoaderVisitor, + TorchGradNotSetToNonePatternVisitor, +) +from .security import TorchUnsafeLoadVisitor +from .vision import ( + TorchVisionDeprecatedPretrainedVisitor, + TorchVisionDeprecatedToTensorVisitor, + TorchVisionSingletonImportVisitor, +) + +__all__ = [ + "TorchDeprecatedSymbolsVisitor", + "TorchExpm1Visitor", + "TorchLog1pVisitor", + "TorchLogsumexpVisitor", + "TorchNonPublicAliasVisitor", + "TorchReentrantCheckpointVisitor", + "TorchRequireGradVisitor", + "TorchScopedLibraryVisitor", + "TorchSynchronizedDataLoaderVisitor", + "TorchUnsafeLoadVisitor", + "TorchVisionDeprecatedPretrainedVisitor", + "TorchVisionDeprecatedToTensorVisitor", + "TorchVisionSingletonImportVisitor", + "TorchGradNotSetToNonePatternVisitor", +] diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index fed7032..40885ee 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -1,34 +1,48 @@ +import pkgutil +from typing import List, Optional + import libcst as cst import yaml -from typing import Optional -from collections.abc import Sequence from ...common import ( - TorchVisitor, call_with_name_changes, - LintViolation, + check_old_names_in_import_from, + TorchError, + TorchVisitor, ) -from .range import call_replacement_range -from .cholesky import call_replacement_cholesky +from .amp import call_replacement_cpu_amp_autocast, call_replacement_cuda_amp_autocast from .chain_matmul import call_replacement_chain_matmul +from .cholesky import call_replacement_cholesky from .qr import call_replacement_qr +from .range import call_replacement_range + class TorchDeprecatedSymbolsVisitor(TorchVisitor): - ERROR_CODE = ["TOR001", "TOR101"] + ERRORS: List[TorchError] = [ + TorchError("TOR001", "Use of removed function {old_name}"), + TorchError("TOR101", "Use of deprecated function {old_name}"), + TorchError("TOR004", "Import of removed function {old_name}"), + TorchError("TOR103", "Import of deprecated function {old_name}"), + ] def __init__(self, deprecated_config_path=None): def read_deprecated_config(path=None): deprecated_config = {} if path is not None: - with open(path) as f: - for item in yaml.load(f, yaml.SafeLoader): - deprecated_config[item["name"]] = item + data = pkgutil.get_data("torchfix", path) + assert data is not None + for item in yaml.load(data, yaml.SafeLoader): + deprecated_config[item["name"]] = item return deprecated_config super().__init__() self.deprecated_config = read_deprecated_config(deprecated_config_path) + self.old_new_name_map = { + name: self.deprecated_config[name].get("replacement") + for name in self.deprecated_config + } def _call_replacement( self, node: cst.Call, qualified_name: str @@ -38,80 +52,71 @@ def _call_replacement( "torch.range": call_replacement_range, "torch.chain_matmul": call_replacement_chain_matmul, "torch.qr": call_replacement_qr, + "torch.cuda.amp.autocast": call_replacement_cuda_amp_autocast, + "torch.cpu.amp.autocast": call_replacement_cpu_amp_autocast, } replacement = None if qualified_name in replacements_map: - replacement = replacements_map[qualified_name](node) - else: - # Replace names for functions that have drop-in replacement. - function_name_replacement = self.deprecated_config.get( - qualified_name, {} - ).get("replacement", "") - if function_name_replacement: - replacement = call_with_name_changes( - node, qualified_name, function_name_replacement - ) - + return replacements_map[qualified_name](node) + + # Replace names for functions that have drop-in replacement. + function_name_replacement = self.deprecated_config.get(qualified_name, {}).get( + "replacement", "" + ) + if function_name_replacement: + replacement_and_imports = call_with_name_changes( + node, qualified_name, function_name_replacement + ) + if replacement_and_imports is not None: + replacement, imports = replacement_and_imports + self.needed_imports.update(imports) return replacement - def visit_Call(self, node): + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + if node.module is None: + return + + old_names, replacement = check_old_names_in_import_from( + node, self.old_new_name_map + ) + for qualified_name in old_names: + if self.deprecated_config[qualified_name]["remove_pr"] is None: + error_code = self.ERRORS[3].error_code + message = self.ERRORS[3].message(old_name=qualified_name) + else: + error_code = self.ERRORS[2].error_code + message = self.ERRORS[2].message(old_name=qualified_name) + + reference = self.deprecated_config[qualified_name].get("reference") + if reference is not None: + message = f"{message}: {reference}" + + self.add_violation( + node, + error_code=error_code, + message=message, + replacement=replacement, + ) + + def visit_Call(self, node) -> None: qualified_name = self.get_qualified_name_for_call(node) if qualified_name is None: return if qualified_name in self.deprecated_config: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) if self.deprecated_config[qualified_name]["remove_pr"] is None: - error_code = self.ERROR_CODE[1] - message = f"Use of deprecated function {qualified_name}" + error_code = self.ERRORS[1].error_code + message = self.ERRORS[1].message(old_name=qualified_name) else: - error_code = self.ERROR_CODE[0] - message = f"Use of removed function {qualified_name}" + error_code = self.ERRORS[0].error_code + message = self.ERRORS[0].message(old_name=qualified_name) replacement = self._call_replacement(node, qualified_name) reference = self.deprecated_config[qualified_name].get("reference") if reference is not None: message = f"{message}: {reference}" - self.violations.append( - LintViolation( - error_code=error_code, - message=message, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, error_code=error_code, message=message, replacement=replacement ) - - -# TODO: refactor/generalize this. -class _UpdateFunctorchImports(cst.CSTTransformer): - REPLACEMENTS = { - "vmap", - "grad", - "vjp", - "jvp", - "jacrev", - "jacfwd", - "hessian", - "functionalize", - } - - def __init__(self): - self.changed = False - - def leave_ImportFrom( - self, node: cst.ImportFrom, updated_node: cst.ImportFrom - ) -> cst.ImportFrom: - if ( - getattr(node.module, "value", None) == "functorch" - and isinstance(node.names, Sequence) - and all(name.name.value in self.REPLACEMENTS for name in node.names) - ): - self.changed = True - return updated_node.with_changes(module=cst.parse_expression("torch.func")) - return updated_node diff --git a/torchfix/visitors/deprecated_symbols/amp.py b/torchfix/visitors/deprecated_symbols/amp.py new file mode 100644 index 0000000..9aa87c7 --- /dev/null +++ b/torchfix/visitors/deprecated_symbols/amp.py @@ -0,0 +1,26 @@ +import libcst as cst + +from ...common import get_module_name + + +def call_replacement_cpu_amp_autocast(node: cst.Call) -> cst.CSTNode: + return _call_replacement_amp(node, "cpu") + + +def call_replacement_cuda_amp_autocast(node: cst.Call) -> cst.CSTNode: + return _call_replacement_amp(node, "cuda") + + +def _call_replacement_amp(node: cst.Call, device: str) -> cst.CSTNode: + """ + Replace `torch.cuda.amp.autocast()` with `torch.amp.autocast("cuda")` and + Replace `torch.cpu.amp.autocast()` with `torch.amp.autocast("cpu")`. + """ + device_arg = cst.ensure_type(cst.parse_expression(f'f("{device}")'), cst.Call).args[ + 0 + ] + + module_name = get_module_name(node, "torch") + replacement = cst.parse_expression(f"{module_name}.amp.autocast(args)") + replacement = replacement.with_changes(args=(device_arg, *node.args)) + return replacement diff --git a/torchfix/visitors/deprecated_symbols/chain_matmul.py b/torchfix/visitors/deprecated_symbols/chain_matmul.py index 3eab730..1e3873f 100644 --- a/torchfix/visitors/deprecated_symbols/chain_matmul.py +++ b/torchfix/visitors/deprecated_symbols/chain_matmul.py @@ -7,21 +7,17 @@ def call_replacement_chain_matmul(node: cst.Call) -> cst.CSTNode: Replace `torch.chain_matmul` with `torch.linalg.multi_dot`, changing multiple parameters to a list. """ - matrices = [] + matrices = [ + cst.Element(value=arg.value) for arg in node.args if arg.keyword is None + ] + matrices_arg = cst.Arg(value=cst.List(elements=matrices)) + out_arg = None for arg in node.args: - if arg.keyword is None: - matrices.append(cst.Element(value=arg.value)) - elif arg.keyword.value == "out": + if arg.keyword is not None and arg.keyword.value == "out": out_arg = arg - matrices_arg = cst.Arg(value=cst.List(elements=matrices)) - if out_arg is None: - replacement_args = [matrices_arg] - else: - replacement_args = [matrices_arg, out_arg] - module_name = get_module_name(node, 'torch') + replacement_args = [matrices_arg] if out_arg is None else [matrices_arg, out_arg] + module_name = get_module_name(node, "torch") replacement = cst.parse_expression(f"{module_name}.linalg.multi_dot(args)") - replacement = replacement.with_changes(args=replacement_args) - - return replacement + return replacement.with_changes(args=replacement_args) diff --git a/torchfix/visitors/deprecated_symbols/cholesky.py b/torchfix/visitors/deprecated_symbols/cholesky.py index cec5e71..c80bf7e 100644 --- a/torchfix/visitors/deprecated_symbols/cholesky.py +++ b/torchfix/visitors/deprecated_symbols/cholesky.py @@ -1,5 +1,5 @@ import libcst as cst -from ...common import (TorchVisitor, get_module_name) +from ...common import TorchVisitor, get_module_name def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode: @@ -19,9 +19,14 @@ def call_replacement_cholesky(node: cst.Call) -> cst.CSTNode: and cst.ensure_type(upper_arg.value, cst.Name).value == "True" ): replacement = cst.parse_expression(f"{module_name}.linalg.cholesky(A).mH") + + # Make mypy happy + assert isinstance(replacement, (cst.Name, cst.Attribute)) + + old_node = cst.ensure_type(replacement.value, cst.Call).args replacement = replacement.with_deep_changes( - # Ignore type error, see https://github.com/Instagram/LibCST/issues/963 - old_node=cst.ensure_type(replacement.value, cst.Call).args, # type: ignore + # see https://github.com/Instagram/LibCST/issues/963 + old_node=old_node, # type: ignore[arg-type] value=[input_arg], ) else: diff --git a/torchfix/visitors/deprecated_symbols/qr.py b/torchfix/visitors/deprecated_symbols/qr.py index f1d96df..ab81fcf 100644 --- a/torchfix/visitors/deprecated_symbols/qr.py +++ b/torchfix/visitors/deprecated_symbols/qr.py @@ -1,6 +1,6 @@ import libcst as cst from typing import Optional -from ...common import (TorchVisitor, get_module_name) +from ...common import TorchVisitor, get_module_name def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]: @@ -29,6 +29,4 @@ def call_replacement_qr(node: cst.Call) -> Optional[cst.CSTNode]: replacement_args = [input_arg] module_name = get_module_name(node, "torch") replacement = cst.parse_expression(f"{module_name}.linalg.qr(args)") - replacement = replacement.with_changes(args=replacement_args) - - return replacement + return replacement.with_changes(args=replacement_args) diff --git a/torchfix/visitors/deprecated_symbols/range.py b/torchfix/visitors/deprecated_symbols/range.py index 97fec69..45e0580 100644 --- a/torchfix/visitors/deprecated_symbols/range.py +++ b/torchfix/visitors/deprecated_symbols/range.py @@ -7,6 +7,7 @@ def call_replacement_range(node: cst.Call) -> Optional[cst.Call]: """Replace `range` with `arange`. Add `step` to the `end` argument as `arange` has the interval `[start, end)`. """ + # `torch.range` documented signature is not a valid Python signature, # so it's hard to generalize this. def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: @@ -17,11 +18,10 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: for arg in node.args: if arg.keyword is None: non_kw_args.append(arg) - else: - if arg.keyword.value == "end": - end_arg = arg - elif arg.keyword.value == "step": - step_arg = arg + elif arg.keyword.value == "end": + end_arg = arg + elif arg.keyword.value == "step": + step_arg = arg if end_arg is None: if len(non_kw_args) == 1: end_arg = non_kw_args[0] @@ -45,16 +45,15 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: step_arg, m.Arg(value=m.UnaryOperation(operator=m.Minus(), expression=m.Integer())), ): - # Ignore type error here and further in this file. - # See https://github.com/Instagram/LibCST/issues/964 - step = -int(step_arg.value.expression.value) # type: ignore + # make mypy happy + assert isinstance(step_arg.value, cst.UnaryOperation) + assert isinstance(step_arg.value.expression, cst.Integer) + step = -int(step_arg.value.expression.value) # Bail out, don't know how to update with non-integer `step`. else: return None - updated_end_arg = None - # `end` is a literal (positive) integer if isinstance(end_arg.value, cst.Integer): end = int(end_arg.value.value) + step @@ -76,10 +75,15 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: end_arg, m.Arg(value=m.UnaryOperation(operator=m.Minus(), expression=m.Integer())), ): - end = -int(end_arg.value.expression.value) + step # type: ignore + op = end_arg.value + # make mypy happy + assert isinstance(op, cst.UnaryOperation) + assert isinstance(op.expression, cst.Integer) + end = -int(op.expression.value) + step if end < 0: updated_end_arg = end_arg.with_deep_changes( - old_node=end_arg.value.expression, value=str(-end) # type: ignore + old_node=op.expression, + value=str(-end), ) else: # `end` became non-negative @@ -93,7 +97,9 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: value=m.BinaryOperation(operator=m.Subtract(), right=m.Integer(value="1")) ), ): - updated_end_arg = end_arg.with_changes(value=end_arg.value.left) # type: ignore + # make mypy happy + assert isinstance(end_arg.value, cst.BinaryOperation) + updated_end_arg = end_arg.with_changes(value=end_arg.value.left) # `end` something else: add `+ 1` at the end else: @@ -105,12 +111,14 @@ def _get_range_args(node: cst.Call) -> Tuple[cst.Arg, Optional[cst.Arg]]: ) ) - replacement = node if updated_end_arg is not None: - # Ignore type error, see https://github.com/Instagram/LibCST/issues/965 - replacement = replacement.deep_replace(end_arg, updated_end_arg) # type: ignore - replacement = replacement.with_deep_changes( + replacement = node.deep_replace(end_arg, updated_end_arg) + + # make mypy happy + assert isinstance(replacement, cst.Call) + else: + replacement = node + + return replacement.with_deep_changes( old_node=cst.ensure_type(replacement.func, cst.Attribute).attr, value="arange" ) - - return replacement diff --git a/torchfix/visitors/internal/__init__.py b/torchfix/visitors/internal/__init__.py new file mode 100644 index 0000000..908527a --- /dev/null +++ b/torchfix/visitors/internal/__init__.py @@ -0,0 +1,29 @@ +from ...common import TorchError, TorchVisitor + + +class TorchScopedLibraryVisitor(TorchVisitor): + """ + Suggest `torch.library._scoped_library` for PyTorch tests. + """ + + ERRORS = [ + TorchError( + "TOR901", + ( + "Use `torch.library._scoped_library` " + "instead of `torch.library.Library` " + "in PyTorch tests files. " + "See https://github.com/pytorch/pytorch/pull/118318 " + "for details." + ), + ) + ] + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + if qualified_name == "torch.library.Library": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + ) diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index ef83d91..8f0c70c 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -1,8 +1,7 @@ import libcst as cst import libcst.matchers as m - -from ...common import TorchVisitor, LintViolation +from ...common import TorchError, TorchVisitor class TorchRequireGradVisitor(TorchVisitor): @@ -10,8 +9,12 @@ class TorchRequireGradVisitor(TorchVisitor): Find and fix common misspelling `require_grad` (instead of `requires_grad`). """ - ERROR_CODE = "TOR002" - MESSAGE = "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?" + ERRORS = [ + TorchError( + "TOR002", + "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?", + ) + ] def visit_Assign(self, node): # Look for any assignment with `require_grad` attribute on the left. @@ -31,18 +34,170 @@ def visit_Assign(self, node): replacement = node.with_deep_changes( old_node=node.targets[0].target.attr, value="requires_grad" ) + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=replacement, + ) + - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node +class TorchReentrantCheckpointVisitor(TorchVisitor): + """ + Find and fix common misuse of reentrant checkpoints. + """ + + ERRORS = [ + TorchError( + "TOR003", + ( + "Please pass `use_reentrant` explicitly to `checkpoint`. " + "To maintain old behavior, pass `use_reentrant=True`. " + "It is recommended to use `use_reentrant=False`." + ), + ) + ] + + def visit_Call(self, node): + if self.get_qualified_name_for_call( + node + ) == "torch.utils.checkpoint.checkpoint" and not self.has_specific_arg( + node, "use_reentrant" + ): + # This codemod maybe unsafe correctness-wise + # if reentrant behavior is actually needed, + # so the changes need to be verified/tested. + use_reentrant_arg = cst.ensure_type( + cst.parse_expression("f(use_reentrant=False)"), cst.Call + ).args[0] + replacement = node.with_changes(args=(*node.args, use_reentrant_arg)) + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=replacement, ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, + +class TorchLog1pVisitor(TorchVisitor): + """ + Suggest using `torch.log1p(x)` instead of `torch.log(1 + x)`. + """ + + ERRORS = [ + TorchError( + "TOR106", + ( + "Use `torch.log1p(x)` instead of `torch.log(1 + x)`. " + "It is more accurate for small values of `x`." + ), + ) + ] + + def visit_Call(self, node): + if self.get_qualified_name_for_call(node) == "torch.log": + if m.matches( + node, + m.Call( + args=[ + m.Arg( + value=m.BinaryOperation( + left=m.Integer(value="1") | m.Float(value="1.0"), + operator=m.Add(), + ) + | m.BinaryOperation( + operator=m.Add(), + right=m.Integer(value="1") | m.Float(value="1.0"), + ), + ), + ], + ), + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, ) - ) + + +class TorchExpm1Visitor(TorchVisitor): + """ + Suggest using `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. + """ + + ERRORS = [ + TorchError( + "TOR107", + ( + "Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. " + "It is more accurate for small values of `x`." + ), + ) + ] + + def visit_BinaryOperation(self, node): + if m.matches( + node, + m.BinaryOperation( + left=m.Call(), + operator=m.Subtract(), + right=m.Integer(value="1") | m.Float(value="1.0"), + ), + ): + if self.get_qualified_name_for_call(node.left) == "torch.exp": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) + + +class TorchLogsumexpVisitor(TorchVisitor): + """ + Suggest using `torch.logsumexp(x)` instead of `torch.log(torch.sum(torch.exp(x))`. + """ + + ERRORS = [ + TorchError( + "TOR108", + ("Use numerically stabilized `torch.logsumexp`."), + ) + ] + + def visit_Call(self, node): + if self.get_qualified_name_for_call(node) == "torch.log": + if m.matches( + node, + m.Call( + args=[ + m.Arg(m.Call(args=[m.Arg(m.Call()), m.ZeroOrMore()])), + m.ZeroOrMore(), + ] + ), + ): + if self.get_qualified_name_for_call(node.args[0].value) == "torch.sum": + if ( + self.get_qualified_name_for_call( + node.args[0].value.args[0].value + ) + == "torch.exp" + ): + + # if `dim` is not provided or None for sum, skip: + # https://github.com/pytorch/pytorch/issues/144339 + dim_arg = self.get_specific_arg( + node.args[0].value, arg_name="dim", arg_pos=1 + ) + if dim_arg is not None: + if not ( + isinstance(dim_arg.value, cst.Name) + and dim_arg.value.value == "None" + ): + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + ) diff --git a/torchfix/visitors/nonpublic/__init__.py b/torchfix/visitors/nonpublic/__init__.py new file mode 100644 index 0000000..281c0f3 --- /dev/null +++ b/torchfix/visitors/nonpublic/__init__.py @@ -0,0 +1,88 @@ +from typing import List + +import libcst as cst + +from ...common import ( + TorchError, + TorchVisitor, + call_with_name_changes, + check_old_names_in_import_from, +) + + +class TorchNonPublicAliasVisitor(TorchVisitor): + """ + Suggest to use public APIs instead of non-public aliases. + + Currently implemented for + torch.utils.data._utils.collate.default_collate and + torch.utils.data._utils.collate.default_convert, + see https://github.com/pytorch/pytorch/pull/69862/files + """ + + ERRORS: List[TorchError] = [ + TorchError( + "TOR104", + ( + "Use of non-public function `{private_name}`, " + "please use `{public_name}` instead" + ), + ), + TorchError( + "TOR105", + ( + "Import of non-public function `{private_name}`, " + "please use `{public_name}` instead" + ), + ), + ] + + # fmt: off + ALIASES = { + "torch.utils.data._utils.collate.default_collate": "torch.utils.data.dataloader.default_collate", # noqa: E501 + "torch.utils.data._utils.collate.default_convert": "torch.utils.data.dataloader.default_convert", # noqa: E501 + } + # fmt: on + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + if qualified_name is None: + return + + if qualified_name in self.ALIASES: + public_name = self.ALIASES[qualified_name] + error_code = self.ERRORS[0].error_code + message = self.ERRORS[0].message( + private_name=qualified_name, public_name=public_name + ) + + replacement_and_imports = call_with_name_changes( + node, qualified_name, public_name + ) + if replacement_and_imports is not None: + replacement, imports = replacement_and_imports + self.needed_imports.update(imports) + else: + replacement = None + + self.add_violation( + node, error_code=error_code, message=message, replacement=replacement + ) + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + if node.module is None: + return + + private_names, replacement = check_old_names_in_import_from(node, self.ALIASES) + for qualified_name in private_names: + public_name = self.ALIASES[qualified_name] + error_code = self.ERRORS[1].error_code + message = self.ERRORS[1].message( + private_name=qualified_name, public_name=public_name + ) + self.add_violation( + node, + error_code=error_code, + message=message, + replacement=replacement, + ) diff --git a/torchfix/visitors/performance/__init__.py b/torchfix/visitors/performance/__init__.py index f838fbe..0558af5 100644 --- a/torchfix/visitors/performance/__init__.py +++ b/torchfix/visitors/performance/__init__.py @@ -1,8 +1,6 @@ -import libcst as cst import libcst.matchers as m - -from ...common import TorchVisitor, LintViolation +from ...common import TorchError, TorchVisitor class TorchSynchronizedDataLoaderVisitor(TorchVisitor): @@ -11,12 +9,16 @@ class TorchSynchronizedDataLoaderVisitor(TorchVisitor): https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py """ - ERROR_CODE = "TOR401" - MESSAGE = ( - "Detected DataLoader running with synchronized implementation. " - "Please enable asynchronous dataloading by setting num_workers > 0 when " - "initializing DataLoader." - ) + ERRORS = [ + TorchError( + "TOR401", + ( + "Detected DataLoader running with synchronized implementation." + " Please enable asynchronous dataloading by setting " + "num_workers > 0 when initializing DataLoader." + ), + ) + ] def visit_Call(self, node): qualified_name = self.get_qualified_name_for_call(node) @@ -25,17 +27,41 @@ def visit_Call(self, node): if num_workers_arg is None or m.matches( num_workers_arg.value, m.Integer(value="0") ): - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=None, + +class TorchGradNotSetToNonePatternVisitor(TorchVisitor): + """ + Reimplementation of GradNotSetToNonePattern from + https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py + """ + + ERRORS = [ + TorchError( + "TOR402", + ( + "Detected gradient set to zero instead of None. " + "Please add 'set_to_none=True' when calling zero_grad()." + ), + ) + ] + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + + if qualified_name and qualified_name.endswith("zero_grad"): + + set_to_none_arg = self.get_specific_arg(node, "set_to_none", 0) + + # hasattr check to handle mypy error + if set_to_none_arg and hasattr(set_to_none_arg.value, "value"): + if set_to_none_arg.value.value == "False": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), ) - ) diff --git a/torchfix/visitors/security/__init__.py b/torchfix/visitors/security/__init__.py index 010c5f4..d1a9380 100644 --- a/torchfix/visitors/security/__init__.py +++ b/torchfix/visitors/security/__init__.py @@ -1,5 +1,6 @@ import libcst as cst -from ...common import TorchVisitor, LintViolation + +from ...common import TorchError, TorchVisitor class TorchUnsafeLoadVisitor(TorchVisitor): @@ -8,48 +9,40 @@ class TorchUnsafeLoadVisitor(TorchVisitor): See https://github.com/pytorch/pytorch/issues/31875. """ - ERROR_CODE = "TOR102" - MESSAGE = ( - "`torch.load` without `weights_only` parameter is unsafe. " - "Explicitly set `weights_only` to False only if you trust the data you load " - "and full pickle functionality is needed, otherwise set " - "`weights_only=True`." - ) + ERRORS = [ + TorchError( + "TOR102", + ( + "`torch.load` without `weights_only` parameter is unsafe. " + "Explicitly set `weights_only` to False only if you trust " + "the data you load " + "and full pickle functionality is needed," + " otherwise set `weights_only=True`." + ), + ) + ] def visit_Call(self, node): - qualified_name = self.get_qualified_name_for_call(node) - if qualified_name == "torch.load": - weights_only_arg = self.get_specific_arg(node, "weights_only", -1) - if weights_only_arg is None: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - - # Add `weights_only=True` if there is no `pickle_module`. - # (do not add `weights_only=False` with `pickle_module`, as it - # needs to be an explicit choice). - # - # This codemod is somewhat unsafe correctness-wise - # because full pickling functionality may still be needed - # even without `pickle_module`, - # so the changes need to be verified/tested. - replacement = None - pickle_module_arg = self.get_specific_arg(node, "pickle_module", 2) - if pickle_module_arg is None: - weights_only_arg = cst.ensure_type( - cst.parse_expression("f(weights_only=True)"), cst.Call - ).args[0] - replacement = node.with_changes( - args=node.args + (weights_only_arg,) - ) - - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=self.MESSAGE, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) - ) + if self.get_qualified_name_for_call( + node + ) == "torch.load" and not self.has_specific_arg(node, "weights_only"): + # Add `weights_only=True` if there is no `pickle_module`. + # (do not add `weights_only=False` with `pickle_module`, as it + # needs to be an explicit choice). + # + # This codemod is somewhat unsafe correctness-wise + # because full pickling functionality may still be needed + # even without `pickle_module`, + # so the changes need to be verified/tested. + replacement = None + if not self.has_specific_arg(node, "pickle_module", 2): + weights_only_arg = cst.ensure_type( + cst.parse_expression("f(weights_only=True)"), cst.Call + ).args[0] + replacement = node.with_changes(args=(*node.args, weights_only_arg)) + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=replacement, + ) diff --git a/torchfix/visitors/vision/__init__.py b/torchfix/visitors/vision/__init__.py index 7adcc19..3a1745f 100644 --- a/torchfix/visitors/vision/__init__.py +++ b/torchfix/visitors/vision/__init__.py @@ -1,2 +1,9 @@ -from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401 -from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401 +from .pretrained import TorchVisionDeprecatedPretrainedVisitor +from .singleton_import import TorchVisionSingletonImportVisitor +from .to_tensor import TorchVisionDeprecatedToTensorVisitor + +__all__ = [ + "TorchVisionDeprecatedPretrainedVisitor", + "TorchVisionDeprecatedToTensorVisitor", + "TorchVisionSingletonImportVisitor", +] diff --git a/torchfix/visitors/vision/pretrained.py b/torchfix/visitors/vision/pretrained.py index 6e17048..acbe564 100644 --- a/torchfix/visitors/vision/pretrained.py +++ b/torchfix/visitors/vision/pretrained.py @@ -3,7 +3,7 @@ import libcst as cst from libcst.codemod.visitors import ImportItem -from ...common import LintViolation, TorchVisitor +from ...common import TorchError, TorchVisitor class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): @@ -16,7 +16,12 @@ class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): otherwise only lint violation is emitted. """ - ERROR_CODE = "TOR201" + ERRORS = [ + TorchError( + "TOR201", + "Parameter `{old_arg_name}` is deprecated, please use `{new_arg_name}` instead.", + ) + ] # flake8: noqa: E105 # fmt: off @@ -172,7 +177,7 @@ class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): def visit_Call(self, node): def _new_arg_and_import( - old_arg: cst.Arg, is_backbone: bool + old_arg: Optional[cst.Arg], is_backbone: bool ) -> Optional[cst.Arg]: old_arg_name = "pretrained_backbone" if is_backbone else "pretrained" if old_arg is None or (model_name, old_arg_name) not in self.MODEL_WEIGHTS: @@ -215,13 +220,17 @@ def _new_arg_and_import( message = None pretrained_arg = self.get_specific_arg(node, "pretrained", 0) if pretrained_arg is not None: - message = "Parameter `pretrained` is deprecated, please use `weights` instead." + message = self.ERRORS[0].message( + old_arg_name="pretrained", new_arg_name="weights" + ) pretrained_backbone_arg = self.get_specific_arg( node, "pretrained_backbone", 1 ) if pretrained_backbone_arg is not None: - message = "Parameter `pretrained_backbone` is deprecated, please use `weights_backbone` instead." + message = self.ERRORS[0].message( + old_arg_name="pretrained_backbone", new_arg_name="weights_backbone" + ) replacement_args = list(node.args) @@ -248,16 +257,9 @@ def _new_arg_and_import( node.with_changes(args=replacement_args) if has_replacement else None ) if message is not None: - position_metadata = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=message, - line=position_metadata.start.line, - column=position_metadata.start.column, - node=node, - replacement=replacement, - ) + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=message, + replacement=replacement, ) diff --git a/torchfix/visitors/vision/singleton_import.py b/torchfix/visitors/vision/singleton_import.py new file mode 100644 index 0000000..f2b207b --- /dev/null +++ b/torchfix/visitors/vision/singleton_import.py @@ -0,0 +1,46 @@ +import libcst as cst +import libcst.matchers as m + +from ...common import TorchError, TorchVisitor + + +class TorchVisionSingletonImportVisitor(TorchVisitor): + ERRORS = [ + TorchError( + "TOR203", + ( + "Consider replacing 'import torchvision.{module} as {module}' " + "with 'from torchvision import {module}'." + ), + ), + ] + + # Keep attr order in sync with ERRORS. + REPLACEABLE_ATTRS = ["datasets", "models", "transforms"] + + def visit_Import(self, node: cst.Import) -> None: + replacement = None + for i, import_attr in enumerate(self.REPLACEABLE_ATTRS): + for imported_item in node.names: + if m.matches( + imported_item, + m.ImportAlias( + name=m.Attribute( + value=m.Name("torchvision"), attr=m.Name(import_attr) + ), + asname=m.AsName(name=m.Name(import_attr)), + ), + ): + # Replace only if the import statement has no other names + if len(node.names) == 1: + replacement = cst.ImportFrom( + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name(import_attr))], + ) + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(module=import_attr), + replacement=replacement, + ) + break diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index 6886c41..02a5915 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -1,31 +1,24 @@ from collections.abc import Sequence + import libcst as cst -from ...common import LintViolation, TorchVisitor +from ...common import TorchError, TorchVisitor + +MESSAGE = ( + "The transform `v2.ToTensor()` is deprecated and will be removed " + "in a future release. Instead, please use " + "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 +) class TorchVisionDeprecatedToTensorVisitor(TorchVisitor): - ERROR_CODE = "TOR202" + ERRORS = [TorchError("TOR202", MESSAGE)] def _maybe_add_violation(self, qualified_name, node): if qualified_name != "torchvision.transforms.v2.ToTensor": return - position = self.get_metadata( - cst.metadata.WhitespaceInclusivePositionProvider, node - ) - self.violations.append( - LintViolation( - error_code=self.ERROR_CODE, - message=( - "The transform `v2.ToTensor()` is deprecated and will be removed " - "in a future release. Instead, please use " - "`v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." # noqa: E501 - ), - line=position.start.line, - column=position.start.column, - node=node, - replacement=None, - ) + self.add_violation( + node, error_code=self.ERRORS[0].error_code, message=self.ERRORS[0].message() ) def visit_ImportFrom(self, node): @@ -43,7 +36,7 @@ def visit_ImportFrom(self, node): def visit_Attribute(self, node): qualified_names = self.get_metadata(cst.metadata.QualifiedNameProvider, node) - if not len(qualified_names) == 1: + if len(qualified_names) != 1: return - self._maybe_add_violation(qualified_names.pop().name, node) + self._maybe_add_violation(list(qualified_names)[0].name, node)