Skip to content

Commit 0339bf1

Browse files
fix jax breaking changes
1 parent 86026fa commit 0339bf1

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
- Fixed `cu` gate translation from qiskit to avoid qiskit bug
2828

29+
- Fixed jax refactoring (0.4.24) where SVD and QR return a namedtuple instead of a tuple
30+
2931
## 0.11.0
3032

3133
### Added

tensorcircuit/backends/jax_ops.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
@jax.custom_vjp
1616
def adaware_svd(A: Array) -> Any:
17-
return jnp.linalg.svd(A, full_matrices=False)
17+
u, s, v = jnp.linalg.svd(A, full_matrices=False)
18+
return (u, s, v)
1819

1920

2021
def _safe_reciprocal(x: Array, epsilon: float = 1e-15) -> Array:
@@ -77,8 +78,8 @@ def jaxsvd_bwd(r: Sequence[Array], tangents: Sequence[Array]) -> Tuple[Array]:
7778

7879
@jax.custom_vjp
7980
def adaware_qr(A: Array) -> Any:
80-
# q, r = jnp.linalg.qr(A)
81-
return jnp.linalg.qr(A)
81+
q, r = jnp.linalg.qr(A)
82+
return (q, r)
8283

8384

8485
def jaxqr_fwd(A: Array) -> Any:

0 commit comments

Comments
 (0)