Skip to content

Commit d7b5012

Browse files
committedMar 29, 2022
add mean and trig methods on backend
1 parent 472041e commit d7b5012

9 files changed

+423
-1
lines changed
 

‎CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
## Unreleased
44

5+
### Added
6+
7+
- add `utils.append` to build function pipeline
8+
9+
- add `mean` method on backends
10+
11+
- add trigonometric methods on backends
12+
513
## 0.0.220328
614

715
### Added

‎tensorcircuit/backends/abstract_backend.py

+154
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,138 @@ def cos(self: Any, a: Tensor) -> Tensor:
9797
"Backend '{}' has not implemented `cos`.".format(self.name)
9898
)
9999

100+
def acos(self: Any, a: Tensor) -> Tensor:
101+
"""
102+
Return the acos of a tensor ``a``.
103+
104+
:param a: tensor in matrix form
105+
:type a: Tensor
106+
:return: acos of ``a``
107+
:rtype: Tensor
108+
"""
109+
raise NotImplementedError(
110+
"Backend '{}' has not implemented `acos`.".format(self.name)
111+
)
112+
113+
def acosh(self: Any, a: Tensor) -> Tensor:
114+
"""
115+
Return the acosh of a tensor ``a``.
116+
117+
:param a: tensor in matrix form
118+
:type a: Tensor
119+
:return: acosh of ``a``
120+
:rtype: Tensor
121+
"""
122+
raise NotImplementedError(
123+
"Backend '{}' has not implemented `acosh`.".format(self.name)
124+
)
125+
126+
def asin(self: Any, a: Tensor) -> Tensor:
127+
"""
128+
Return the acos of a tensor ``a``.
129+
130+
:param a: tensor in matrix form
131+
:type a: Tensor
132+
:return: asin of ``a``
133+
:rtype: Tensor
134+
"""
135+
raise NotImplementedError(
136+
"Backend '{}' has not implemented `asin`.".format(self.name)
137+
)
138+
139+
def asinh(self: Any, a: Tensor) -> Tensor:
140+
"""
141+
Return the asinh of a tensor ``a``.
142+
143+
:param a: tensor in matrix form
144+
:type a: Tensor
145+
:return: asinh of ``a``
146+
:rtype: Tensor
147+
"""
148+
raise NotImplementedError(
149+
"Backend '{}' has not implemented `asinh`.".format(self.name)
150+
)
151+
152+
def atan(self: Any, a: Tensor) -> Tensor:
153+
"""
154+
Return the atan of a tensor ``a``.
155+
156+
:param a: tensor in matrix form
157+
:type a: Tensor
158+
:return: atan of ``a``
159+
:rtype: Tensor
160+
"""
161+
raise NotImplementedError(
162+
"Backend '{}' has not implemented `atan`.".format(self.name)
163+
)
164+
165+
def atan2(self: Any, y: Tensor, x: Tensor) -> Tensor:
166+
"""
167+
Return the atan of a tensor ``y``/``x``.
168+
169+
:param a: tensor in matrix form
170+
:type a: Tensor
171+
:return: atan2 of ``a``
172+
:rtype: Tensor
173+
"""
174+
raise NotImplementedError(
175+
"Backend '{}' has not implemented `atan2`.".format(self.name)
176+
)
177+
178+
def cosh(self: Any, a: Tensor) -> Tensor:
179+
"""
180+
Return the cosh of a tensor ``a``.
181+
182+
:param a: tensor in matrix form
183+
:type a: Tensor
184+
:return: cosh of ``a``
185+
:rtype: Tensor
186+
"""
187+
raise NotImplementedError(
188+
"Backend '{}' has not implemented `cosh`.".format(self.name)
189+
)
190+
191+
def tan(self: Any, a: Tensor) -> Tensor:
192+
"""
193+
Return the tan of a tensor ``a``.
194+
195+
:param a: tensor in matrix form
196+
:type a: Tensor
197+
:return: tan of ``a``
198+
:rtype: Tensor
199+
"""
200+
raise NotImplementedError(
201+
"Backend '{}' has not implemented `tan`.".format(self.name)
202+
)
203+
204+
def tanh(self: Any, a: Tensor) -> Tensor:
205+
"""
206+
Return the tanh of a tensor ``a``.
207+
208+
:param a: tensor in matrix form
209+
:type a: Tensor
210+
:return: tanh of ``a``
211+
:rtype: Tensor
212+
"""
213+
raise NotImplementedError(
214+
"Backend '{}' has not implemented `tanh`.".format(self.name)
215+
)
216+
217+
def sinh(self: Any, a: Tensor) -> Tensor:
218+
"""
219+
Return the sinh of a tensor ``a``.
220+
221+
:param a: tensor in matrix form
222+
:type a: Tensor
223+
:return: sinh of ``a``
224+
:rtype: Tensor
225+
"""
226+
raise NotImplementedError(
227+
"Backend '{}' has not implemented `sinh`.".format(self.name)
228+
)
229+
230+
# acos acosh asin asinh atan atan2 atanh cosh (cos) tan tanh (sin) sinh
231+
100232
def abs(self: Any, a: Tensor) -> Tensor:
101233
"""
102234
Return the elementwise abs value of a matrix ``a``.
@@ -283,6 +415,28 @@ def tile(self: Any, a: Tensor, rep: Tensor) -> Tensor:
283415
"Backend '{}' has not implemented `tile`.".format(self.name)
284416
)
285417

418+
def mean(
419+
self: Any,
420+
a: Tensor,
421+
axis: Optional[Sequence[int]] = None,
422+
keepdims: bool = False,
423+
) -> Tensor:
424+
"""
425+
Compute the arithmetic mean for ``a`` along the specified ``axis``.
426+
427+
:param a: tensor to take average
428+
:type a: Tensor
429+
:param axis: the axis to take mean, defaults to None indicating sum over flatten array
430+
:type axis: Optional[Sequence[int]], optional
431+
:param keepdims: _description_, defaults to False
432+
:type keepdims: bool, optional
433+
:return: _description_
434+
:rtype: Tensor
435+
"""
436+
raise NotImplementedError(
437+
"Backend '{}' has not implemented `mean`.".format(self.name)
438+
)
439+
286440
def min(self: Any, a: Tensor, axis: Optional[int] = None) -> Tensor:
287441
"""
288442
Return the minimum of an array or minimum along an axis.

