Skip to content

Commit 0020c3f

Browse files
fix arg_to_tensor decorator quoperator behavior
1 parent d3c9f1f commit 0020c3f

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

tensorcircuit/interfaces/tensortrans.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ def tensor_to_numpy(t: Tensor) -> Array:
5050

5151

5252
def tensor_to_backend_jittable(t: Tensor) -> Tensor:
53-
if which_backend(t, return_backend=False) == backend.name:
54-
return t
5553
if isinstance(t, int) or isinstance(t, float):
5654
return t
55+
if isinstance(t, QuOperator):
56+
return t
57+
if which_backend(t, return_backend=False) == backend.name:
58+
return t
5759
return backend.convert_to_tensor(which_backend(t).numpy(t))
5860

5961

@@ -281,9 +283,20 @@ def wrapper(*args: Any, **kws: Any) -> Any:
281283
partial(qop_to_matrix, is_reshapem=qop_as_matrix), arg
282284
)
283285
arg = backend.tree_map(tensor_to_backend_jittable, arg)
286+
284287
# arg = backend.tree_map(backend.convert_to_tensor, arg)
288+
def _cast(a: Tensor, dtype: str) -> Tensor:
289+
if isinstance(a, QuOperator):
290+
return a
291+
return backend.cast(a, dtype)
292+
293+
def _reshapem(a: Tensor) -> Tensor:
294+
if isinstance(a, QuOperator):
295+
return a
296+
return backend.reshapem(a)
297+
285298
if cast_dtype:
286-
arg = backend.tree_map(partial(backend.cast, dtype=dtypestr), arg)
299+
arg = backend.tree_map(partial(_cast, dtype=dtypestr), arg)
287300
if tensor_as_matrix:
288301
arg = backend.tree_map(backend.reshapem, arg)
289302

tests/test_interfaces.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,5 @@ def g(a, b, c):
425425

426426
assert tc.interfaces.which_backend(a[0], return_backend=False) == tc.backend.name
427427
assert tc.backend.shape_tuple(a[1]) == (2, 2, 2, 2)
428-
assert tc.interfaces.which_backend(b, return_backend=False) == tc.backend.name
429-
assert tc.backend.shape_tuple(b) == (2, 2, 2, 2, 2, 2)
428+
assert tc.backend.shape_tuple(b.eval()) == (2, 2, 2, 2, 2, 2)
430429
assert tc.backend.shape_tuple(c) == (2, 2, 2, 2)

0 commit comments

Comments
 (0)