-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathtest_transforms_axioms.py
103 lines (90 loc) · 3.39 KB
/
test_transforms_axioms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Tests for group axioms.
https://proofwiki.org/wiki/Definition:Group_Axioms
"""
from typing import Tuple, Type
import numpy as np
import numpy.typing as onpt
from utils import (
assert_arrays_close,
assert_transforms_close,
general_group_test,
sample_transform,
)
import viser.transforms as vtf
@general_group_test
def test_closure(
Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike
):
"""Check closure property."""
transform_a = sample_transform(Group, batch_axes, dtype)
transform_b = sample_transform(Group, batch_axes, dtype)
composed = transform_a @ transform_b
assert_transforms_close(composed, composed.normalize())
composed = transform_b @ transform_a
assert_transforms_close(composed, composed.normalize())
composed = Group.multiply(transform_a, transform_b)
assert_transforms_close(composed, composed.normalize())
composed = Group.multiply(transform_b, transform_a)
assert_transforms_close(composed, composed.normalize())
@general_group_test
def test_identity(
Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike
):
"""Check identity property."""
transform = sample_transform(Group, batch_axes, dtype)
identity = Group.identity(batch_axes, dtype=dtype)
assert_transforms_close(transform, identity @ transform)
assert_transforms_close(transform, transform @ identity)
assert_arrays_close(
transform.as_matrix(),
np.einsum("...ij,...jk->...ik", identity.as_matrix(), transform.as_matrix()),
)
assert_arrays_close(
transform.as_matrix(),
np.einsum("...ij,...jk->...ik", transform.as_matrix(), identity.as_matrix()),
)
@general_group_test
def test_inverse(
Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike
):
"""Check inverse property."""
transform = sample_transform(Group, batch_axes, dtype)
identity = Group.identity(batch_axes, dtype=dtype)
assert_transforms_close(identity, transform @ transform.inverse())
assert_transforms_close(identity, transform.inverse() @ transform)
assert_transforms_close(identity, Group.multiply(transform, transform.inverse()))
assert_transforms_close(identity, Group.multiply(transform.inverse(), transform))
assert_arrays_close(
np.broadcast_to(
np.eye(Group.matrix_dim, dtype=dtype),
(*batch_axes, Group.matrix_dim, Group.matrix_dim),
),
np.einsum(
"...ij,...jk->...ik",
transform.as_matrix(),
transform.inverse().as_matrix(),
),
)
assert_arrays_close(
np.broadcast_to(
np.eye(Group.matrix_dim, dtype=dtype),
(*batch_axes, Group.matrix_dim, Group.matrix_dim),
),
np.einsum(
"...ij,...jk->...ik",
transform.inverse().as_matrix(),
transform.as_matrix(),
),
)
@general_group_test
def test_associative(
Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike
):
"""Check associative property."""
transform_a = sample_transform(Group, batch_axes, dtype)
transform_b = sample_transform(Group, batch_axes, dtype)
transform_c = sample_transform(Group, batch_axes, dtype)
assert_transforms_close(
(transform_a @ transform_b) @ transform_c,
transform_a @ (transform_b @ transform_c),
)