@@ -9,26 +9,27 @@ using namespace std;
9
9
* @Data Data type of actual values
10
10
* @Node Data type of what to store in segment tree nodes
11
11
*
12
- * Node must have a default constructor and a constructor Node(const Data &d)
12
+ * identity is Node indentity element on mergeFn i.e mergeFn(x, identity) = x
13
+ * Node(const Data& data) should be possible
13
14
*/
14
- template <typename Data, typename Node>
15
+ template <typename Data, typename Node, typename MergeFnType >
15
16
class SegTree
16
17
{
17
18
int n;
18
19
vector<Node> st;
19
- function<Node(Node, Node)> mergeFn;
20
+ MergeFnType mergeFn;
21
+ Node identity;
20
22
public:
21
23
SegTree () = default ;
22
- SegTree (const vector<Data> &data, const function<Node(Node, Node)> &fn)
24
+ SegTree (const vector<Data> &data, const Node& identity, MergeFnType &&fn)
25
+ : mergeFn(forward<MergeFnType>(fn)), identity(identity), n(data.size())
23
26
{
24
- n = data.size ();
25
- st = vector<Node>(4 * n);
26
- mergeFn = fn;
27
- build (data, 1 , 0 , n - 1 );
27
+ st = vector<Node>(2 * n - 1 );
28
+ build (data, 0 , 0 , n - 1 );
28
29
}
29
30
30
- Node query (int l, int r) { return queryImpl (l, r, 1 , 0 , n - 1 ); };
31
- void pointUpdate (int pos, const Data &val) { return pointUpdateImpl (pos, val, 1 , 0 , n - 1 ); }
31
+ Node query (int l, int r) { return queryImpl (l, r, 0 , 0 , n - 1 ); };
32
+ void pointUpdate (int pos, const Data &val) { return pointUpdateImpl (pos, val, 0 , 0 , n - 1 ); }
32
33
33
34
private:
34
35
void build (const vector<Data> &data, int v, int l, int r)
@@ -37,22 +38,22 @@ class SegTree
37
38
st[v] = data[l];
38
39
else
39
40
{
40
- int mid = (l + r) / 2 ;
41
- build (data, 2 * v , l, mid);
42
- build (data, 2 * v + 1 , mid + 1 , r);
43
- st[v] = mergeFn (st[v * 2 ], st[v * 2 + 1 ]);
41
+ auto [lv, rv, mid] = getChildren (v, l, r) ;
42
+ build (data, lv , l, mid);
43
+ build (data, rv , mid + 1 , r);
44
+ st[v] = mergeFn (st[lv ], st[rv ]);
44
45
}
45
46
}
46
47
47
48
Node queryImpl (int l, int r, int v, int nL, int nR)
48
49
{
49
- if (l > r) return Node{} ;
50
+ if (l > r) return identity ;
50
51
if (nL == l && nR == r)
51
52
return st[v];
52
- int mid = (nL + nR) / 2 ;
53
+ auto [lv, rv, mid] = getChildren (v, nL, nR);
53
54
return mergeFn (
54
- queryImpl (l, min (r, mid), 2 * v , nL, mid),
55
- queryImpl (max (l, mid + 1 ), r, 2 * v + 1 , mid + 1 , nR)
55
+ queryImpl (l, min (r, mid), lv , nL, mid),
56
+ queryImpl (max (l, mid + 1 ), r, rv , mid + 1 , nR)
56
57
);
57
58
}
58
59
@@ -63,11 +64,17 @@ class SegTree
63
64
st[v] = val;
64
65
return ;
65
66
}
66
- int mid = (nL + nR) / 2 ;
67
+ auto [lv, rv, mid] = getChildren (v, nL, nR);
67
68
if (pos <= mid)
68
- pointUpdateImpl (pos, val, 2 * v , nL, mid);
69
+ pointUpdateImpl (pos, val, lv , nL, mid);
69
70
else
70
- pointUpdateImpl (pos, val, 2 * v + 1 , mid + 1 , nR);
71
- st[v] = mergeFn (st[2 * v ], st[2 * v + 1 ]);
71
+ pointUpdateImpl (pos, val, rv , mid + 1 , nR);
72
+ st[v] = mergeFn (st[lv ], st[rv ]);
72
73
}
74
+
75
+ tuple<int , int , int > getChildren (int v, int nL, int nR) const
76
+ {
77
+ int mid = (nL + nR) / 2 ;
78
+ return make_tuple (v + 1 , v + 2 * (mid - nL + 1 ), mid);
79
+ }
73
80
};
0 commit comments