Skip to content

Commit 3af21a4

Browse files
add mat prod vmap example
1 parent 5f9851f commit 3af21a4

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- Add multiple GPU VQE examples using jax pmap
88

9+
- Add benchmark example showcasing new way of implementing matrix product using vmap
10+
911
## 0.10.0
1012

1113
### Added

examples/matprod_vmap.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
matrix product: a new twist
3+
rewrite matrix product in a vmap style
4+
"""
5+
from functools import partial
6+
7+
import numpy as np
8+
import tensorcircuit as tc
9+
10+
for bk in ["jax", "tensorflow"]:
11+
with tc.runtime_backend(bk) as K:
12+
print("~~~~~~~~~~~~~~~~~~~~~")
13+
print(f"using {K.name} backend")
14+
15+
@partial(K.jit, jit_compile=True)
16+
def mul(a, b):
17+
return a @ b
18+
19+
def ij(i, j):
20+
"""
21+
Inner product
22+
"""
23+
return K.tensordot(i, j, 1)
24+
25+
vij = K.vmap(ij, vectorized_argnums=1)
26+
vvij = K.vmap(vij, vectorized_argnums=0)
27+
28+
@partial(K.jit, jit_compile=True)
29+
def mul2(a, b):
30+
b = K.transpose(b)
31+
return vvij(a, b)
32+
33+
for shape in [(256, 4096), (4096, 256), (2048, 2048)]:
34+
print(shape)
35+
a = K.implicit_randn(shape)
36+
b = K.implicit_randn([shape[1], shape[0]])
37+
print("plain matprod")
38+
r1, _, _ = tc.utils.benchmark(mul, a, b, tries=10)
39+
print("vmap matprod")
40+
r2, _, _ = tc.utils.benchmark(mul2, a, b, tries=10)
41+
np.testing.assert_allclose(r1, r2, atol=1e-5)

0 commit comments

Comments
 (0)