Skip to content

Commit 29be87a

Browse files
committedJun 5, 2022
better m1 compatibility
1 parent 44aaf56 commit 29be87a

File tree

4 files changed

+23
-7
lines changed

4 files changed

+23
-7
lines changed
 

‎check_all.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ pylint tensorcircuit tests examples/*.py
99
echo "pytest check"
1010
pytest -n 4 --cov=tensorcircuit -vv -W ignore::DeprecationWarning
1111
echo "sphinx check"
12-
cd docs && sphinx-build source build/html && sphinx-build source -D language="zh" -D master_doc=index_cn build/html_cn
12+
cd docs && sphinx-build source build/html && sphinx-build source -D language="zh" build/html_cn
1313
echo "all checks passed, congratulation! 💐"

‎setup.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
import setuptools
2-
import platform
32

43
from tensorcircuit import __version__, __author__
5-
4+
from tensorcircuit.utils import is_m1mac
65

76
with open("README.md", "r") as fh:
87
long_description = fh.read()
98

109
install_requires = ["numpy", "scipy", "tensornetwork", "networkx"]
1110

12-
if platform.processor() != "arm":
11+
if not is_m1mac():
1312
install_requires.append("tensorflow")
1413
# avoid the embarassing macos M1 chip case, where the package is called tensorflow-macos
1514

16-
# TODO(@refraction-ray): add better check_m1 function
17-
1815
setuptools.setup(
1916
name="tensorcircuit",
2017
version=__version__,

‎tensorcircuit/quantum.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from .cons import backend, contractor, dtypestr, npdtype
3939
from .backends import get_backend # type: ignore
40+
from .utils import is_m1mac
4041

4142
Tensor = Any
4243
Graph = Any
@@ -1103,7 +1104,10 @@ def quimb2qop(qb_mpo: Any) -> QuOperator:
11031104

11041105

11051106
try:
1106-
compiled_jit = partial(get_backend("tensorflow").jit, jit_compile=True)
1107+
if is_m1mac():
1108+
compiled_jit = lambda x: x
1109+
else:
1110+
compiled_jit = partial(get_backend("tensorflow").jit, jit_compile=True)
11071111

11081112
def heisenberg_hamiltonian(
11091113
g: Graph,

‎tensorcircuit/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import Any, Callable, Union, Sequence
66
from functools import wraps
7+
import platform
78

89

910
def return_partial(
@@ -81,3 +82,17 @@ def wrapper(*args: Any, **kws: Any) -> Any:
8182
return rs
8283

8384
return wrapper
85+
86+
87+
def is_m1mac() -> bool:
88+
"""
89+
check whether the running platform is MAC with M1 chip
90+
91+
:return: True for MAC M1 platform
92+
:rtype: bool
93+
"""
94+
if platform.processor() != "arm":
95+
return False
96+
if not platform.platform().startswith("macOS"):
97+
return False
98+
return True

0 commit comments

Comments
 (0)