线段树

很多人在初始接触线段树的时候,一看到别人写一大堆代码就直接弃坑了,其实不要被它的外表所欺骗,线段树其实是相当好写的树结构了,而且理解起来其实很简单。要学会这个,你不能光会抄模板就会区间修改和求个区间和,因为实际应用经常会使用它的变形,还是在于理解(理解后背板)

数据结构

首先,回想一下heap的结构,它使用一个数组,同时使用下标本身来表达父子关系,这样的方式能节省大量指针所需要的内存空间,以下也使用这种表示方法来表示一棵线段树,也就是说,这里介绍的,属于狭义线段树。假设我们的数据是以下这样

下标 1 2 3 4 5 6 7 8
数据 1 0 5 2 3 4 0 1

构建线段树后结果如下

graph TD;
1,8:16-->1,4:8
1,8:16-->5,8:8
1,4:8-->1,2:1
1,4:8-->3,4:7
1,2:1-->1,1:1
1,2:1-->2,2:0
3,4:7-->3,3:5
3,4:7-->4,4:2
5,8:8-->5,6:7
5,8:8-->7,8:1
5,6:7-->5,5:3
5,6:7-->6,6:4
7,8:1-->7,7:0
7,8:1-->8,8:1

冒号前面的两个数表示一条线段,冒号后表示的是数据,这个数据表示的是这个区间的和。如此一来,我们查询一个区间的和,可以很快地计算出来,例如求[1,6]的和,那么需要拆分为[1,4][5,6]的和,分别是8和7,所以结果是8+7=15,原理就是这样而已。

单点数据更新

单点更新时,可以参考树状数组,先更新子节点,然后向上找父节点更新即可,也可以递归实现,这不在本节讨论范围。不过如果你确实只需要单点修改,那么可以考虑ZKW线段树,ZKW线段树是先更新子节点,然后向上找父节点更新,由于少了很多递归,常数比递归的线段树要小。后文提供一个简易的模板作为参考。

区间数据更新

例如,我们希望对区间[3,5]上的数都加上2,这时候需要引入懒惰标记,其实就是把操作记录在父节点上,有必要时再向下传递。像刚才的例子,都加上懒惰标记后

graph TD;
1,8:16,0-->1,4:8,0
1,8:16,0-->5,8:8,0
1,4:8,0-->1,2:1,0
1,4:8,0-->3,4:7,0
1,2:1,0-->1,1:1,0
1,2:1,0-->2,2:0,0
3,4:7,0-->3,3:5,0
3,4:7,0-->4,4:2,0
5,8:8,0-->5,6:7,0
5,8:8,0-->7,8:1,0
5,6:7,0-->5,5:3,0
5,6:7,0-->6,6:4,0
7,8:1,0-->7,7:0,0
7,8:1,0-->8,8:1,0

然后对区间[3,5]上的数都加上2,那么把这个区间拆分为[3,4][5,5],更新标记

graph TD;
1,8:16,0-->1,4:8,0
1,8:16,0-->5,8:8,0
1,4:8,0-->1,2:1,0
1,4:8,0-->3,4:11,2
1,2:1,0-->1,1:1,0
1,2:1,0-->2,2:0,0
3,4:11,2-->3,3:5,0
3,4:11,2-->4,4:2,0
5,8:8,0-->5,6:7,0
5,8:8,0-->7,8:1,0
5,6:7,0-->5,5:5,2
5,6:7,0-->6,6:4,0
7,8:1,0-->7,7:0,0
7,8:1,0-->8,8:1,0

也就是说,[3,3][4,4]都没有更新,更新在[3,4]上了,那么接下来需要查询[3,3]的话,就把标记向下传递一层,变成

graph TD;
1,8:16,0-->1,4:8,0
1,8:16,0-->5,8:8,0
1,4:8,0-->1,2:1,0
1,4:8,0-->3,4:11,0
1,2:1,0-->1,1:1,0
1,2:1,0-->2,2:0,0
3,4:11,0-->3,3:7,2
3,4:11,0-->4,4:4,2
5,8:8,0-->5,6:7,0
5,8:8,0-->7,8:1,0
5,6:7,0-->5,5:5,2
5,6:7,0-->6,6:4,0
7,8:1,0-->7,7:0,0
7,8:1,0-->8,8:1,0

这样,再获取区间[3,3]的结果7,就是所需要的答案

基础模板

以下基础模板只支持区间求和,以及区间整体加上一个数的操作,和树状数组后面提供的模板实现了相同的功能

