-
-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathtest_fit_shape.py
106 lines (91 loc) · 1.92 KB
/
test_fit_shape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Test model parameter shapes."""
import pytest
import stan
program_code = """
data {
int K;
int L;
int M;
int N;
int O;
int P;
int Q;
int R;
int S;
}
parameters {
array[K] real a;
array[L, M] real B;
vector[N] c;
matrix[O, P] D;
array[Q] matrix[R, S] E;
}
model {
for (k in 1 : K) {
a[k] ~ std_normal();
}
for (l in 1 : L) {
for (m in 1 : M) {
B[l, m] ~ std_normal();
}
}
for (n in 1 : N) {
c[n] ~ std_normal();
}
for (o in 1 : O) {
for (p in 1 : P) {
D[o, p] ~ std_normal();
}
}
for (q in 1 : Q) {
for (r in 1 : R) {
for (s in 1 : S) {
E[q, r, s] ~ std_normal();
}
}
}
}
"""
num_samples = 100
num_chains = 3
dims = {
"a": ("K",),
"B": ("L", "M"),
"c": ("N",),
"D": ("O", "P"),
"E": ("Q", "R", "S"),
}
def get_posterior(data):
return stan.build(program_code, data=data)
def get_fit(data):
posterior = get_posterior(data)
return posterior.sample(num_samples=num_samples, num_chains=num_chains)
def get_data(zero_dims):
data = {
"K": 2,
"L": 3,
"M": 2,
"N": 2,
"O": 3,
"P": 2,
"Q": 4,
"R": 3,
"S": 2,
}
for zero_dim in zero_dims:
assert zero_dim in data
data[zero_dim] = 0
return data
@pytest.mark.parametrize(
"zero_dims",
["K", "L", "M", "LM", "N", "O", "P", "OP", "Q", "R", "S", "QR", "QS", "RS", "QRS", "LMNOPQRS"],
)
def test_fit_empty_array_shape(zero_dims):
"""
Make sure shapes are correct.
"""
data = get_data(zero_dims)
fit = get_fit(data)
for parameter, dim in dims.items():
shape = tuple(map(data.get, dim)) + (num_samples * num_chains,)
assert fit[parameter].shape == shape