‎tensorcircuit/backends/jax_backend.py

+41
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,39 @@ def sin(self, a: Tensor) -> Tensor:
247247
def cos(self, a: Tensor) -> Tensor:
248248
return jnp.cos(a)
249249

250+
def acos(self, a: Tensor) -> Tensor:
251+
return jnp.arccos(a)
252+
253+
def acosh(self, a: Tensor) -> Tensor:
254+
return jnp.arccosh(a)
255+
256+
def asin(self, a: Tensor) -> Tensor:
257+
return jnp.arcsin(a)
258+
259+
def asinh(self, a: Tensor) -> Tensor:
260+
return jnp.arcsinh(a)
261+
262+
def atan(self, a: Tensor) -> Tensor:
263+
return jnp.arctan(a)
264+
265+
def atan2(self, y: Tensor, x: Tensor) -> Tensor:
266+
return jnp.arctan2(y, x)
267+
268+
def atanh(self, a: Tensor) -> Tensor:
269+
return jnp.arctanh(a)
270+
271+
def cosh(self, a: Tensor) -> Tensor:
272+
return jnp.cosh(a)
273+
274+
def tan(self, a: Tensor) -> Tensor:
275+
return jnp.tan(a)
276+
277+
def tanh(self, a: Tensor) -> Tensor:
278+
return jnp.tanh(a)
279+
280+
def sinh(self, a: Tensor) -> Tensor:
281+
return jnp.sinh(a)
282+
250283
def size(self, a: Tensor) -> Tensor:
251284
return jnp.size(a)
252285