struct seg_tree_add
{
    struct node
    {
        int sum;
        int lz_add;
    };
    int sz;
    vector<node> d; // 仿heap的形式保存线段树
    inline int lson(int tp) { return tp * 2 + 1; }
    inline int rson(int tp) { return tp * 2 + 2; }
    // 当前tp节点对应的线段区间为[tl,tr],更新区间是[l,r]
    void update_add(int l, int r, int v, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            d[tp].sum += (tr - tl + 1) * v;
            d[tp].lz_add += v;
            return;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        // 更新左右儿子
        if (l <= tmid) update_add(l, r, v, tl, tmid, lson(tp));
        if (r > tmid) update_add(l, r, v, tmid + 1, tr, rson(tp));
        d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
    }
    int get_sum(int l, int r, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            return d[tp].sum;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        // 统计左右儿子
        int sum = 0;
        if (l <= tmid) sum += get_sum(l, r, tl, tmid, lson(tp));
        if (r > tmid) sum += get_sum(l, r, tmid + 1, tr, rson(tp));
        return sum;
    }
    void init(int size) // 可操作下标范围为0~size-1,如需要从1开始那么要+1
    {
        sz = size;
        while (sz & (sz - 1)) sz += sz&-sz; // 扩展为满二叉树
        d.resize(sz * 2);
    }
    void update_add(int l, int r, int v)
    {
        update_add(l, r, v, 0, sz - 1, 0);
    }
    int get_sum(int l, int r)
    {
        return get_sum(l, r, 0, sz - 1, 0);
    }
};

用法,调用init初始化范围(注意下标从0到size-1,下标要从1开始的话要size+1,否则如果size正好是2的k次方时操作下标为size时会出问题),然后通过update_addget_sum更新数据即可。

另外一点,这个模板实现没有使用左闭右开区间来写,如果改用左闭右开区间,并添加build实现,则得到如下实现(代码有少许简化且更对称更好读)

struct seg_tree_add
{
    struct node
    {
        int sum;
        int lz_add;
    };
    int sz;
    vector<node> d;
    inline int lson(int tp) { return tp * 2 + 1; }
    inline int rson(int tp) { return tp * 2 + 2; }
    void update_add(int l, int r, int v, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            d[tp].sum += (tr - tl) * v;
            d[tp].lz_add += v;
            return;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        // 更新左右儿子
        if (l < tmid) update_add(l, r, v, tl, tmid, lson(tp));
        if (r > tmid) update_add(l, r, v, tmid, tr, rson(tp));
        d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
    }
    int get_sum(int l, int r, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            return d[tp].sum;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        // 统计左右儿子
        int sum = 0;
        if (l < tmid) sum += get_sum(l, r, tl, tmid, lson(tp));
        if (r > tmid) sum += get_sum(l, r, tmid, tr, rson(tp));
        return sum;
    }
    void build(int a[], int alen, int tl, int tr, int tp)
    {
        if (tl + 1 == tr)
        {
            if (tl < alen)
                d[tp].sum = a[tl];
            else
                d[tp].sum = 0;
            return;
        }
        int tmid = (tl + tr) / 2;
        build(a, alen, tl, tmid, lson(tp));
        build(a, alen, tmid, tr, rson(tp));
        d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
        d[tp].lz_add = 0;
    }
    void build(int a[], int alen)
    {
        build(a, alen, 0, sz, 0);
    }
    void init(int size) // 可操作下标范围为0~size-1
    {
        sz = size;
        while (sz & (sz - 1)) sz += sz&-sz;
        d.resize(sz * 2);
    }
    void update_add(int l, int r, int v)
    {
        update_add(l, r + 1, v, 0, sz, 0);
    }
    int get_sum(int l, int r)
    {
        return get_sum(l, r + 1, 0, sz, 0);
    }
};

进阶模板

如果你需要支持区间整体加上某个数,同时支持区间整体设置为指定数,那么就需要多重懒惰标记,模板可以改写如下(闭区间实现)

点击展开

