Skip to content

Commit 3643b59

Browse files
detensorflow
1 parent 893eeff commit 3643b59

File tree

3 files changed

+11
-17
lines changed

3 files changed

+11
-17
lines changed

setup.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
import setuptools
22

33
from tensorcircuit import __version__, __author__
4-
from tensorcircuit.utils import is_m1mac
54

65
with open("README.md", "r") as fh:
76
long_description = fh.read()
87

9-
install_requires = ["numpy", "scipy", "tensornetwork", "networkx"]
10-
11-
if not is_m1mac():
12-
install_requires.append("tensorflow")
13-
# avoid the embarassing macos M1 chip case, where the package is called tensorflow-macos
148

159
setuptools.setup(
1610
name="tensorcircuit",
@@ -23,14 +17,13 @@
2317
url="https://github.com/tencent-quantum-lab/tensorcircuit",
2418
packages=setuptools.find_packages(),
2519
include_package_data=True,
26-
install_requires=install_requires,
27-
tests_require=[
28-
"pytest",
29-
"pytest-lazy-fixture",
30-
"pytest-cov",
31-
"pytest-benchmark",
32-
"pytest-xdist",
33-
],
20+
install_requires=["numpy", "scipy", "tensornetwork", "networkx"],
21+
extras_require={
22+
"tensorflow": ["tensorflow"],
23+
"jax": ["jax", "jaxlib"],
24+
"torch": ["torch"],
25+
"qiskit": ["qiskit"],
26+
},
3427
classifiers=[
3528
"Programming Language :: Python :: 3",
3629
"Operating System :: OS Independent",

tensorcircuit/interfaces/tensorflow.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from typing import Any, Callable, Tuple
66
from functools import wraps
77

8-
import tensorflow as tf
9-
108
from ..cons import backend
119
from ..utils import return_partial
1210
from .tensortrans import general_args_to_backend
@@ -30,6 +28,8 @@ def fun_tf(*x: Any) -> Any:
3028

3129

3230
def tf_dtype(dtype: str) -> Any:
31+
import tensorflow as tf
32+
3333
if isinstance(dtype, str):
3434
return getattr(tf, dtype)
3535
return dtype
@@ -72,6 +72,7 @@ def f(params):
7272
while AD is also supported
7373
:rtype: Callable[..., Any]
7474
"""
75+
import tensorflow as tf
7576

7677
if jit is True:
7778
fun = backend.jit(fun)

tensorcircuit/translation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def qiskit2tc(
254254
base_gate = gate_info[0].base_gate
255255
ctrl_state = [1] * base_gate.num_qubits
256256
idx = idx[: -base_gate.num_qubits] + idx[-base_gate.num_qubits :][::-1]
257-
print(idx)
257+
# print(idx)
258258
tc_circuit.multicontrol(
259259
*idx, ctrl=ctrl_state, unitary=base_gate.to_matrix()
260260
)

0 commit comments

Comments
 (0)