@@ -288,6 +321,14 @@ def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
288321
def tile(self, a: Tensor, rep: Tensor) -> Tensor:
289322
return jnp.tile(a, rep)
290323

324+
def mean(
325+
self,
326+
a: Tensor,
327+
axis: Optional[Sequence[int]] = None,
328+
keepdims: bool = False,
329+
) -> Tensor:
330+
return jnp.mean(a, axis=axis, keepdims=keepdims)
331+
291332
def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
292333
return jnp.min(a, axis=axis)
293334

‎tensorcircuit/backends/numpy_backend.py

+42
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,40 @@ def sin(self, a: Tensor) -> Tensor:
9191
def cos(self, a: Tensor) -> Tensor:
9292
return np.cos(a)
9393

94+
# acos acosh asin asinh atan atan2 atanh cosh (cos) tan tanh sinh (sin)
95+
def acos(self, a: Tensor) -> Tensor:
96+
return np.arccos(a)
97+
98+
def acosh(self, a: Tensor) -> Tensor:
99+
return np.arccosh(a)
100+
101+
def asin(self, a: Tensor) -> Tensor:
102+
return np.arcsin(a)
103+
104+
def asinh(self, a: Tensor) -> Tensor:
105+
return np.arcsinh(a)
106+
107+
def atan(self, a: Tensor) -> Tensor:
108+
return np.arctan(a)
109+
110+
def atan2(self, y: Tensor, x: Tensor) -> Tensor:
111+
return np.arctan2(y, x)
112+
113+
def atanh(self, a: Tensor) -> Tensor:
114+
return np.arctanh(a)
115+
116+
def cosh(self, a: Tensor) -> Tensor:
117+
return np.cosh(a)
118+
119+
def tan(self, a: Tensor) -> Tensor:
120+
return np.tan(a)
121+
122+
def tanh(self, a: Tensor) -> Tensor:
123+
return np.tanh(a)
124+
125+
def sinh(self, a: Tensor) -> Tensor:
126+
return np.sinh(a)
127+
94128
def size(self, a: Tensor) -> Tensor:
95129
return a.size
96130

@@ -116,6 +150,14 @@ def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
116150
def tile(self, a: Tensor, rep: Tensor) -> Tensor:
117151
return np.tile(a, rep)
118152

153+
def mean(
154+
self,
155+
a: Tensor,
156+
axis: Optional[Sequence[int]] = None,
157+
keepdims: bool = False,
158+
) -> Tensor:
159+
return np.mean(a, axis=axis, keepdims=keepdims)
160+
119161
def unique_with_counts(self, a: Tensor) -> Tuple[Tensor, Tensor]:
120162
return np.unique(a, return_counts=True) # type: ignore
121163

‎tensorcircuit/backends/pytorch_backend.py

+43
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,39 @@ def sin(self, a: Tensor) -> Tensor:
217217
def cos(self, a: Tensor) -> Tensor:
218218
return torchlib.cos(a)
219219

220+
def acos(self, a: Tensor) -> Tensor:
221+
return torchlib.acos(a)
222+
223+
def acosh(self, a: Tensor) -> Tensor:
224+
return torchlib.acosh(a)
225+
226+
def asin(self, a: Tensor) -> Tensor:
227+
return torchlib.asin(a)
228+
229+
def asinh(self, a: Tensor) -> Tensor:
230+
return torchlib.asinh(a)
231+
232+
def atan(self, a: Tensor) -> Tensor:
233+
return torchlib.atan(a)
234+
235+
def atan2(self, y: Tensor, x: Tensor) -> Tensor:
236+
return torchlib.atan2(y, x)
237+
238+
def atanh(self, a: Tensor) -> Tensor:
239+
return torchlib.atanh(a)
240+
241+
def cosh(self, a: Tensor) -> Tensor:
242+
return torchlib.cosh(a)
243+
244+
def tan(self, a: Tensor) -> Tensor:
245+
return torchlib.tan(a)
246+
247+
def tanh(self, a: Tensor) -> Tensor:
248+
return torchlib.tanh(a)
249+
250+
def sinh(self, a: Tensor) -> Tensor:
251+
return torchlib.sinh(a)
252+
220253
def size(self, a: Tensor) -> Tensor:
221254
return a.size()
222255

