Skip to content

Instantly share code, notes, and snippets.

@leftrk
Created May 19, 2019 17:20
Show Gist options
  • Save leftrk/768595420dcb64a1d634c36324d43941 to your computer and use it in GitHub Desktop.
Save leftrk/768595420dcb64a1d634c36324d43941 to your computer and use it in GitHub Desktop.
#include "leetcode.h"
class Solution {
public:
// 关键点:把 nums 塞入 value 中
void build_segment_tree(vector<int> &value, // 线段数组 value, 存储区间 sum
vector<int> &nums, // 原始数组 nums
int pos, // 当前线段【节点】在线段数组 value 中的下标
int left, // 当前线段的左端点
int right // 当前线段的右端点
) {
if (left == right) {
value[pos] = nums[left];
return;
}
int mid = (left + right) / 2;
build_segment_tree(value, nums, pos * 2 + 1, left, mid); // 建立左子树线段
build_segment_tree(value, nums, pos * 2 + 2, mid + 1, right); // 建立右子树线段
value[pos] = value[pos * 2 + 1] + value[pos * 2 + 2]; // value[pos] 为左右子树代表的区间值的和
}
void print_segment_tree(vector<int> &value, int pos, int left, int right, int layer) {
for (int i = 0; i < layer; ++i)
printf("---");
printf("[%d %d][%d] : %d\n", left, right, pos, value[pos]);
if (left == right)
return;
int mid = (left + right) / 2;
print_segment_tree(value, pos * 2 + 1, left, mid, layer + 1);
print_segment_tree(value, pos * 2 + 2, mid + 1, right, layer + 1);
}
int sum_range_segment_tree(vector<int> &value, int pos, int left, int right, int qleft, int qright) {
if (qleft > right || qright < left)
return 0;
if (qleft <= left && qright >= right)
return value[pos];
int mid = (left + right) / 2;
return sum_range_segment_tree(value, pos * 2 + 1, left, mid, qleft, qright) +
sum_range_segment_tree(value, pos * 2 + 2, mid + 1, right, qleft, qright);
}
void update_segment_tree(vector<int> &value, int pos, int left, int right, int index, int new_value) {
if (left == right && left == index) {
value[pos] = new_value;
return;
}
int mid = (left + right) / 2;
if (index <= mid)
update_segment_tree(value, pos * 2 + 1, left, mid, index, new_value);
else
update_segment_tree(value, pos * 2 + 2, mid + 1, right, index, new_value);
value[pos] = value[pos * 2 + 1] + value[pos * 2 + 2];
}
};
int main() {
Solution solution;
vector<int> nums;
for (int i = 0; i < 24; ++i)
nums.push_back(i);
vector<int> value;
for (int i = 0; i < 24; ++i)
value.push_back(0);
solution.build_segment_tree(value, nums, 0, 0, nums.size() - 1);
printf("segment_tree:\n");
solution.print_segment_tree(value, 0, 0, nums.size() - 1, 0);
int sum_range = solution.sum_range_segment_tree(value, 0, 0, nums.size() - 1, 2, 4);
printf("sum range [2, 5] = %d\n", sum_range);
solution.update_segment_tree(value, 0, 0, nums.size() - 1, 2, 10);
printf("segment_tree:\n");
solution.print_segment_tree(value, 0, 0, nums.size() - 1, 0);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment