Range Sum Query - Mutable - Segment tree in C++ and Java
Problem Statement:
Given an integer array nums
, handle multiple queries of the following types:
- Update the value of an element in
nums
. - Calculate the sum of the elements of
nums
between indicesleft
andright
inclusive whereleft <= right
.
Implement the NumArray
class:
NumArray(int[] nums)
Initializes the object with the integer arraynums
.void update(int index, int val)
Updates the value ofnums[index]
to beval
.int sumRange(int left, int right)
Returns the sum of the elements ofnums
between indicesleft
andright
inclusive (i.e.nums[left] + nums[left + 1] + ... + nums[right]
).
Example 1:
Input ["NumArray", "sumRange", "update", "sumRange"] [[[1, 3, 5]], [0, 2], [1, 2], [0, 2]] Output [null, 9, null, 8] Explanation NumArray numArray = new NumArray([1, 3, 5]); numArray.sumRange(0, 2); // return 1 + 3 + 5 = 9 numArray.update(1, 2); // nums = [1, 2, 5] numArray.sumRange(0, 2); // return 1 + 2 + 5 = 8
Constraints:
1 <= nums.length <= 3 * 104
-100 <= nums[i] <= 100
0 <= index < nums.length
-100 <= val <= 100
0 <= left <= right < nums.length
- At most
3 * 104
calls will be made toupdate
andsumRange
.
Solution:
SegmentTree in C++:
class SegmentTree
{
vector<int> tree;
public:
SegmentTree(vector<int> &nums)
{
int n = nums.size();
tree = vector<int>(4*n);
build(nums, 0, 0, n-1);
}
void build(vector<int>&A, int node, int lo, int hi)
{
if (lo==hi){tree[node]=A[lo]; return;}
int mid = lo + (hi-lo)/2;
build(A, 2*node+1, lo, mid);
build(A, 2*node+2, mid+1, hi);
tree[node] = tree[2*node+1]+tree[2*node+2];
}
int query(int node, int lo, int hi, int qlo, int qhi)
{
if (qhi<lo || qlo>hi) return 0;
if (qlo<=lo && qhi>=hi) return tree[node];
int mid = lo+(hi-lo)/2;
return query(2*node+1, lo, mid, qlo, qhi) + \
query(2*node+2, mid+1,hi,qlo, qhi);
}
void update(int node, int lo, int hi, int index, int value)
{
if (hi<lo) return;
if (lo==hi) {tree[node] = value; return;}
int mid = lo + (hi-lo)/2;
if (index<=mid) update(2*node+1, lo, mid, index, value);
else if (index>mid) update(2*node+2, mid+1, hi, index, value);
tree[node] = tree[2*node+1] + tree[2*node+2];
}
};
class NumArray {
int N;
SegmentTree segTree;
public:
NumArray(vector<int>& nums): segTree(nums), N(nums.size()) {}
void update(int index, int val) {
segTree.update(0,0,N-1,index,val);
}
int sumRange(int left, int right) {
return segTree.query(0, 0, N-1, left, right);
}
};
Segment tree in Java:
class SegmentTree
{
int[] tree;
SegmentTree(int[] A)
{
int n = A.length;
tree = new int[4*n];
build(A, 0, 0, n-1);
}
void build(int[] A, int node, int lo, int hi)
{
if (lo==hi){tree[node] = A[lo]; return;}
int mid = lo+(hi-lo)/2;
build(A, 2*node+1, lo, mid);
build(A, 2*node+2, mid+1, hi);
tree[node] = tree[2*node+1] + tree[2*node+2];
}
int query(int node, int lo, int hi, int qlo, int qhi)
{
if (qhi<lo || qlo>hi) return 0;
if (qlo<=lo && qhi>=hi) return tree[node];
int mid = lo + (hi-lo)/2;
return query(2*node+1, lo, mid, qlo, qhi) +
query(2*node+2, mid+1, hi, qlo, qhi);
}
void update(int node, int lo, int hi, int index, int value)
{
if (lo==hi){tree[node] = value; return;}
int mid = lo+(hi-lo)/2;
if (index<=mid) update(2*node+1, lo, mid, index, value);
if (index>mid) update(2*node+2, mid+1, hi, index, value);
tree[node] = tree[2*node+1]+tree[2*node+2];
}
}
class NumArray {
int N;
SegmentTree segTree;
public NumArray(int[] nums) {
N = nums.length;
segTree = new SegmentTree(nums);
}
public void update(int index, int val) {
segTree.update(0,0,N-1,index,val);
}
public int sumRange(int left, int right) {
return segTree.query(0,0,N-1,left,right);
}
}
We can also use segment tree without creating another class though it will be slightly inefficient (still AC though)
void build(vector<int>&A, vector<int>&tree, int node, int lo, int hi)
{
if (lo==hi){tree[node]=A[lo];return;}
int mid = lo + (hi-lo)/2;
build(A, tree, 2*node+1, lo, mid);
build(A, tree, 2*node+2, mid+1, hi);
tree[node] = tree[2*node+1]+tree[2*node+2];
}
int query(vector<int>&tree, int node, int lo, int hi, int qlo, int qhi)
{
if (qhi<lo || qlo>hi) return 0;
if (qlo<=lo && qhi>=hi) return tree[node];
int mid = lo+(hi-lo)/2;
return query(tree, 2*node+1, lo, mid, qlo, qhi) + \
query(tree, 2*node+2, mid+1,hi,qlo, qhi);
}
void supdate(vector<int>&tree, int node, int lo, int hi, int index, int value)
{
if (lo==hi) {tree[node] = value; return;}
int mid = lo + (hi-lo)/2;
if (index<=mid) supdate(tree, 2*node+1, lo, mid, index, value);
else if (index>mid) supdate(tree, 2*node+2, mid+1, hi, index, value);
tree[node] = tree[2*node+1] + tree[2*node+2];
}
class NumArray {
vector<int> tree;
int n;
public:
NumArray(vector<int>& nums) {
n = nums.size();
tree = vector<int>(4*n);
build(nums, tree, 0, 0, n-1);
}
void update(int index, int val) {
supdate(tree, 0, 0, n-1, index, val);
}
int sumRange(int left, int right) {
return query(tree, 0, 0, n-1, left, right);
}
};