@@ -258,6 +291,16 @@ def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
258291
def tile(self, a: Tensor, rep: Tensor) -> Tensor:
259292
return torchlib.tile(a, rep)
260293

294+
def mean(
295+
self,
296+
a: Tensor,
297+
axis: Optional[Sequence[int]] = None,
298+
keepdims: bool = False,
299+
) -> Tensor:
300+
if axis is None:
301+
axis = tuple([i for i in range(len(a.shape))])
302+
return torchlib.mean(a, dim=axis, keepdim=keepdims)
303+
261304
def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
262305
if axis is None:
263306
return torchlib.min(a)

‎tensorcircuit/backends/tensorflow_backend.py

+41
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,39 @@ def sin(self, a: Tensor) -> Tensor:
278278
def cos(self, a: Tensor) -> Tensor:
279279
return tf.math.cos(a)
280280

281+
def acos(self, a: Tensor) -> Tensor:
282+
return tf.math.acos(a)
283+
284+
def acosh(self, a: Tensor) -> Tensor:
285+
return tf.math.acosh(a)
286+
287+
def asin(self, a: Tensor) -> Tensor:
288+
return tf.math.asin(a)
289+
290+
def asinh(self, a: Tensor) -> Tensor:
291+
return tf.math.asinh(a)
292+
293+
def atan(self, a: Tensor) -> Tensor:
294+
return tf.math.atan(a)
295+
296+
def atan2(self, y: Tensor, x: Tensor) -> Tensor:
297+
return tf.math.atan2(y, x)
298+
299+
def atanh(self, a: Tensor) -> Tensor:
300+
return tf.math.atanh(a)
301+
302+
def cosh(self, a: Tensor) -> Tensor:
303+
return tf.math.cosh(a)
304+
305+
def tan(self, a: Tensor) -> Tensor:
306+
return tf.math.tan(a)
307+
308+
def tanh(self, a: Tensor) -> Tensor:
309+
return tf.math.tanh(a)
310+
311+
def sinh(self, a: Tensor) -> Tensor:
312+
return tf.math.sinh(a)
313+
281314
def size(self, a: Tensor) -> Tensor:
282315
return tf.size(a)
283316

@@ -325,6 +358,14 @@ def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
325358
def tile(self, a: Tensor, rep: Tensor) -> Tensor:
326359
return tf.tile(a, rep)
327360

361+
def mean(
362+
self,
363+
a: Tensor,
364+
axis: Optional[Sequence[int]] = None,
365+
keepdims: bool = False,
366+
) -> Tensor:
367+
return tf.math.reduce_mean(a, axis=axis, keepdims=keepdims)
368+
328369
def sigmoid(self, a: Tensor) -> Tensor:
329370
return tf.nn.sigmoid(a)
330371

‎tensorcircuit/circuit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def depolarizing(
761761
) -> float:
762762
# px/y/z here not support differentiation for now
763763
# jit compatible for now
764-
assert px + py + pz < 1 and px >= 0 and py >= 0 and pz >= 0
764+
# assert px + py + pz < 1 and px >= 0 and py >= 0 and pz >= 0
765765

