|
29 | 29 |
|
30 | 30 |
|
31 | 31 | Tensor = Any
|
32 |
| - |
| 32 | +PRNGKeyArray = Any # libjax.random.PRNGKeyArray |
| 33 | +RGenerator = Any # tf.random.Generator |
33 | 34 |
|
34 | 35 | libjax: Any
|
35 | 36 | jnp: Any
|
@@ -356,6 +357,77 @@ def tree_map( # pylint: disable=unused-variable
|
356 | 357 |
|
357 | 358 | return r
|
358 | 359 |
|
| 360 | + def set_random_state( # pylint: disable=unused-variable |
| 361 | + self: Any, seed: Optional[int] = None |
| 362 | + ) -> None: |
| 363 | + """ |
| 364 | + set random state attached in the backend |
| 365 | +
|
| 366 | + :param seed: int, defaults to None |
| 367 | + :type seed: Optional[int], optional |
| 368 | + """ |
| 369 | + raise NotImplementedError( |
| 370 | + "Backend '{}' has not implemented `set_random_state`.".format(self.name) |
| 371 | + ) |
| 372 | + |
| 373 | + def implicit_randn( # pylint: disable=unused-variable |
| 374 | + self: Any, |
| 375 | + shape: Union[int, Sequence[int]] = 1, |
| 376 | + mean: float = 0, |
| 377 | + stddev: float = 1, |
| 378 | + dtype: str = "32", |
| 379 | + ) -> Tensor: |
| 380 | + """ |
| 381 | + call random normal function with the random state management behind the scene |
| 382 | +
|
| 383 | + :param shape: [description], defaults to 1 |
| 384 | + :type shape: Union[int, Sequence[int]], optional |
| 385 | + :param mean: [description], defaults to 0 |
| 386 | + :type mean: float, optional |
| 387 | + :param stddev: [description], defaults to 1 |
| 388 | + :type stddev: float, optional |
| 389 | + :param dtype: [description], defaults to "32" |
| 390 | + :type dtype: str, optional |
| 391 | + :return: [description] |
| 392 | + :rtype: Tensor |
| 393 | + """ |
| 394 | + g = getattr(self, "g", None) |
| 395 | + if g is None: |
| 396 | + self.set_random_state() |
| 397 | + g = getattr(self, "g", None) |
| 398 | + r = self.stateful_randn(g, shape, mean, stddev, dtype) |
| 399 | + return r |
| 400 | + |
| 401 | + def stateful_randn( # pylint: disable=unused-variable |
| 402 | + self: Any, |
| 403 | + g: Any, |
| 404 | + shape: Union[int, Sequence[int]] = 1, |
| 405 | + mean: float = 0, |
| 406 | + stddev: float = 1, |
| 407 | + dtype: str = "32", |
| 408 | + ) -> Tensor: |
| 409 | + """ |
| 410 | + [summary] |
| 411 | +
|
| 412 | + :param self: [description] |
| 413 | + :type self: Any |
| 414 | + :param g: stateful register for each package |
| 415 | + :type g: Any |
| 416 | + :param shape: shape of output sampling tensor |
| 417 | + :type shape: Union[int, Sequence[int]] |
| 418 | + :param mean: [description], defaults to 0 |
| 419 | + :type mean: float, optional |
| 420 | + :param stddev: [description], defaults to 1 |
| 421 | + :type stddev: float, optional |
| 422 | + :param dtype: only real data type is supported, "32" or "64", defaults to "32" |
| 423 | + :type dtype: str, optional |
| 424 | + :return: [description] |
| 425 | + :rtype: Tensor |
| 426 | + """ |
| 427 | + raise NotImplementedError( |
| 428 | + "Backend '{}' has not implemented `stateful_randn`.".format(self.name) |
| 429 | + ) |
| 430 | + |
359 | 431 | def grad( # pylint: disable=unused-variable
|
360 | 432 | self: Any, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
|
361 | 433 | ) -> Callable[..., Any]:
|
@@ -580,6 +652,33 @@ def real(self, a: Tensor) -> Tensor:
|
580 | 652 | def cast(self, a: Tensor, dtype: str) -> Tensor:
|
581 | 653 | return a.astype(getattr(np, dtype))
|
582 | 654 |
|
| 655 | + def set_random_state(self, seed: Optional[int] = None) -> None: |
| 656 | + g = np.random.default_rng(seed) # None auto supported |
| 657 | + self.g = g |
| 658 | + |
| 659 | + def stateful_randn( |
| 660 | + self, |
| 661 | + g: np.random.Generator, |
| 662 | + shape: Union[int, Sequence[int]] = 1, |
| 663 | + mean: float = 0, |
| 664 | + stddev: float = 1, |
| 665 | + dtype: str = "32", |
| 666 | + ) -> Tensor: |
| 667 | + if isinstance(dtype, str): |
| 668 | + dtype = dtype[-2:] |
| 669 | + if isinstance(shape, int): |
| 670 | + shape = (shape,) |
| 671 | + r = g.normal(loc=mean, scale=stddev, size=shape) # type: ignore |
| 672 | + if dtype == "32": |
| 673 | + r = r.astype(np.float32) |
| 674 | + elif dtype == "64": |
| 675 | + r = r.astype(np.float64) |
| 676 | + elif not isinstance(dtype, str): |
| 677 | + r = r.astype(dtype) |
| 678 | + else: |
| 679 | + raise ValueError("unspported `dtype` %s" % dtype) |
| 680 | + return r |
| 681 | + |
583 | 682 | def grad(
|
584 | 683 | self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
|
585 | 684 | ) -> Callable[..., Any]:
|
@@ -735,6 +834,49 @@ def is_tensor(self, a: Any) -> bool:
|
735 | 834 | return True
|
736 | 835 | return False
|
737 | 836 |
|
| 837 | + def set_random_state(self, seed: Optional[int] = None) -> None: |
| 838 | + if seed is None: |
| 839 | + seed = np.random.randint(42) |
| 840 | + g = libjax.random.PRNGKey(seed) |
| 841 | + self.g = g |
| 842 | + |
| 843 | + def implicit_randn( |
| 844 | + self, |
| 845 | + shape: Union[int, Sequence[int]] = 1, |
| 846 | + mean: float = 0, |
| 847 | + stddev: float = 1, |
| 848 | + dtype: str = "32", |
| 849 | + ) -> Tensor: |
| 850 | + g = getattr(self, "g", None) |
| 851 | + if g is None: |
| 852 | + self.set_random_state() |
| 853 | + g = getattr(self, "g", None) |
| 854 | + key, subkey = libjax.random.split(g) |
| 855 | + r = self.stateful_randn(subkey, shape, mean, stddev, dtype) |
| 856 | + self.g = key |
| 857 | + return r |
| 858 | + |
| 859 | + def stateful_randn( |
| 860 | + self, |
| 861 | + g: PRNGKeyArray, |
| 862 | + shape: Union[int, Sequence[int]] = 1, |
| 863 | + mean: float = 0, |
| 864 | + stddev: float = 1, |
| 865 | + dtype: str = "32", |
| 866 | + ) -> Tensor: |
| 867 | + if isinstance(dtype, str): |
| 868 | + dtype = dtype[-2:] |
| 869 | + if isinstance(shape, int): |
| 870 | + shape = (shape,) |
| 871 | + if dtype == "32": |
| 872 | + dtyper = jnp.float32 |
| 873 | + elif dtype == "64": |
| 874 | + dtyper = jnp.float64 |
| 875 | + elif not isinstance(dtype, str): |
| 876 | + dtyper = dtype |
| 877 | + r = libjax.random.normal(g, shape=shape, dtype=dtyper) * stddev + mean |
| 878 | + return r |
| 879 | + |
738 | 880 | def grad(
|
739 | 881 | self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
|
740 | 882 | ) -> Any:
|
@@ -824,6 +966,38 @@ def _outer_product_tf(self: Any, tensor1: Tensor, tensor2: Tensor) -> Tensor:
|
824 | 966 | return tf.tensordot(tensor1, tensor2, 0)
|
825 | 967 |
|
826 | 968 |
|
| 969 | +def _random_choice_tf( |
| 970 | + g: RGenerator, |
| 971 | + a: Union[int, Sequence[int], Tensor], |
| 972 | + shape: Union[int, Sequence[int]], |
| 973 | + p: Optional[Union[Sequence[float], Tensor]] = None, |
| 974 | +) -> Tensor: |
| 975 | + # only replace=True support, replace=False is not implemented |
| 976 | + # for stateless random module, tf has corresponding categorical function similar to choice |
| 977 | + # however, such utility is not implemented with ``tf.random.Generator`` |
| 978 | + if isinstance(a, int): |
| 979 | + assert a > 0 |
| 980 | + a = tf.range(a) |
| 981 | + assert len(a.shape) == 1 |
| 982 | + if isinstance(shape, int): |
| 983 | + shape = (shape,) |
| 984 | + if p is None: |
| 985 | + dtype = tf.float32 |
| 986 | + p = tf.ones_like(a) |
| 987 | + p = tf.cast(p, dtype=dtype) |
| 988 | + p /= tf.reduce_sum(p) |
| 989 | + else: |
| 990 | + if not isinstance(p, tf.Tensor): |
| 991 | + p = tf.constant(p) |
| 992 | + dtype = p.dtype |
| 993 | + shape1 = reduce(mul, shape) |
| 994 | + p_cuml = tf.cumsum(p) |
| 995 | + r = p_cuml[-1] * (1 - g.uniform([shape1], dtype=dtype)) |
| 996 | + ind = tf.searchsorted(p_cuml, r) |
| 997 | + res = tf.gather(a, ind) |
| 998 | + return tf.reshape(res, shape) |
| 999 | + |
| 1000 | + |
827 | 1001 | # temporary hot replace until new version of tensorflow is released,
|
828 | 1002 | # see issue: https://github.com/google/TensorNetwork/issues/940
|
829 | 1003 | # avoid buggy tensordot2 in tensornetwork
|
@@ -910,6 +1084,35 @@ def real(self, a: Tensor) -> Tensor:
|
910 | 1084 | def cast(self, a: Tensor, dtype: str) -> Tensor:
|
911 | 1085 | return tf.cast(a, dtype=getattr(tf, dtype))
|
912 | 1086 |
|
| 1087 | + def set_random_state(self, seed: Optional[Union[int, RGenerator]] = None) -> None: |
| 1088 | + if seed is None: |
| 1089 | + g = tf.random.Generator.from_non_deterministic_state() |
| 1090 | + elif isinstance(seed, int): |
| 1091 | + g = tf.random.Generator.from_seed(seed) |
| 1092 | + else: |
| 1093 | + g = seed |
| 1094 | + self.g = g |
| 1095 | + |
| 1096 | + def stateful_randn( |
| 1097 | + self, |
| 1098 | + g: RGenerator, |
| 1099 | + shape: Union[int, Sequence[int]] = 1, |
| 1100 | + mean: float = 0, |
| 1101 | + stddev: float = 1, |
| 1102 | + dtype: str = "32", |
| 1103 | + ) -> Tensor: |
| 1104 | + if isinstance(dtype, str): |
| 1105 | + dtype = dtype[-2:] |
| 1106 | + if isinstance(shape, int): |
| 1107 | + shape = (shape,) |
| 1108 | + if dtype == "32": |
| 1109 | + dtyper = tf.float32 |
| 1110 | + elif dtype == "64": |
| 1111 | + dtyper = tf.float64 |
| 1112 | + elif not isinstance(dtype, str): |
| 1113 | + dtyper = dtype |
| 1114 | + return g.normal(shape, mean, stddev, dtype=dtyper) |
| 1115 | + |
913 | 1116 | def grad(
|
914 | 1117 | self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
|
915 | 1118 | ) -> Callable[..., Any]:
|
@@ -1090,6 +1293,8 @@ def wrapper(
|
1090 | 1293 | vvag = vectorized_value_and_grad
|
1091 | 1294 |
|
1092 | 1295 |
|
| 1296 | +# TODO(@refraction-ray): lack stateful random methods implementation for now |
| 1297 | +# To be added once pytorch backend is ready |
1093 | 1298 | class PyTorchBackend(pytorch_backend.PyTorchBackend): # type: ignore
|
1094 | 1299 | def __init__(self) -> None:
|
1095 | 1300 | super(PyTorchBackend, self).__init__()
|
|
0 commit comments