1
+ #include < bits/stdc++.h>
2
+ using namespace std ;
3
+
4
+ constexpr int MAX_SIZE = 10 ;
5
+
6
+ int st[4 * MAX_SIZE];
7
+
8
+ int combine (int x, int y)
9
+ {
10
+ return x + y; // sum for sum queries
11
+ }
12
+
13
+ void buildST (int a[], int v, int l, int r)
14
+ {
15
+ if (l == r)
16
+ st[v] = a[l];
17
+ else
18
+ {
19
+ int mid = (l + r) / 2 ;
20
+ buildST (a, 2 * v, l, mid);
21
+ buildST (a, 2 * v + 1 , mid + 1 , r);
22
+ st[v] = combine (st[v * 2 ], st[v * 2 + 1 ]);
23
+ }
24
+ }
25
+
26
+ int sumQuery (int v, int nL, int nR, int l, int r)
27
+ {
28
+ if (l > r)
29
+ return 0 ;
30
+ if (nL == l && nR == r)
31
+ return st[v];
32
+ int mid = (nL + nR) / 2 ;
33
+ return combine (
34
+ sumQuery (v * 2 , nL, mid, l, min (r, mid)),
35
+ sumQuery (v * 2 + 1 , mid + 1 , nR, max (l, mid + 1 ), r)
36
+ );
37
+ }
38
+
39
+ void updateST (int v, int nL, int nR, int pos, int newVal)
40
+ {
41
+ if (nL == nR)
42
+ {
43
+ st[v] = newVal;
44
+ return ;
45
+ }
46
+ int mid = (nL + nR) / 2 ;
47
+ if (pos <= mid)
48
+ updateST (2 * v, nL, mid, pos, newVal);
49
+ else
50
+ updateST (2 * v + 1 , mid + 1 , nR, pos, newVal);
51
+ st[v] = combine (st[2 * v], st[2 * v + 1 ]);
52
+ }
53
+
54
+ int32_t main ()
55
+ {
56
+ int a[] = {1 , 2 , 3 , 4 , 5 , 6 };
57
+ int n = sizeof (a) / sizeof (int );
58
+
59
+ buildST (a, 1 , 0 , n - 1 );
60
+
61
+ // Test to check if seg tree works correctly.
62
+ random_device dev;
63
+ mt19937 rng (dev ());
64
+
65
+ uniform_int_distribution<mt19937::result_type> dist (0 ,1000 ); // distribution in range [1, 6]
66
+
67
+ int numTests = 1000 ; // Test thousand times.
68
+ int remTest = numTests;
69
+ int valueUpdateCount = 0 ;
70
+ while (remTest--)
71
+ {
72
+ int type = dist (rng) % 2 ;
73
+ if (!type)
74
+ {
75
+ // Do a random update
76
+ int idx = dist (rng) % n;
77
+ int newVal = dist (rng);
78
+
79
+ updateST (1 , 0 , n - 1 , idx, newVal);
80
+ a[idx] = newVal;
81
+ valueUpdateCount++;
82
+ }else
83
+ {
84
+ // Do a random range query
85
+ int l = dist (rng) % n, r = dist (rng) % n;
86
+ if (l > r) swap (l, r);
87
+
88
+ int stAns = sumQuery (1 , 0 , n - 1 , l, r);
89
+ int actualAns = 0 ;
90
+ for (int i = l; i <= r; i++)
91
+ actualAns += a[i];
92
+
93
+ if (actualAns != stAns)
94
+ {
95
+ cout << " Test failed\n " ;
96
+ return 0 ;
97
+ }
98
+ }
99
+ }
100
+ cout << " Test Passed\n " ;
101
+ cout << " valueUpdateCount: " << valueUpdateCount << " \n " ;
102
+ cout << " query count: " << numTests - valueUpdateCount << " \n " ;
103
+ return 0 ;
104
+ }
0 commit comments