可持久化线段树

可持久化权值线段树,wiki上指出引入者名字叫黃嘉泰,名字缩写正好是某位主席名字,所以又叫做主席树。而本篇先介绍可持久化线段树,阅读本篇前你需要先了解线段树

概念

所谓的可持久化,意思是你能得到所有的历史版本,为了达到这个效果,当然可以每次修改的时候,先整体复制再修改,结果自然就是会爆内存。而事实上,由于每次修改最多改一条链,而其它分支可以重用。我们先拿链表做例子,如果有个链表内容是 1->2->3->4->5 ,现在我们把3修改成6,得到 1->2->6->4->5 ,但是后面的元素没有改动,所以我们可以把后面的元素直接重叠在一起使用,如下图:

graph LR;
1-->2
2-->3
3-->4
4-->5
1'-->2'
2'-->6
6-->4

这样,完全可以当成两条不同的链表使用,同时节省空间。而可持久化线段树做法与这一样,就是没变的部分还使用原来节点,所以这个实现不能使用之前介绍的堆式储存,要和平衡树一样动态开节点。

数据结构

假设我们的数据是以下这样

下标 1 2 3 4
数据 1 0 5 2

构建线段树后结果如下

graph TD;
1,4:8-->1,2:1
1,4:8-->3,4:7
1,2:1-->1,1:1
1,2:1-->2,2:0
3,4:7-->3,3:5
3,4:7-->4,4:2

冒号前面的两个数表示一条线段,冒号后表示的是数据,这个数据表示的是这个区间的和。

然后我们要把第3个元素从5改为1,构造第二棵线段树,首先复制一个root,包括儿子的指向也复制,得到

graph TD;
1,4:8-->1,2:1
1,4:8-->3,4:7
1,2:1-->1,1:1
1,2:1-->2,2:0
3,4:7-->3,3:5
3,4:7-->4,4:2
1,4':8-->1,2:1
1,4':8-->3,4:7

然后,要更新的节点在右儿子那,所以把右儿子复制出来,得到

graph TD;
1,4:8-->1,2:1
1,4:8-->3,4:7
1,2:1-->1,1:1
1,2:1-->2,2:0
3,4:7-->3,3:5
3,4:7-->4,4:2
1,4':8-->1,2:1
1,4':8-->3,4':7
3,4':7-->3,3:5
3,4':7-->4,4:2

最后,在区间$[3,4]$要更新的节点在左儿子那,所以把左儿子复制出来,同时由于这是最后的节点,再从底向上更新sum,得到

graph TD;
1,4:8-->1,2:1
1,4:8-->3,4:7
1,2:1-->1,1:1
1,2:1-->2,2:0
3,4:7-->3,3:5
3,4:7-->4,4:2
1,4':4-->1,2:1
1,4':4-->3,4':3
3,4':3-->L3,3':1
3,4':3-->4,4:2

上图中L3,3':13,4':3的左儿子。这样就是可持久化线段树的构造过程

静态区间范围查询

现在给出区间$[L,R]$和范围$[a,b]$,求数组中在区间$[L,R]$里有多少个元素在范围$[a,b]$里。这种查询普通的线段树并不好办,那可持久化线段树有什么方法来解呢,首先我们先构造一棵空线段树,然后对数组元素做离散化,按大小映射到$[0,n-1]$,然后对离散化后的数组,按下标次序,一个一个加入到可持久化线段树里,例如数字2,那么我们就要在线段树里对2号元素+1,所以这就是可持久化权值线段树,即主席树。如此这般加入后,由于我们是按下标次序加入的,所以我们非常容易地得到表示区间$[0,R]$的线段树,那么在范围$[a,b]$里的元素数量,正好就是$[a,b]$区间和。但如果要求的是区间$[L,R]$里有多少个元素在范围$[a,b]$里,那我们除了要求出区间$[0,R]$,还要求出区间$[0,L-1]$,然后两者的$[a,b]$区间和的差,就是我们所要的答案

基础模板

以下基础模板只支持区间求和,即求区间$[0,R]$里有多少个元素在范围$[a,b]$里

struct persistent_seg_tree
{
    struct data
    {
        int sum;
        data() :sum(0) {}
    };
    struct node
    {
        int l, r;
        data d;
        node() :l(0), r(0) {}
    };
    vector<node> nodes;
    vector<int> roots;
    int sz;

