@@ -50,10 +50,12 @@ def tensor_to_numpy(t: Tensor) -> Array:
50
50
51
51
52
52
def tensor_to_backend_jittable (t : Tensor ) -> Tensor :
53
- if which_backend (t , return_backend = False ) == backend .name :
54
- return t
55
53
if isinstance (t , int ) or isinstance (t , float ):
56
54
return t
55
+ if isinstance (t , QuOperator ):
56
+ return t
57
+ if which_backend (t , return_backend = False ) == backend .name :
58
+ return t
57
59
return backend .convert_to_tensor (which_backend (t ).numpy (t ))
58
60
59
61
@@ -281,9 +283,20 @@ def wrapper(*args: Any, **kws: Any) -> Any:
281
283
partial (qop_to_matrix , is_reshapem = qop_as_matrix ), arg
282
284
)
283
285
arg = backend .tree_map (tensor_to_backend_jittable , arg )
286
+
284
287
# arg = backend.tree_map(backend.convert_to_tensor, arg)
288
+ def _cast (a : Tensor , dtype : str ) -> Tensor :
289
+ if isinstance (a , QuOperator ):
290
+ return a
291
+ return backend .cast (a , dtype )
292
+
293
+ def _reshapem (a : Tensor ) -> Tensor :
294
+ if isinstance (a , QuOperator ):
295
+ return a
296
+ return backend .reshapem (a )
297
+
285
298
if cast_dtype :
286
- arg = backend .tree_map (partial (backend . cast , dtype = dtypestr ), arg )
299
+ arg = backend .tree_map (partial (_cast , dtype = dtypestr ), arg )
287
300
if tensor_as_matrix :
288
301
arg = backend .tree_map (backend .reshapem , arg )
289
302
0 commit comments