Skip to content

Commit d5754bf

Browse files
Prince GuptaPrince Gupta
Prince Gupta
authored and
Prince Gupta
committed
Update segtree
1 parent 7eae560 commit d5754bf

File tree

5 files changed

+94
-26
lines changed

5 files changed

+94
-26
lines changed

.vscode/settings.json

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"files.associations": {
3-
"iostream": "cpp"
4-
}
3+
"iostream": "cpp",
4+
"tuple": "cpp"
5+
}
56
}

DataStructures/Segment Tree/a.out

-17.8 KB
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include <bits/stdc++.h>
2+
using namespace std;
3+
4+
#include "segtree.h"
5+
6+
int32_t main()
7+
{
8+
vector<int> a = {1, 2, 3, 4, 5, 6};
9+
int n = a.size();
10+
11+
// SegTree<int, int> st(a, [](int x, int y) { return x + y; });
12+
auto mergeFn = [](int x, int y) { return max(x, y); };
13+
SegTree<int, int, decltype(mergeFn)&> st(a, INT_MIN, mergeFn);
14+
15+
// Test to check if seg tree works correctly.
16+
random_device dev;
17+
mt19937 rng(dev());
18+
19+
uniform_int_distribution<mt19937::result_type> dist(0,1000);
20+
21+
int numTests = 1000; // Test thousand times.
22+
int remTest = numTests;
23+
int valueUpdateCount = 0;
24+
while (remTest--)
25+
{
26+
int type = dist(rng) % 2;
27+
if (!type)
28+
{
29+
// Do a random update
30+
int idx = dist(rng) % n;
31+
int newVal = dist(rng) - 500;
32+
33+
st.pointUpdate(idx, newVal);
34+
a[idx] = newVal;
35+
valueUpdateCount++;
36+
}else
37+
{
38+
// Do a random range query
39+
int l = dist(rng) % n, r = dist(rng) % n;
40+
if (l > r) swap(l, r);
41+
42+
int stAns = st.query(l ,r);
43+
int actualAns = INT_MIN;
44+
for (int i = l; i <= r; i++)
45+
actualAns = max(actualAns, a[i]);
46+
47+
if (actualAns != stAns)
48+
{
49+
cout << "Test failed\n";
50+
cout << stAns << ": " << actualAns << "\n";
51+
return 0;
52+
}
53+
}
54+
}
55+
cout << "Test Passed\n";
56+
cout << "valueUpdateCount: " << valueUpdateCount << "\n";
57+
cout << "query count: " << numTests - valueUpdateCount << "\n";
58+
return 0;
59+
}

DataStructures/Segment Tree/seg_tree_sum_query.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@ int32_t main()
99
int n = a.size();
1010

1111
// SegTree<int, int> st(a, [](int x, int y) { return x + y; });
12-
SegTree<int, int> st(a, [](int x, int y) { return x + y; });
12+
auto mergeFn = [](int x, int y) { return x + y; };
13+
SegTree<int, int, decltype(mergeFn)&> st(a, 0, mergeFn);
1314

1415
// Test to check if seg tree works correctly.
1516
random_device dev;
1617
mt19937 rng(dev());
1718

18-
uniform_int_distribution<mt19937::result_type> dist(0,1000); // distribution in range [1, 6]
19+
uniform_int_distribution<mt19937::result_type> dist(0,1000);
1920

2021
int numTests = 1000; // Test thousand times.
2122
int remTest = numTests;

DataStructures/Segment Tree/segtree.h

+29-22
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,27 @@ using namespace std;
99
* @Data Data type of actual values
1010
* @Node Data type of what to store in segment tree nodes
1111
*
12-
* Node must have a default constructor and a constructor Node(const Data &d)
12+
* identity is Node indentity element on mergeFn i.e mergeFn(x, identity) = x
13+
* Node(const Data& data) should be possible
1314
*/
14-
template<typename Data, typename Node>
15+
template<typename Data, typename Node, typename MergeFnType>
1516
class SegTree
1617
{
1718
int n;
1819
vector<Node> st;
19-
function<Node(Node, Node)> mergeFn;
20+
MergeFnType mergeFn;
21+
Node identity;
2022
public:
2123
SegTree() = default;
22-
SegTree(const vector<Data> &data, const function<Node(Node, Node)> &fn)
24+
SegTree(const vector<Data> &data, const Node& identity, MergeFnType &&fn)
25+
: mergeFn(forward<MergeFnType>(fn)), identity(identity), n(data.size())
2326
{
24-
n = data.size();
25-
st = vector<Node>(4 * n);
26-
mergeFn = fn;
27-
build(data, 1, 0, n - 1);
27+
st = vector<Node>(2 * n - 1);
28+
build(data, 0, 0, n - 1);
2829
}
2930

30-
Node query(int l, int r) { return queryImpl(l, r, 1, 0, n - 1); };
31-
void pointUpdate(int pos, const Data &val) { return pointUpdateImpl(pos, val, 1, 0, n - 1); }
31+
Node query(int l, int r) { return queryImpl(l, r, 0, 0, n - 1); };
32+
void pointUpdate(int pos, const Data &val) { return pointUpdateImpl(pos, val, 0, 0, n - 1); }
3233

3334
private:
3435
void build(const vector<Data> &data, int v, int l, int r)
@@ -37,22 +38,22 @@ class SegTree
3738
st[v] = data[l];
3839
else
3940
{
40-
int mid = (l + r) / 2;
41-
build(data, 2 * v, l, mid);
42-
build(data, 2 * v + 1, mid + 1, r);
43-
st[v] = mergeFn(st[v * 2], st[v * 2 + 1]);
41+
auto [lv, rv, mid] = getChildren(v, l, r);
42+
build(data, lv, l, mid);
43+
build(data, rv, mid + 1, r);
44+
st[v] = mergeFn(st[lv], st[rv]);
4445
}
4546
}
4647

4748
Node queryImpl(int l, int r, int v, int nL, int nR)
4849
{
49-
if (l > r) return Node{};
50+
if (l > r) return identity;
5051
if (nL == l && nR == r)
5152
return st[v];
52-
int mid = (nL + nR) / 2;
53+
auto [lv, rv, mid] = getChildren(v, nL, nR);
5354
return mergeFn(
54-
queryImpl(l, min(r, mid), 2 * v, nL, mid),
55-
queryImpl(max(l, mid + 1), r, 2 * v + 1, mid + 1, nR)
55+
queryImpl(l, min(r, mid), lv, nL, mid),
56+
queryImpl(max(l, mid + 1), r, rv, mid + 1, nR)
5657
);
5758
}
5859

@@ -63,11 +64,17 @@ class SegTree
6364
st[v] = val;
6465
return;
6566
}
66-
int mid = (nL + nR) / 2;
67+
auto [lv, rv, mid] = getChildren(v, nL, nR);
6768
if (pos <= mid)
68-
pointUpdateImpl(pos, val, 2 * v, nL, mid);
69+
pointUpdateImpl(pos, val, lv, nL, mid);
6970
else
70-
pointUpdateImpl(pos, val, 2 * v + 1, mid + 1, nR);
71-
st[v] = mergeFn(st[2 * v], st[2 * v + 1]);
71+
pointUpdateImpl(pos, val, rv, mid + 1, nR);
72+
st[v] = mergeFn(st[lv], st[rv]);
7273
}
74+
75+
tuple<int, int, int> getChildren(int v, int nL, int nR) const
76+
{
77+
int mid = (nL + nR) / 2;
78+
return make_tuple(v + 1, v + 2 * (mid - nL + 1), mid);
79+
}
7380
};

0 commit comments

Comments
 (0)