struct seg_tree
{
    static const int lz_mark = 0x80000000;
    struct node
    {
        int sum;
        int lz_set;
        int lz_add;
    };
    int sz;
    vector<node> d;
    inline int lson(int tp) { return tp * 2 + 1; }
    inline int rson(int tp) { return tp * 2 + 2; }
    void update_add(int l, int r, int v, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            d[tp].sum += (tr - tl + 1) * v;
            if (d[tp].lz_set != lz_mark) d[tp].lz_set += v;
            else d[tp].lz_add += v;
            return;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_set != lz_mark)
        {
            update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
            update_set(tmid + 1, tr, d[tp].lz_set, tmid + 1, tr, rson(tp));
            d[tp].lz_set = lz_mark;
        }
        else if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        if (l <= tmid) update_add(l, r, v, tl, tmid, lson(tp));
        if (r > tmid) update_add(l, r, v, tmid + 1, tr, rson(tp));
        d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
    }
    void update_set(int l, int r, int v, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            d[tp].sum = (tr - tl + 1) * v; //区间和
            d[tp].lz_set = v;
            d[tp].lz_add = 0;
            return;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_set != lz_mark)
        {
            update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
            update_set(tmid + 1, tr, d[tp].lz_set, tmid + 1, tr, rson(tp));
            d[tp].lz_set = lz_mark;
        }
        else if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        if (l <= tmid) update_set(l, r, v, tl, tmid, lson(tp));
        if (r > tmid) update_set(l, r, v, tmid + 1, tr, rson(tp));
        d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
    }
    int get_sum(int l, int r, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            return d[tp].sum;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_set != lz_mark)
        {
            update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
            update_set(tmid + 1, tr, d[tp].lz_set, tmid + 1, tr, rson(tp));
            d[tp].lz_set = lz_mark;
        }
        else if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid + 1, tr, d[tp].lz_add, tmid + 1, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        int sum = 0;
        if (l <= tmid) sum += get_sum(l, r, tl, tmid, lson(tp));
        if (r > tmid) sum += get_sum(l, r, tmid + 1, tr, rson(tp));
        return sum;
    }
    void init(int size) // 可操作下标范围为0~size-1
    {
        sz = size;
        while (sz & (sz - 1)) sz += sz&-sz;
        d.resize(sz * 2);
    }
    void update_add(int l, int r, int v)
    {
        update_add(l, r, v, 0, sz - 1, 0);
    }
    void update_set(int l, int r, int v)
    {
        update_set(l, r, v, 0, sz - 1, 0);
    }
    int get_sum(int l, int r)
    {
        return get_sum(l, r, 0, sz - 1, 0);
    }
};

左闭右开区间实现(接口为闭区间)

点击展开

struct seg_tree
{
    static const int lz_mark = 0x80000000;
    struct node
    {
        int sum;
        int lz_set;
        int lz_add;
    };
    int sz;
    vector<node> d;
    inline int lson(int tp) { return tp * 2 + 1; }
    inline int rson(int tp) { return tp * 2 + 2; }
    void update_add(int l, int r, int v, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            d[tp].sum += (tr - tl) * v;
            if (d[tp].lz_set != lz_mark) d[tp].lz_set += v;
            else d[tp].lz_add += v;
            return;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_set != lz_mark)
        {
            update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
            update_set(tmid, tr, d[tp].lz_set, tmid, tr, rson(tp));
            d[tp].lz_set = lz_mark;
        }
        else if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        if (l < tmid) update_add(l, r, v, tl, tmid, lson(tp));
        if (r > tmid) update_add(l, r, v, tmid, tr, rson(tp));
        d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
    }
    void update_set(int l, int r, int v, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            d[tp].sum = (tr - tl) * v; //区间和
            d[tp].lz_set = v;
            d[tp].lz_add = 0;
            return;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_set != lz_mark)
        {
            update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
            update_set(tmid, tr, d[tp].lz_set, tmid, tr, rson(tp));
            d[tp].lz_set = lz_mark;
        }
        else if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        if (l < tmid) update_set(l, r, v, tl, tmid, lson(tp));
        if (r > tmid) update_set(l, r, v, tmid, tr, rson(tp));
        d[tp].sum = d[lson(tp)].sum + d[rson(tp)].sum;
    }
    ll get_sum(int l, int r, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            return d[tp].sum;
        }
        int tmid = (tl + tr) / 2;
        // 下发lazy标志一层
        if (d[tp].lz_set != lz_mark)
        {
            update_set(tl, tmid, d[tp].lz_set, tl, tmid, lson(tp));
            update_set(tmid, tr, d[tp].lz_set, tmid, tr, rson(tp));
            d[tp].lz_set = lz_mark;
        }
        else if (d[tp].lz_add != 0)
        {
            update_add(tl, tmid, d[tp].lz_add, tl, tmid, lson(tp));
            update_add(tmid, tr, d[tp].lz_add, tmid, tr, rson(tp));
            d[tp].lz_add = 0;
        }
        int sum = 0;
        if (l < tmid) sum += get_sum(l, r, tl, tmid, lson(tp));
        if (r > tmid) sum += get_sum(l, r, tmid, tr, rson(tp));
        return sum;
    }
    void init(int size) // 可操作下标范围为0~size-1
    {
        sz = size;
        while (sz & (sz - 1)) sz += sz&-sz;
        d.resize(sz * 2);
    }
    void update_add(int l, int r, int v)
    {
        update_add(l, r + 1, v, 0, sz, 0);
    }
    void update_set(int l, int r, int v)
    {
        update_set(l, r + 1, v, 0, sz, 0);
    }
    ll get_sum(int l, int r)
    {
        return get_sum(l, r + 1, 0, sz, 0);
    }
};

