1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
| class SegmentTree { private: struct Node { int sum; }; struct LazyNode { int add, mul; }; vector<Node> tree; vector<LazyNode> lazy; int n; int mod;
Node merge(const Node& l, const Node& r) { Node t; t.sum = (l.sum + r.sum) % mod; return t; }
void build(const vector<int>& nums, int u, int s, int e) { lazy[u].mul = 1; lazy[u].add = 0; if(s == e) tree[u] = {nums[s] % mod}; else { int mid = (s + e) >> 1; build(nums, u * 2 + 1, s, mid); build(nums, u * 2 + 2, mid + 1, e); tree[u] = merge(tree[u * 2 + 1], tree[u * 2 + 2]); } }
void pushdown(int u, int s, int e, int mid) { if(lazy[u].mul != 1 || lazy[u].add != 0) { tree[u * 2 + 1].sum = (tree[u * 2 + 1].sum * lazy[u].mul + lazy[u].add * (mid - s + 1)) % mod; lazy[u * 2 + 1].mul = (lazy[u * 2 + 1].mul * lazy[u].mul) % mod; lazy[u * 2 + 1].add = (lazy[u * 2 + 1].add * lazy[u].mul + lazy[u].add) % mod;
tree[u * 2 + 2].sum = (tree[u * 2 + 2].sum * lazy[u].mul + lazy[u].add * (e - mid)) % mod; lazy[u * 2 + 2].mul = (lazy[u * 2 + 2].mul * lazy[u].mul) % mod; lazy[u * 2 + 2].add = (lazy[u * 2 + 2].add * lazy[u].mul + lazy[u].add) % mod;
lazy[u].mul = 1; lazy[u].add = 0; } }
void updateRange(int u, int s, int e, int l, int r, int add, int mul) { if(l > e || r < s) return; if(l <= s && r >= e) { tree[u].sum = (tree[u].sum * mul + add * (e - s + 1)) % mod; lazy[u].mul = (lazy[u].mul * mul) % mod; lazy[u].add = (lazy[u].add * mul + add) % mod; return; } int mid = (s + e) >> 1; pushdown(u, s, e, mid); updateRange(u * 2 + 1, s, mid, l, r, add, mul); updateRange(u * 2 + 2, mid + 1, e, l, r, add, mul); tree[u] = merge(tree[u * 2 + 1], tree[u * 2 + 2]); }
int queryRange(int u, int s, int e, int l, int r) { if(l > e || r < s) return 0; if(l <= s && r >= e) return tree[u].sum % mod; int mid = (s + e) >> 1; pushdown(u, s, e, mid); return (queryRange(u * 2 + 1, s, mid, l, r) + queryRange(u * 2 + 2, mid + 1, e, l, r)) % mod; }
public: SegmentTree(const vector<int>& nums, int p) { n = nums.size(); mod = p; tree.resize(4 * n); lazy.resize(4 * n); build(nums, 0, 0, n - 1); }
void updateadd(int l, int r, int val) { updateRange(0, 0, n - 1, l, r, val, 1); }
void updatemul(int l, int r, int val) { updateRange(0, 0, n - 1, l, r, 0, val); }
int query(int l, int r) { return queryRange(0, 0, n - 1, l, r); } };
|