Skip to content

Commit 947eeb6

Browse files
authored
[TorchFix] Add torch.solve as a removed function (#4705)
Add torch.solve as a removed function (I somehow missed it before and noticed after pytorch/tutorials#2642). Also added a mechanism to show reference links in the error messages.
1 parent 01921ff commit 947eeb6

File tree

5 files changed

+30
-0
lines changed

5 files changed

+30
-0
lines changed

Diff for: README.md

+16
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,19 @@ To enable them, use standard flake8 configuration options for the plugin mode or
4848

4949
If you encounter a bug or some other problem with TorchFix, please file an issue on
5050
https://github.com/pytorch/test-infra/issues, mentioning [TorchFix] in the title.
51+
52+
53+
## Rules
54+
55+
### TOR001 Use of removed function
56+
57+
#### torch.solve
58+
59+
This function was deprecated since PyTorch version 1.9 and is now removed.
60+
61+
`torch.solve` is deprecated in favor of `torch.linalg.solve`.
62+
`torch.linalg.solve` has its arguments reversed and does not return the LU factorization.
63+
64+
To get the LU factorization see `torch.lu`, which can be used with `torch.lu_solve` or `torch.lu_unpack`.
65+
66+
`X = torch.solve(B, A).solution` should be replaced with `X = torch.linalg.solve(A, B)`.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import torch
2+
A = torch.randn(3, 3)
3+
b = torch.randn(3)
4+
torch.solve(A, b).solution
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
4:1 TOR001 Use of removed function torch.solve: https://github.com/pytorch/test-infra/tree/main/tools/torchfix#torchsolve

Diff for: torchfix/deprecated_symbols.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
- name: torch.solve
2+
deprecate_pr: https://github.com/pytorch/pytorch/pull/57741
3+
remove_pr: https://github.com/pytorch/pytorch/pull/70986
4+
reference: https://github.com/pytorch/test-infra/tree/main/tools/torchfix#torchsolve
5+
16
- name: torch.qr
27
deprecate_pr: https://github.com/pytorch/pytorch/pull/57745
38
remove_pr:

Diff for: torchfix/visitors/deprecated_symbols/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def visit_Call(self, node):
7272
message = f"Use of removed function {qualified_name}"
7373
replacement = self._call_replacement(node, qualified_name)
7474

75+
reference = self.deprecated_config[qualified_name].get("reference")
76+
if reference is not None:
77+
message = f"{message}: {reference}"
78+
7579
self.violations.append(
7680
LintViolation(
7781
error_code=error_code,

0 commit comments

Comments
 (0)