Skip to content

Commit a0a87f8

Browse files
fix detach numpy in pytorch
1 parent 55b43a5 commit a0a87f8

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

Diff for: CHANGELOG.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44

55
### Added
66

7-
- PyTorch backend support multi pytrees version of ``tree_map``
7+
- PyTorch backend support multi pytrees version of `tree_map`
88

99
### Changed
1010

11-
- Refactor ``interfaces`` code as a submodule and add pytree support for args
11+
- Refactor `interfaces` code as a submodule and add pytree support for args
1212

1313
- Change the way to register global setup internally, so that we can skip the list of all submodules
1414

1515
- Refactor the tensortrans code to a pytree perspective
1616

17+
### Fixed
18+
19+
- Fixed `numpy` method bug in pytorch backend when the input tensor requires grad (#24)
20+
1721
## 0.2.1
1822

1923
### Added

Diff for: tensorcircuit/backends/pytorch_backend.py

+2
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ def kron(self, a: Tensor, b: Tensor) -> Tensor:
285285
def numpy(self, a: Tensor) -> Tensor:
286286
if a.is_conj():
287287
return a.resolve_conj().numpy()
288+
if a.requires_grad:
289+
return a.detach().numpy()
288290
return a.numpy()
289291

290292
def i(self, dtype: Any = None) -> Tensor:

Diff for: tests/test_interfaces.py

+26
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,32 @@ def f3(x):
105105
np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
106106

107107

108+
@pytest.mark.skipif(is_torch is False, reason="torch not installed")
109+
@pytest.mark.xfail(reason="see comment link below")
110+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
111+
def test_torch_interface_pytree(backend):
112+
# pytree cannot support in pytorch autograd function...
113+
# https://github.com/pytorch/pytorch/issues/55509
114+
def f4(x):
115+
return tc.backend.sum(x["a"] ** 2), tc.backend.sum(x["b"] ** 3)
116+
117+
f4_torch = tc.interfaces.torch_interface(f4, jit=False)
118+
param4 = {
119+
"a": torch.ones([2], requires_grad=True),
120+
"b": torch.ones([2], requires_grad=True),
121+
}
122+
123+
def f4_post(x):
124+
r1, r2 = f4_torch(param4)
125+
l4 = r1 + r2
126+
return l4
127+
128+
pg = tc.get_backend("pytorch").grad(f4_post)(param4)
129+
np.testing.assert_allclose(
130+
pg["a"], 2 * np.ones([2]).astype(np.complex64), atol=1e-5
131+
)
132+
133+
108134
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
109135
def test_scipy_interface(backend):
110136
n = 3

0 commit comments

Comments
 (0)