Skip to content

Commit b0d672b

Browse files
update test suits fiiting gpu platform
1 parent 0918eb2 commit b0d672b

File tree

3 files changed

+4
-1
lines changed

3 files changed

+4
-1
lines changed

check_all.sh

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ echo "pylint check"
88
pylint tensorcircuit tests examples/*.py
99
echo "pytest check"
1010
pytest -n auto --cov=tensorcircuit -vv -W ignore::DeprecationWarning
11+
# for test on gpu machine, please set `export TF_FORCE_GPU_ALLOW_GROWTH=true` for tf
12+
# and `export XLA_PYTHON_CLIENT_PREALLOCATE=false` for jax to avoid OOM in testing
1113
echo "sphinx check"
1214
cd docs && sphinx-build source build/html && sphinx-build source -D language="zh" build/html_cn
1315
echo "all checks passed, congratulation! 💐"

tests/test_backends.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def test_backend_methods_2(backend):
300300
def test_device_cpu_only(backend):
301301
a = tc.backend.ones([])
302302
dev_str = tc.backend.device(a)
303-
assert dev_str == "cpu"
303+
assert dev_str in ["cpu", "gpu:0"]
304304
tc.backend.device_move(a, dev_str)
305305

306306

tests/test_interfaces.py

+1
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def test_dlpack_transformation(backend):
297297
target_backend=b,
298298
enable_dlpack=True,
299299
)
300+
ans = tc.interfaces.which_backend(ans).device_move(ans, "cpu")
300301
np.testing.assert_allclose(ans, np.ones([2]))
301302

302303

0 commit comments

Comments
 (0)