766766
def step_function(x: Tensor) -> Tensor:
767767
r = (

‎tensorcircuit/utils.py

+28
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,31 @@ def wrapper(*args: Any, **kws: Any) -> Any:
5353
return tuple(nr)
5454

5555
return wrapper
56+
57+
58+
def append(f: Callable[..., Any], *op: Callable[..., Any]) -> Any:
59+
"""
60+
Functional programming paradigm to build function pipeline
61+
62+
:Example:
63+
64+
>>> f = tc.utils.append(lambda x: x**2, lambda x: x+1, tc.backend.mean)
65+
>>> f(tc.backend.ones(2))
66+
(2+0j)
67+
68+
:param f: The function which are attached with other functions
69+
:type f: Callable[..., Any]
70+
:param op: Function to be attached
71+
:type op: Callable[..., Any]
72+
:return: The final results after function pipeline
73+
:rtype: Any
74+
"""
75+
76+
@wraps(f)
77+
def wrapper(*args: Any, **kws: Any) -> Any:
78+
rs = f(*args, **kws)
79+
for opi in op:
80+
rs = opi(rs)
81+
return rs
82+
83+
return wrapper

‎tests/test_backends.py

+65
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,71 @@ def test_backend_methods(backend):
179179
)
180180

181181

182+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
183+
def test_backend_methods_2(backend):
184+
np.testing.assert_allclose(tc.backend.mean(tc.backend.ones([10])), 1.0, atol=1e-5)
185+
# acos acosh asin asinh atan atan2 atanh cosh (cos) tan tanh sinh (sin)
186+
np.testing.assert_allclose(
187+
tc.backend.acos(tc.backend.ones([2], dtype="float32")),
188+
np.arccos(tc.backend.ones([2])),
189+
atol=1e-5,
190+
)
191+
np.testing.assert_allclose(
192+
tc.backend.acosh(tc.backend.ones([2], dtype="float32")),
193+
np.arccosh(tc.backend.ones([2])),
194+
atol=1e-5,
195+
)
196+
np.testing.assert_allclose(
197+
tc.backend.asin(tc.backend.ones([2], dtype="float32")),
198+
np.arcsin(tc.backend.ones([2])),
199+
atol=1e-5,
200+
)
201+
np.testing.assert_allclose(
202+
tc.backend.asinh(tc.backend.ones([2], dtype="float32")),
203+
np.arcsinh(tc.backend.ones([2])),
204+
atol=1e-5,
205+
)
206+
np.testing.assert_allclose(
207+
tc.backend.atan(0.5 * tc.backend.ones([2], dtype="float32")),
208+
np.arctan(0.5 * tc.backend.ones([2])),
209+
atol=1e-5,
210+
)
211+
np.testing.assert_allclose(
212+
tc.backend.atan2(
213+
tc.backend.ones([1], dtype="float32"), tc.backend.ones([1], dtype="float32")
214+
),
215+
np.arctan2(
216+
tc.backend.ones([1], dtype="float32"), tc.backend.ones([1], dtype="float32")
217+
),
218+
atol=1e-5,
219+
)
220+
np.testing.assert_allclose(
221+
tc.backend.atanh(0.5 * tc.backend.ones([2], dtype="float32")),
222+
np.arctanh(0.5 * tc.backend.ones([2])),
223+
atol=1e-5,
224+
)
225+
np.testing.assert_allclose(
226+
tc.backend.cosh(tc.backend.ones([2], dtype="float32")),
227+
np.cosh(tc.backend.ones([2])),
228+
atol=1e-5,
229+
)
230+
np.testing.assert_allclose(
231+
tc.backend.tan(tc.backend.ones([2], dtype="float32")),
232+
np.tan(tc.backend.ones([2])),
233+
atol=1e-5,
234+
)
235+
np.testing.assert_allclose(
236+
tc.backend.tanh(tc.backend.ones([2], dtype="float32")),
237+
np.tanh(tc.backend.ones([2])),
238+
atol=1e-5,
239+
)
240+
np.testing.assert_allclose(
241+
tc.backend.sinh(0.5 * tc.backend.ones([2], dtype="float32")),
242+
np.sinh(0.5 * tc.backend.ones([2])),
243+
atol=1e-5,
244+
)
245+
246+
182247
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
183248
def test_arg_cmp(backend):
184249
np.testing.assert_allclose(tc.backend.argmax(tc.backend.ones([3], "float64")), 0)

0 commit comments

Comments
 (0)
Please sign in to comment.