    void up(int id)
    {
        nodes[id].d.sum = nodes[nodes[id].l].d.sum + nodes[nodes[id].r].d.sum;
    }
    int newnode(int cpy)
    {
        int id = (int)nodes.size();
        node tmp = nodes[cpy];
        nodes.push_back(tmp);
        return id;
    }
    int add(int tp, int tl, int tr, int i, int v)
    {
        int id = newnode(tp);
        if (tl + 1 >= tr)
        {
            nodes[id].d.sum += v;
            return id;
        }
        int tmid = (tl + tr) / 2;
        if (i < tmid)
        {
            int nid = add(nodes[id].l, tl, tmid, i, v);
            nodes[id].l = nid;
        }
        else
        {
            int nid = add(nodes[id].r, tmid, tr, i, v);
            nodes[id].r = nid;
        }
        up(id);
        return id;
    }
    int getsum(int tp, int tl, int tr, int l, int r)
    {
        if (l <= tl && tr <= r)
        {
            return nodes[tp].d.sum;
        }
        int tmid = (tl + tr) / 2;
        int sum = 0;
        if (l < tmid)
        {
            sum += getsum(nodes[tp].l, tl, tmid, l, r);
        }
        if (r > tmid)
        {
            sum += getsum(nodes[tp].r, tmid, tr, l, r);
        }
        return sum;
    }
    // interface
    void init(int range, int root_size) // 数组大小[0, range),操作次数
    {
        sz = range;
        nodes.clear();
        roots.clear();
        nodes.reserve(root_size * (int)(log(sz * 2.0) / log(2.0) + 1.01));
        nodes.push_back(node());
        roots.reserve(root_size + 1);
        roots.push_back(0);
    }
    void add(int pos, int v)
    {
        int last = roots.back();
        roots.push_back(add(last, 0, sz, pos, v));
    }
    int getsum(int t, int l, int r)
    {
        if (t <= 0) return 0;
        if (r < l) return 0;
        if (t >= (int)roots.size()) t = (int)roots.size() - 1;
        return getsum(roots[t], 0, sz, l, r + 1);
    }
};

使用说明,先调用init,参数分别是离散化后的值域大小,和数组大小(对应的就是操作完后根的个数,所以名字是root_size),然后循环 add(pos, 1),最后查询时,调用getsum(R, a, b) - getsum(L - 1, a, b),LR就是区间,ab就是值域范围。

静态区间第k大

此问题解法较多,本篇主要介绍使用主席树的解法,同样也是先建立一棵可持久化权值线段树,对于查询区间为$[0,R]$的第k大,这个问题很简单,就是找出前缀和大于等于k的区间$[0,m]$所对应的最小的m值,所以只要对$[0,R]$所对应的线段树做查找,如果左子树的sum小于等于k,那么进入左子树查询k,否则进入右子树查询k-sum即可。但对于查询区间$[L,R]$,我们需要找出$[0,R]$和$[0,L-1]$这两棵线段树,它们的$[a,b]$区间和表示在$[L,R]$里有多少个数的值域在$[a,b]$之间,所以我们就同时对这两棵线段树做查找,设 $[0,R]$左子树的sum 减去 $[0,L-1]$左子树的sum 为S,那么如果S小于等于k,那么进入左子树查询k,否则进入右子树查询k-S,实现代码如下

int kth(int tpl, int tpr, int tl, int tr, int k)
{
    if (tl + 1 >= tr) return tl;
    int tmid = (tl + tr) / 2;
    int num = nodes[nodes[tpr].l].d.sum - nodes[nodes[tpl].l].d.sum;
    if (k <= num) return kth(nodes[tpl].l, nodes[tpr].l, tl, tmid, k);
    else return kth(nodes[tpl].r, nodes[tpr].r, tmid, tr, k - num);
}

区间第k大模板

struct persistent_seg_tree
{
    struct data
    {
        int sum;
        data() :sum(0) {}
    };
    struct node
    {
        int l, r;
        data d;
        node() :l(0), r(0) {}
    };
    vector<node> nodes;
    vector<int> roots;
    int sz;

    void up(int id)
    {
        nodes[id].d.sum = nodes[nodes[id].l].d.sum + nodes[nodes[id].r].d.sum;
    }
    int newnode(int cpy)
    {
        int id = (int)nodes.size();
        node tmp = nodes[cpy];
        nodes.push_back(tmp);
        return id;
    }
    int add(int tp, int tl, int tr, int i, int v)
    {
        int id = newnode(tp);
        if (tl + 1 >= tr)
        {
            nodes[id].d.sum += v;
            return id;
        }
        int tmid = (tl + tr) / 2;
        if (i < tmid)
        {
            int nid = add(nodes[id].l, tl, tmid, i, v);
            nodes[id].l = nid;
        }
        else
        {
            int nid = add(nodes[id].r, tmid, tr, i, v);
            nodes[id].r = nid;
        }
        up(id);
        return id;
    }
    int kth(int tpl, int tpr, int tl, int tr, int k)
    {
        if (tl + 1 >= tr) return tl;
        int tmid = (tl + tr) / 2;
        int num = nodes[nodes[tpr].l].d.sum - nodes[nodes[tpl].l].d.sum;
        if (k <= num) return kth(nodes[tpl].l, nodes[tpr].l, tl, tmid, k);
        else return kth(nodes[tpl].r, nodes[tpr].r, tmid, tr, k - num);
    }
    // interface
    void init(int range, int root_size) // 数组大小[0, range),操作次数
    {
        sz = range;
        nodes.clear();
        roots.clear();
        nodes.reserve(root_size * (int)(log(sz * 2.0) / log(2.0) + 1.01));
        nodes.push_back(node());
        roots.reserve(root_size + 1);
        roots.push_back(0);
    }
    void add(int i, int v)
    {
        int last = roots.back();
        roots.push_back(add(last, 0, sz, i, v));
    }
    int kth(int tpl, int tpr, int k)
    {
        return kth(roots[tpl - 1], roots[tpr], 0, sz, k);
    }
};

其它说明

其它的可持久化数据结构大同小异,如可持久化的trie,构造方法也是一样的

以上只介绍了静态区间的范围查询和第k大查询,还不支持动态修改并查询,这个会在之后再做介绍。

习题:静态区间范围查询hdu 4417,静态区间第k大POJ 2104

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus