Skip to content

Commit 44acbdc

Browse files
author
Erertertet
committedJan 20, 2023
change of formatting and better compatibility
1 parent 8ffc993 commit 44acbdc

File tree

3 files changed

+57
-32
lines changed

3 files changed

+57
-32
lines changed
 

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from tensorcircuit import __version__, __author__
44

5-
with open("README.md", "r", encoding = "utf-8") as fh:
5+
with open("README.md", "r", encoding="utf-8") as fh:
66
long_description = fh.read()
77

88

‎tensorcircuit/translation.py

+54-29
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,18 @@
1818
from qiskit.circuit.quantumcircuitdata import CircuitInstruction
1919
from qiskit.circuit.parametervector import ParameterVectorElement
2020
from qiskit.circuit import Parameter, ParameterExpression
21-
import cirq
2221
except ImportError:
2322
logger.warning(
2423
"Please first ``pip install -U qiskit`` to enable related functionality in translation module"
2524
)
2625

26+
try:
27+
import cirq
28+
except ImportError:
29+
logger.warning(
30+
"Please first ``pip install -U cirq`` to enable related functionality in translation module"
31+
)
32+
2733
from . import gates
2834
from .circuit import Circuit
2935
from .densitymatrix import DMCircuit2
@@ -82,21 +88,6 @@ def _merge_extra_qir(
8288
nqir += inds[k]
8389
return nqir
8490

85-
class CustomizedCirqGate(cirq.Gate):
86-
def __init__(self, uMatrix, name, nqubit):
87-
super(CustomizedCirqGate, self)
88-
self.uMatrix = uMatrix
89-
self.name = name
90-
self.nqubit = nqubit
91-
92-
def _num_qubits_(self):
93-
return self.nqubit
94-
95-
def _unitary_(self):
96-
return self.uMatrix
97-
98-
def _circuit_diagram_info_(self, args):
99-
return [self.name] * self.nqubit
10091

10192
def qir2cirq(
10293
qir: List[Dict[str, Any]], n: int, extra_qir: Optional[List[Dict[str, Any]]] = None
@@ -124,51 +115,85 @@ def qir2cirq(
124115
:return: qiskit cirq object
125116
:rtype: Any
126117
127-
todo:
118+
#TODO(@erertertet):
128119
add default theta to iswap gate
129120
add more cirq built-in gate instead of customized
130121
add unitary test with tolerance
131122
add support of cirq built-in ControlledGate for multiplecontroll
132123
support more element in qir, e.g. barrier, measure...
133124
disable outputting controlled bit when creating controlled gate
134125
"""
126+
127+
class CustomizedCirqGate(cirq.Gate):
128+
def __init__(self, uMatrix: Any, name: str, nqubit: int):
129+
super(CustomizedCirqGate, self)
130+
self.uMatrix = uMatrix
131+
self.name = name
132+
self.nqubit = nqubit
133+
134+
def _num_qubits_(self) -> int:
135+
return self.nqubit
136+
137+
def _unitary_(self) -> Any:
138+
return self.uMatrix
139+
140+
def _circuit_diagram_info_(self) -> List[str]:
141+
return [self.name] * self.nqubit
142+
135143
if extra_qir is not None and len(extra_qir) > 0:
136144
qir = _merge_extra_qir(qir, extra_qir)
137145
qbits = cirq.LineQubit.range(n)
138-
cmd = []
146+
cmd = []
139147
for gate_info in qir:
140148
index = [qbits[i] for i in gate_info["index"]]
141149
gate_name = str(gate_info["gatef"])
142150
if "parameters" in gate_info:
143151
parameters = gate_info["parameters"]
144-
if gate_name in ["h","i","x","y","z","s","t","fredkin","toffoli","cnot","swap"]:
152+
if gate_name in [
153+
"h",
154+
"i",
155+
"x",
156+
"y",
157+
"z",
158+
"s",
159+
"t",
160+
"fredkin",
161+
"toffoli",
162+
"cnot",
163+
"swap",
164+
]:
145165
cmd.append(getattr(cirq, gate_name.upper())(*index))
146166
elif gate_name in ["rx", "ry", "rz"]:
147-
cmd.append(getattr(cirq, gate_name)(_get_float(parameters, "theta")).on(*index))
167+
cmd.append(
168+
getattr(cirq, gate_name)(_get_float(parameters, "theta")).on(*index)
169+
)
148170
elif gate_name == "iswap":
149-
if "theta" not in parameters:
150-
cmd.append(cirq.ISWAP(*index))
151-
# when ISWAP theta is not specified, _get_float will return default value of 0.0 instead of 1.0
152-
else:
153-
cmd.append(cirq.ISwapPowGate(exponent = _get_float(parameters, "theta")).on(*index))
171+
cmd.append(
172+
cirq.ISwapPowGate(
173+
exponent=_get_float(parameters, "theta", default=1)
174+
).on(*index)
175+
)
154176
elif gate_name in ["mpo", "multicontrol"]:
155177
gatem = np.reshape(
156-
backend.numpy(gate_info["gatef"](**parameters).eval_matrix()),
157-
[2 ** len(index), 2 ** len(index)],
158-
)
178+
backend.numpy(gate_info["gatef"](**parameters).eval_matrix()),
179+
[2 ** len(index), 2 ** len(index)],
180+
)
159181
ci_name = gate_info["name"]
160182
cgate = CustomizedCirqGate(gatem, ci_name, len(index))
161183
cmd.append(cgate.on(*index))
162184
else:
163185
# Add Customized Gate if there is no match
164-
gatem = np.reshape(gate_info["gate"].tensor,[2 ** len(index), 2 ** len(index)],
186+
gatem = np.reshape(
187+
gate_info["gate"].tensor,
188+
[2 ** len(index), 2 ** len(index)],
165189
)
166190
# Note: unitary test is not working for some of the generated matrix, probably add tolerance unitary test later
167191
cgate = CustomizedCirqGate(gatem, gate_name, len(index))
168192
cmd.append(cgate.on(*index))
169193
cirq_circuit = cirq.Circuit(*cmd)
170194
return cirq_circuit
171195

196+
172197
def qir2qiskit(
173198
qir: List[Dict[str, Any]], n: int, extra_qir: Optional[List[Dict[str, Any]]] = None
174199
) -> Any:

‎tests/test_circuit.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -857,11 +857,11 @@ def test_circuit_quoperator(backend):
857857
qo = c.quoperator()
858858
np.testing.assert_allclose(qo.eval_matrix(), c.matrix(), atol=1e-5)
859859

860+
860861
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
861862
def test_qir2cirq(backend):
862863
try:
863864
import cirq
864-
from tensorcircuit.translation import perm_matrix
865865
except ImportError:
866866
pytest.skip("cirq is not installed")
867867
n = 6
@@ -953,7 +953,7 @@ def test_qir2cirq(backend):
953953
cirq_unitary = cirq.unitary()
954954
cirq_unitary = np.reshape(cirq_unitary, [2**n, 2**n])
955955

956-
np.testing.assert_allclose(tc_unitary, cirq_unitary, atol = 1e-5)
956+
np.testing.assert_allclose(tc_unitary, cirq_unitary, atol=1e-5)
957957

958958

959959
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])

0 commit comments

Comments
 (0)