简易区间最值模板(就是简易得只有查询,如果要支持更新就自己加上)

点击展开

struct seg_tree
{
    struct node
    {
        int max;
        int min;
    };
    int sz;
    vector<node> d;
    inline int lson(int tp) { return tp * 2 + 1; }
    inline int rson(int tp) { return tp * 2 + 2; }
    int get_max(int l, int r, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            return d[tp].max;
        }
        int tmid = (tl + tr) / 2;
        int ret = INT_MIN;
        if (l < tmid) ret = max(ret, get_max(l, r, tl, tmid, lson(tp)));
        if (r > tmid) ret = max(ret, get_max(l, r, tmid, tr, rson(tp)));
        return ret;
    }
    int get_min(int l, int r, int tl, int tr, int tp)
    {
        if (l <= tl && tr <= r)
        {
            return d[tp].min;
        }
        int tmid = (tl + tr) / 2;
        int ret = INT_MAX;
        if (l < tmid) ret = min(ret, get_min(l, r, tl, tmid, lson(tp)));
        if (r > tmid) ret = min(ret, get_min(l, r, tmid, tr, rson(tp)));
        return ret;
    }
    void build(int a[], int alen, int tl, int tr, int tp)
    {
        if (tl + 1 == tr)
        {
            if (tl < alen)
            {
                d[tp].max = a[tl];
                d[tp].min = a[tl];
            }
            else
            {
                d[tp].max = 0;
                d[tp].min = 0;
            }
            return;
        }
        int tmid = (tl + tr) / 2;
        build(a, alen, tl, tmid, lson(tp));
        build(a, alen, tmid, tr, rson(tp));
        d[tp].max = max(d[lson(tp)].max, d[rson(tp)].max);
        d[tp].min = min(d[lson(tp)].min, d[rson(tp)].min);
    }
    void build(int a[], int alen)
    {
        build(a, alen, 0, sz, 0);
    }
    void init(int size) // 可操作下标范围为0~size-1
    {
        sz = size;
        while (sz & (sz - 1)) sz += sz&-sz;
        d.resize(sz * 2);
    }
    int get_min(int l, int r)
    {
        return get_min(l, r + 1, 0, sz, 0);
    }
    int get_max(int l, int r)
    {
        return get_max(l, r + 1, 0, sz, 0);
    }
};

ZKW线段树模板

这是单点修改求区间和的模板,求区间最值稍微改改就好了,如果需要区间修改,那么可以模仿树状数组的办法做差分,或做永久化标记,适应性比递归实现的线段树差一些,优点是常数小,以下实现比前面的大约快30%左右,代码更简单,就不额外解释了。

struct zkwseg_tree
{
    struct node
    {
        int sum;
    };
    int sz;
    vector<node> d;
    void init(int size) // 可操作下标范围为0~size-1
    {
        sz = size;
        while (sz & (sz - 1)) sz += sz&-sz;
        d.resize(sz * 2);
    }
    void update_add(int p, int v)
    {
        int i = sz + p;
        while (i)
        {
            d[i].sum += v;
            i >>= 1;
        }
    }
    int get_sum(int l, int r)
    {
        int sum = 0;
        l += sz;
        r += sz + 1;
        for (; l < r; l>>=1, r>>=1)
        {
            if (l & 1)
            {
                sum += d[l++].sum;
            }
            if (r & 1)
            {
                sum += d[--r].sum;
            }
        }
        return sum;
    }
};

其它说明

以上模板为了解释简单,有的实现只有updateget_sum操作,并没有build的部分,只使用update即可完成build的操作,时间复杂度也是一样的。除了维护区间和,也可以维护区间最大最小值。

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus