线段树
很多人在初始接触线段树的时候,一看到别人写一大堆代码就直接弃坑了,其实不要被它的外表所欺骗,线段树其实是相当好写的树结构了,而且理解起来其实很简单。要学会这个,你不能光会抄模板就会区间修改和求个区间和,因为实际应用经常会使用它的变形,还是在于理解(理解后背板)。
数据结构
首先,回想一下heap的结构,它使用一个数组,同时使用下标本身来表达父子关系,这样的方式能节省大量指针所需要的内存空间,以下也使用这种表示方法来表示一棵线段树,也就是说,这里介绍的,属于狭义线段树。假设我们的数据是以下这样
下标 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
数据 | 1 | 0 | 5 | 2 | 3 | 4 | 0 | 1 |
构建线段树后结果如下
graph TD;
1,8:16-->1,4:8
1,8:16-->5,8:8
1,4:8-->1,2:1
1,4:8-->3,4:7
1,2:1-->1,1:1
1,2:1-->2,2:0
3,4:7-->3,3:5
3,4:7-->4,4:2
5,8:8-->5,6:7
5,8:8-->7,8:1
5,6:7-->5,5:3
5,6:7-->6,6:4
7,8:1-->7,7:0
7,8:1-->8,8:1
冒号前面的两个数表示一条线段,冒号后表示的是数据,这个数据表示的是这个区间的和。如此一来,我们查询一个区间的和,可以很快地计算出来,例如求[1,6]
的和,那么需要拆分为[1,4]
与[5,6]
的和,分别是8和7,所以结果是8+7=15
,原理就是这样而已。
单点数据更新
单点更新时,可以参考树状数组,先更新子节点,然后向上找父节点更新即可,也可以递归实现,这不在本节讨论范围。不过如果你确实只需要单点修改,那么可以考虑ZKW线段树,ZKW线段树是先更新子节点,然后向上找父节点更新,由于少了很多递归,常数比递归的线段树要小。后文提供一个简易的模板作为参考。
区间数据更新
例如,我们希望对区间[3,5]
上的数都加上2,这时候需要引入懒惰标记,其实就是把操作记录在父节点上,有必要时再向下传递。像刚才的例子,都加上懒惰标记后
graph TD;
1,8:16,0-->1,4:8,0
1,8:16,0-->5,8:8,0
1,4:8,0-->1,2:1,0
1,4:8,0-->3,4:7,0
1,2:1,0-->1,1:1,0
1,2:1,0-->2,2:0,0
3,4:7,0-->3,3:5,0
3,4:7,0-->4,4:2,0
5,8:8,0-->5,6:7,0
5,8:8,0-->7,8:1,0
5,6:7,0-->5,5:3,0
5,6:7,0-->6,6:4,0
7,8:1,0-->7,7:0,0
7,8:1,0-->8,8:1,0
然后对区间[3,5]
上的数都加上2,那么把这个区间拆分为[3,4]
和[5,5]
,更新标记
graph TD;
1,8:16,0-->1,4:8,0
1,8:16,0-->5,8:8,0
1,4:8,0-->1,2:1,0
1,4:8,0-->3,4:11,2
1,2:1,0-->1,1:1,0
1,2:1,0-->2,2:0,0
3,4:11,2-->3,3:5,0
3,4:11,2-->4,4:2,0
5,8:8,0-->5,6:7,0
5,8:8,0-->7,8:1,0
5,6:7,0-->5,5:5,2
5,6:7,0-->6,6:4,0
7,8:1,0-->7,7:0,0
7,8:1,0-->8,8:1,0
也就是说,[3,3]
和[4,4]
都没有更新,更新在[3,4]
上了,那么接下来需要查询[3,3]
的话,就把标记向下传递一层,变成
graph TD;
1,8:16,0-->1,4:8,0
1,8:16,0-->5,8:8,0
1,4:8,0-->1,2:1,0
1,4:8,0-->3,4:11,0
1,2:1,0-->1,1:1,0
1,2:1,0-->2,2:0,0
3,4:11,0-->3,3:7,2
3,4:11,0-->4,4:4,2
5,8:8,0-->5,6:7,0
5,8:8,0-->7,8:1,0
5,6:7,0-->5,5:5,2
5,6:7,0-->6,6:4,0
7,8:1,0-->7,7:0,0
7,8:1,0-->8,8:1,0
这样,再获取区间[3,3]
的结果7,就是所需要的答案
基础模板
以下基础模板只支持区间求和,以及区间整体加上一个数的操作,和树状数组后面提供的模板实现了相同的功能
struct seg_tree_add
{
struct node
{
int sum;
int lz_add;
};
int sz;
vector<node> d; // 仿heap的形式保存线段树
inline int lson(int tp) { return tp * 2 + 1; }
inline int rson(int tp) { return tp * 2 + 2; }
// 当前tp节点对应的线段区间为[tl,tr],更新区间是[l,r]
void update_add(int l, int r, int v, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
d[tp].sum += (tr - tl + 1) * v;
d[tp].lz_add += v;
return;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
d[tp].lz_add = 0;
}
// 更新左右儿子
if (l <= tmid) update_add(l, r, v, tl, tmid, lson(tp));
if (r > tmid) update_add(l, r, v, tmid + 1, tr, rson(tp));
d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
}
int get_sum(int l, int r, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
return d[tp].sum;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
d[tp].lz_add = 0;
}
// 统计左右儿子
int sum = 0;
if (l <= tmid) sum += get_sum(l, r, tl, tmid, lson(tp));
if (r > tmid) sum += get_sum(l, r, tmid + 1, tr, rson(tp));
return sum;
}
void init(int size) // 可操作下标范围为0~size-1,如需要从1开始那么要+1
{
sz = size;
while (sz & (sz - 1)) sz += sz&-sz; // 扩展为满二叉树
d.resize(sz * 2);
}
void update_add(int l, int r, int v)
{
update_add(l, r, v, 0, sz - 1, 0);
}
int get_sum(int l, int r)
{
return get_sum(l, r, 0, sz - 1, 0);
}
};
用法,调用init初始化范围(注意下标从0到size-1,下标要从1开始的话要size+1,否则如果size正好是2的k次方时操作下标为size时会出问题),然后通过update_add
和get_sum
更新数据即可。
另外一点,这个模板实现没有使用左闭右开区间来写,如果改用左闭右开区间,并添加build实现,则得到如下实现(代码有少许简化且更对称更好读)
struct seg_tree_add
{
struct node
{
int sum;
int lz_add;
};
int sz;
vector<node> d;
inline int lson(int tp) { return tp * 2 + 1; }
inline int rson(int tp) { return tp * 2 + 2; }
void update_add(int l, int r, int v, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
d[tp].sum += (tr - tl) * v;
d[tp].lz_add += v;
return;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
d[tp].lz_add = 0;
}
// 更新左右儿子
if (l < tmid) update_add(l, r, v, tl, tmid, lson(tp));
if (r > tmid) update_add(l, r, v, tmid, tr, rson(tp));
d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
}
int get_sum(int l, int r, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
return d[tp].sum;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
d[tp].lz_add = 0;
}
// 统计左右儿子
int sum = 0;
if (l < tmid) sum += get_sum(l, r, tl, tmid, lson(tp));
if (r > tmid) sum += get_sum(l, r, tmid, tr, rson(tp));
return sum;
}
void build(int a[], int alen, int tl, int tr, int tp)
{
if (tl + 1 == tr)
{
if (tl < alen)
d[tp].sum = a[tl];
else
d[tp].sum = 0;
return;
}
int tmid = (tl + tr) / 2;
build(a, alen, tl, tmid, lson(tp));
build(a, alen, tmid, tr, rson(tp));
d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
d[tp].lz_add = 0;
}
void build(int a[], int alen)
{
build(a, alen, 0, sz, 0);
}
void init(int size) // 可操作下标范围为0~size-1
{
sz = size;
while (sz & (sz - 1)) sz += sz&-sz;
d.resize(sz * 2);
}
void update_add(int l, int r, int v)
{
update_add(l, r + 1, v, 0, sz, 0);
}
int get_sum(int l, int r)
{
return get_sum(l, r + 1, 0, sz, 0);
}
};
进阶模板
如果你需要支持区间整体加上某个数,同时支持区间整体设置为指定数,那么就需要多重懒惰标记,模板可以改写如下(闭区间实现)
点击展开
struct seg_tree
{
static const int lz_mark = 0x80000000;
struct node
{
int sum;
int lz_set;
int lz_add;
};
int sz;
vector<node> d;
inline int lson(int tp) { return tp * 2 + 1; }
inline int rson(int tp) { return tp * 2 + 2; }
void update_add(int l, int r, int v, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
d[tp].sum += (tr - tl + 1) * v;
if (d[tp].lz_set != lz_mark) d[tp].lz_set += v;
else d[tp].lz_add += v;
return;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_set != lz_mark)
{
update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
update_set(tmid + 1, tr, d[tp].lz_set, tmid + 1, tr, rson(tp));
d[tp].lz_set = lz_mark;
}
else if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
d[tp].lz_add = 0;
}
if (l <= tmid) update_add(l, r, v, tl, tmid, lson(tp));
if (r > tmid) update_add(l, r, v, tmid + 1, tr, rson(tp));
d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
}
void update_set(int l, int r, int v, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
d[tp].sum = (tr - tl + 1) * v; //区间和
d[tp].lz_set = v;
d[tp].lz_add = 0;
return;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_set != lz_mark)
{
update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
update_set(tmid + 1, tr, d[tp].lz_set, tmid + 1, tr, rson(tp));
d[tp].lz_set = lz_mark;
}
else if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
d[tp].lz_add = 0;
}
if (l <= tmid) update_set(l, r, v, tl, tmid, lson(tp));
if (r > tmid) update_set(l, r, v, tmid + 1, tr, rson(tp));
d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
}
int get_sum(int l, int r, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
return d[tp].sum;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_set != lz_mark)
{
update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
update_set(tmid + 1, tr, d[tp].lz_set, tmid + 1, tr, rson(tp));
d[tp].lz_set = lz_mark;
}
else if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
d[tp].lz_add = 0;
}
int sum = 0;
if (l <= tmid) sum += get_sum(l, r, tl, tmid, lson(tp));
if (r > tmid) sum += get_sum(l, r, tmid + 1, tr, rson(tp));
return sum;
}
void init(int size) // 可操作下标范围为0~size-1
{
sz = size;
while (sz & (sz - 1)) sz += sz&-sz;
d.resize(sz * 2);
}
void update_add(int l, int r, int v)
{
update_add(l, r, v, 0, sz - 1, 0);
}
void update_set(int l, int r, int v)
{
update_set(l, r, v, 0, sz - 1, 0);
}
int get_sum(int l, int r)
{
return get_sum(l, r, 0, sz - 1, 0);
}
};
左闭右开区间实现(接口为闭区间)
点击展开
struct seg_tree
{
static const int lz_mark = 0x80000000;
struct node
{
int sum;
int lz_set;
int lz_add;
};
int sz;
vector<node> d;
inline int lson(int tp) { return tp * 2 + 1; }
inline int rson(int tp) { return tp * 2 + 2; }
void update_add(int l, int r, int v, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
d[tp].sum += (tr - tl) * v;
if (d[tp].lz_set != lz_mark) d[tp].lz_set += v;
else d[tp].lz_add += v;
return;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_set != lz_mark)
{
update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
update_set(tmid, tr, d[tp].lz_set, tmid, tr, rson(tp));
d[tp].lz_set = lz_mark;
}
else if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
d[tp].lz_add = 0;
}
if (l < tmid) update_add(l, r, v, tl, tmid, lson(tp));
if (r > tmid) update_add(l, r, v, tmid, tr, rson(tp));
d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
}
void update_set(int l, int r, int v, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
d[tp].sum = (tr - tl) * v; //区间和
d[tp].lz_set = v;
d[tp].lz_add = 0;
return;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_set != lz_mark)
{
update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
update_set(tmid, tr, d[tp].lz_set, tmid, tr, rson(tp));
d[tp].lz_set = lz_mark;
}
else if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
d[tp].lz_add = 0;
}
if (l < tmid) update_set(l, r, v, tl, tmid, lson(tp));
if (r > tmid) update_set(l, r, v, tmid, tr, rson(tp));
d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
}
ll get_sum(int l, int r, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
return d[tp].sum;
}
int tmid = (tl + tr) / 2;
// 下发lazy标志一层
if (d[tp].lz_set != lz_mark)
{
update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
update_set(tmid, tr, d[tp].lz_set, tmid, tr, rson(tp));
d[tp].lz_set = lz_mark;
}
else if (d[tp].lz_add != 0)
{
update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
d[tp].lz_add = 0;
}
int sum = 0;
if (l < tmid) sum += get_sum(l, r, tl, tmid, lson(tp));
if (r > tmid) sum += get_sum(l, r, tmid, tr, rson(tp));
return sum;
}
void init(int size) // 可操作下标范围为0~size-1
{
sz = size;
while (sz & (sz - 1)) sz += sz&-sz;
d.resize(sz * 2);
}
void update_add(int l, int r, int v)
{
update_add(l, r + 1, v, 0, sz, 0);
}
void update_set(int l, int r, int v)
{
update_set(l, r + 1, v, 0, sz, 0);
}
ll get_sum(int l, int r)
{
return get_sum(l, r + 1, 0, sz, 0);
}
};
简易区间最值模板(就是简易得只有查询,如果要支持更新就自己加上)
点击展开
struct seg_tree
{
struct node
{
int max;
int min;
};
int sz;
vector<node> d;
inline int lson(int tp) { return tp * 2 + 1; }
inline int rson(int tp) { return tp * 2 + 2; }
int get_max(int l, int r, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
return d[tp].max;
}
int tmid = (tl + tr) / 2;
int ret = INT_MIN;
if (l < tmid) ret = max(ret, get_max(l, r, tl, tmid, lson(tp)));
if (r > tmid) ret = max(ret, get_max(l, r, tmid, tr, rson(tp)));
return ret;
}
int get_min(int l, int r, int tl, int tr, int tp)
{
if (l <= tl && tr <= r)
{
return d[tp].min;
}
int tmid = (tl + tr) / 2;
int ret = INT_MAX;
if (l < tmid) ret = min(ret, get_min(l, r, tl, tmid, lson(tp)));
if (r > tmid) ret = min(ret, get_min(l, r, tmid, tr, rson(tp)));
return ret;
}
void build(int a[], int alen, int tl, int tr, int tp)
{
if (tl + 1 == tr)
{
if (tl < alen)
{
d[tp].max = a[tl];
d[tp].min = a[tl];
}
else
{
d[tp].max = 0;
d[tp].min = 0;
}
return;
}
int tmid = (tl + tr) / 2;
build(a, alen, tl, tmid, lson(tp));
build(a, alen, tmid, tr, rson(tp));
d[tp].max = max(d[lson(tp)].max, d[rson(tp)].max);
d[tp].min = min(d[lson(tp)].min, d[rson(tp)].min);
}
void build(int a[], int alen)
{
build(a, alen, 0, sz, 0);
}
void init(int size) // 可操作下标范围为0~size-1
{
sz = size;
while (sz & (sz - 1)) sz += sz&-sz;
d.resize(sz * 2);
}
int get_min(int l, int r)
{
return get_min(l, r + 1, 0, sz, 0);
}
int get_max(int l, int r)
{
return get_max(l, r + 1, 0, sz, 0);
}
};
ZKW线段树模板
这是单点修改求区间和的模板,求区间最值稍微改改就好了,如果需要区间修改,那么可以模仿树状数组的办法做差分,或做永久化标记,适应性比递归实现的线段树差一些,优点是常数小,以下实现比前面的大约快30%左右,代码更简单,就不额外解释了。
struct zkwseg_tree
{
struct node
{
int sum;
};
int sz;
vector<node> d;
void init(int size) // 可操作下标范围为0~size-1
{
sz = size;
while (sz & (sz - 1)) sz += sz&-sz;
d.resize(sz * 2);
}
void update_add(int p, int v)
{
int i = sz + p;
while (i)
{
d[i].sum += v;
i >>= 1;
}
}
int get_sum(int l, int r)
{
int sum = 0;
l += sz;
r += sz + 1;
for (; l < r; l>>=1, r>>=1)
{
if (l & 1)
{
sum += d[l++].sum;
}
if (r & 1)
{
sum += d[--r].sum;
}
}
return sum;
}
};
其它说明
以上模板为了解释简单,有的实现只有update
和get_sum
操作,并没有build
的部分,只使用update
即可完成build的操作,时间复杂度也是一样的。除了维护区间和,也可以维护区间最大最小值。