Skip to content

Commit 9329142

Browse files
add mpo expectation in templates
1 parent 7dcf6af commit 9329142

File tree

4 files changed

+34
-4
lines changed

4 files changed

+34
-4
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66

77
- add sigmoid method on backends
88

9+
- add MPO expectation template function for MPO evaluation on circuit
10+
11+
### Fixed
12+
13+
- fix the bug in QuOperator.from_local_tensor where the dtype should always be in numpy context
14+
915
## 0.0.220301
1016

1117
### Added

tensorcircuit/quantum.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
except ImportError:
3636
pass
3737

38-
from .cons import backend, contractor, dtypestr
38+
from .cons import backend, contractor, dtypestr, npdtype
3939
from .backends import get_backend # type: ignore
4040

4141
Tensor = Any
@@ -377,9 +377,7 @@ def from_local_tensor(
377377
out_edges = [localn[i] for i in out_axes]
378378
in_edges = [localn[i] for i in in_axes] # type: ignore
379379
id_nodes = [
380-
CopyNode(2, d, dtype=tensor.dtype)
381-
for i, d in enumerate(space)
382-
if i not in loc
380+
CopyNode(2, d, dtype=npdtype) for i, d in enumerate(space) if i not in loc
383381
]
384382
for n in id_nodes:
385383
out_edges.append(n[0])

tensorcircuit/templates/measurements.py

+7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ..circuit import Circuit
99
from ..cons import backend, dtypestr
10+
from ..quantum import QuOperator
1011
from .. import gates as G
1112

1213
Tensor = Any
@@ -98,6 +99,12 @@ def sparse_expectation(c: Circuit, hamiltonian: Tensor) -> Tensor:
9899
return backend.real(expt)[0, 0]
99100

100101

102+
def mpo_expectation(c: Circuit, mpo: QuOperator) -> Tensor:
103+
mps = c.get_quvector()
104+
e = (mps.adjoint() @ mpo @ mps).eval_matrix()
105+
return backend.real(e)[0, 0]
106+
107+
101108
def heisenberg_measurements(
102109
c: Circuit,
103110
g: Graph,

tests/test_templates.py

+19
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,22 @@ def test_amplitude_encoding(backend):
110110
figs, 2, tc.array_to_tensor(np.array([0, 3, 1, 2]), dtype="int32")
111111
)
112112
np.testing.assert_allclose(states[0], 1 / np.sqrt(2) * np.array([1, 1, 0, 0]))
113+
114+
115+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
116+
def test_mpo_measurement(backend):
117+
def f(theta):
118+
mpo = tc.quantum.QuOperator.from_local_tensor(
119+
tc.array_to_tensor(tc.gates._x_matrix), [2, 2, 2], [0]
120+
)
121+
c = tc.Circuit(3)
122+
c.ry(0, theta=theta)
123+
c.H(1)
124+
c.H(2)
125+
e = tc.templates.measurements.mpo_expectation(c, mpo)
126+
return e
127+
128+
v, g = tc.backend.jit(tc.backend.value_and_grad(f))(tc.backend.ones([]))
129+
130+
np.testing.assert_allclose(v, 0.84147, atol=1e-4)
131+
np.testing.assert_allclose(g, 0.54032, atol=1e-4)

0 commit comments

Comments
 (0)