平衡树与序列维护

平衡树除了用来对存在偏序关系的数据进行维护,还能用于对序列维护,相当于一个数组。阅读本文你需要先看完上一篇关于treap的文章。

序列维护

在之前的文章,我们介绍过使用树状数组,以及线段树来维护一个序列,可以做区间操作及区间求和,但它们都存在一个缺点,不能动态插入数据。那我们怎么样才能通过平衡树来维护序列呢,之前我们有一个size字段能快速找第k大(或树的中序遍历第k个元素),而旋转操作并不会改变元素之间的相对顺序,那么我们就通过它直接插入到第k个元素的前面,这样我们插入的时候就不再通过要插入的值本身的大小关系,而需要多加一个参数k决定插入的位置。当平衡树用于维护序列的时候,就不用考虑元素相等的问题了。这样我们把元素相等处理的代码删除并修改基本操作的代码就能得到第一个能维护序列的基本模板,以下模板使用Treap修改得来。

基本模板

以下模板我实现成真·模板,就几乎可以作为数组使用了

template <typename T>
struct treap_seq
{
    struct data
    {
        T v;
        data(T _v = 0) :v(_v) {}
        operator bool() const { return v != 0; }
        operator T() const { return v; }
    };
    struct node
    {
        int ch[2], sz;
        unsigned k;
        data d;
        node(int z = 1) :sz(z), k(rnd()) { d = ch[0] = ch[1] = 0; }
        void clone(const node& n) { d = n.d; }
        static unsigned rnd()
        {
            static unsigned r = 0x123;
            r = r * 69069 + 1;
            return r;
        }
    };
    vector<node> nodes;
    vector<int> recycle;
    int root;
    int reserve_size;
    void reserve()
    {
        if (size() >= reserve_size)
            nodes.reserve((reserve_size *= 2) + 1);
    }
    int new_node()
    {
        int id = (int)nodes.size();
        if (!recycle.empty())
        {
            id = recycle.back();
            recycle.pop_back();
            nodes[id] = node();
        }
        else nodes.push_back(node());
        return id;
    }
    void update(int tp)
    {
        node& n = nodes[tp];
        n.sz = 1 + nodes[n.ch[0]].sz + nodes[n.ch[1]].sz;
    }
    int insert(int& tp, int k, const data& d)
    {
        if (tp == 0)
        {
            tp = new_node();
            nodes[tp].d = d;
            return tp;
        }
        node& n = nodes[tp];
        int sz = nodes[n.ch[0]].sz + 1;
        int r = sz < k;
        int& s = n.ch[r];
        int ret = insert(s, k - sz * r, d);
        update(s);
        if (nodes[s].k < n.k) rotate(tp, r);
        else update(tp);
        return ret;
    }
    void rotate(int& tp, int r)
    {
        node& n = nodes[tp];
        int s = n.ch[r];
        n.ch[r] = nodes[s].ch[r ^ 1];
        nodes[s].ch[r ^ 1] = tp;
        update(tp); update(s);
        tp = s;
    }
    int erasefind(int& tp, int k) // return deleted
    {
        if (tp == 0) return 0;
        node& n = nodes[tp];
        int sz = nodes[n.ch[0]].sz + 1;
        if (sz == k)
        {
            remove(tp);
            return 1;
        }
        int r = sz < k;
        int& s = n.ch[r];
        int ret = erasefind(s, k - sz * r);
        if (ret)
        {
            update(tp);
            return 1;
        }
        return 0;
    }
    void remove(int& tp)
    {
        if (tp == 0) return;
        if (!nodes[tp].ch[0] || !nodes[tp].ch[1])
        {
            recycle.push_back(tp);
            tp = nodes[tp].ch[!nodes[tp].ch[0]];
        }
        else
        {
            int r = nodes[nodes[tp].ch[0]].k < nodes[nodes[tp].ch[1]].k;
            rotate(tp, r ^ 1);
            remove(nodes[tp].ch[r]);
            update(tp);
        }
    }
    int kth(int tp, int k) // return id
    {
        if (tp == 0) return tp;
        node n = nodes[tp];
        int sz = nodes[n.ch[0]].sz;
        if (sz >= k) return kth(n.ch[0], k);
        if (sz + 1 >= k) return tp;
        return kth(n.ch[1], k - sz - 1);
    }
    // interface
    void init(int size)
    {
        nodes.clear();
        recycle.clear();
        nodes.reserve(size + 1);
        nodes.push_back(node(0));
        root = 0; reserve_size = size;
    }
    T get(int id) { return nodes[id].d; }
    int size() { return nodes[root].sz; }
    int insert(int k, data v) { if (size() >= reserve_size) nodes.reserve((reserve_size *= 2) + 1); return insert(root, k, v); }
    int erase(int k) { return erasefind(root, k); }
    int kth(int k) { return kth(root, k); } // return id
};

动态版本线段树

有了这个,我们就可以把它改成动态版本的线段树,就是每个根节点多维护一个sum字段,再加一个懒惰标记,就能实现区间加和区间求和。不过和线段树不同的是,线段树的子树的根只维护区间的结果,而平衡树的根本身就是一个元素,所以代码和线段树略有差别。以下我们实现一个支持区间加和区间求和且能动态增减数据的平衡树,实测与前面讲线段树文章的普通线段树模板,在解决同一问题的执行时间非常接近

template <typename T>
struct treap_seq
{
    struct data
    {
        T v;
        data(T _v = 0) :v(_v) {}
        data operator + (const data& d) const
        {
            data r;
            r.v = v + d.v;
            return r;
        }
        data operator * (int t) const
        {
            data r;
            r.v = v * t;
            return r;
        }
        operator bool() const { return v != 0; }
        operator T() const { return v; }
    };
    struct node
    {
        int ch[2], sz;
        unsigned k;
        data d;
        data sum;
        data lz_add;
        node(int z = 1) :sz(z), k(rnd()) { sum = lz_add = d = ch[0] = ch[1] = 0; }
        void clone(const node& n) { d = n.d; sum = n.sum; }
        static unsigned rnd()
        {
            static unsigned r = 0x123;
            r = r * 69069 + 1;
            return r;
        }
    };
    vector<node> nodes;
    vector<int> recycle;
    int root;
    int reserve_size;
    void reserve()
    {
        if (size() >= reserve_size)
            nodes.reserve((reserve_size *= 2) + 1);
    }
    int new_node()
    {
        int id = (int)nodes.size();
        if (!recycle.empty())
        {
            id = recycle.back();
            recycle.pop_back();
            nodes[id] = node();
        }
        else nodes.push_back(node());
        return id;
    }
    void _add(int tp, const data& d)
    {
        node& n = nodes[tp];
        n.lz_add = n.lz_add + d;
        n.d = n.d + d;
        n.sum = n.sum + d * n.sz;
    }
    void pushdown(int tp)
    {
        node& n = nodes[tp];
        if (n.lz_add)
        {
            _add(n.ch[0], n.lz_add);
            _add(n.ch[1], n.lz_add);
            n.lz_add = 0;
        }
    }
    void update(int tp)
    {
        node& n = nodes[tp];
        n.sz = 1 + nodes[n.ch[0]].sz + nodes[n.ch[1]].sz;
        n.sum = n.d + nodes[n.ch[0]].sum + nodes[n.ch[1]].sum;
    }
    int insert(int& tp, int k, const data& d)
    {
        if (tp == 0)
        {
            tp = new_node();
            nodes[tp].d = d;
            nodes[tp].sum = d;
            return tp;
        }
        node& n = nodes[tp];
        pushdown(tp);
        int sz = nodes[n.ch[0]].sz + 1;
        int r = sz < k;
        int& s = n.ch[r];
        int ret = insert(s, k - sz * r, d);
        update(s);
        if (nodes[s].k < n.k) rotate(tp, r);
        else update(tp);
        return ret;
    }
    void rotate(int& tp, int r)
    {
        node& n = nodes[tp];
        pushdown(tp);
        int s = n.ch[r];
        pushdown(s);
        n.ch[r] = nodes[s].ch[r ^ 1];
        nodes[s].ch[r ^ 1] = tp;
        update(tp); update(s);
        tp = s;
    }
    int erasefind(int& tp, int k) // return deleted
    {
        if (tp == 0) return 0;
        node& n = nodes[tp];
        pushdown(tp);
        int sz = nodes[n.ch[0]].sz + 1;
        if (sz == k)
        {
            remove(tp);
            return 1;
        }
        int r = sz < k;
        int& s = n.ch[r];
        int ret = erasefind(s, k - sz * r);
        if (ret)
        {
            update(tp);
            return 1;
        }
        return 0;
    }
    void remove(int& tp)
    {
        if (tp == 0) return;
        if (!nodes[tp].ch[0] || !nodes[tp].ch[1])
        {
            recycle.push_back(tp);
            tp = nodes[tp].ch[!nodes[tp].ch[0]];
        }
        else
        {
            int r = nodes[nodes[tp].ch[0]].k < nodes[nodes[tp].ch[1]].k;
            rotate(tp, r ^ 1);
            remove(nodes[tp].ch[r]);
            update(tp);
        }
    }
    int kth(int tp, int k) // return id
    {
        if (tp == 0) return tp;
        node n = nodes[tp];
        pushdown(tp);
        int sz = nodes[n.ch[0]].sz;
        if (sz >= k) return kth(n.ch[0], k);
        if (sz + 1 >= k) return tp;
        return kth(n.ch[1], k - sz - 1);
    }
    data getsum(int& tp, int l, int r)
    {
        if (tp == 0 || l >= r) return 0;
        node& n = nodes[tp];
        int sz = nodes[n.ch[0]].sz + 1;
        if (l <= 1 && r > n.sz)
        {
            return n.sum;
        }
        else
        {
            pushdown(tp);
            data sum = 0;
            if (l <= sz && sz < r)
            {
                sum = nodes[tp].d;
            }
            sum = sum + getsum(n.ch[0], l, min(sz, r));
            sum = sum + getsum(n.ch[1], max(1, l - sz), r - sz);
            return sum;
        }
    }
    void range_add(int& tp, int l, int r, const data& d)
    {
        if (tp == 0 || l >= r) return;
        node& n = nodes[tp];
        int sz = nodes[n.ch[0]].sz + 1;
        if (l <= 1 && r > n.sz)
        {
            _add(tp, d);
        }
        else
        {
            //pushdown(tp);
            if (l <= sz && sz < r)
            {
                nodes[tp].d = nodes[tp].d + d;
            }
            nodes[tp].sum = nodes[tp].sum + d * (r - l);
            range_add(n.ch[0], l, min(sz, r), d);
            range_add(n.ch[1], max(1, l - sz), r - sz, d);
        }
    }
    // interface
    void init(int size)
    {
        nodes.clear();
        recycle.clear();
        nodes.reserve(size + 1);
        nodes.push_back(node(0));
        root = 0; reserve_size = size;
    }
    T get(int id) { return nodes[id].d; }
    int size() { return nodes[root].sz; }
    int insert(int k, data v) { reserve(); return insert(root, k, v); }
    int erase(int k) { return erasefind(root, k); }
    int kth(int k) { return kth(root, k); } // return id
    T getsum(int l, int r) { return getsum(root, l, r + 1); }
    void range_add(int l, int r, data v) { range_add(root, l, r + 1, v); }
};

Splay tree

使用以上这个Treap的自由度还是不够好,有些操作还是不容易做,例如区间翻转,或者区间删除。所以我们需要一个功能更为强大的树,因为那个随机数的限制,不能任意节点都能当树根,而没有那个随机数字段的树,就是伸展树Splay tree,区别主要是三个地方,一是需要父节点字段,维护关系时常数更大,二是旋转,使用双旋,三是splay操作,作用是把节点提升到树根。这个splay操作就是神器,能把很多区间操作写得非常简单,代码也确实是目前介绍的树里面代码最少的。不过伸展树的缺点是编码理解难度稍大。

和其它树的不同点是,为了保证区间操作代码简短,初始化的时候直接插入两个元素作为序列的一头一尾,于是实际操作区间是2到n+1,这个细节要注意,有了这两个元素可以减少很多特判操作。例如说,要找区间[l,r],那么只要让位置r+1的元素splay到根,然后再让位置l-1的元素splay到根的左边,那么l-1位的元素的右子树就是整个操作区间了,而为了让这个总是能做,所以才要预先加两个元素。这个技巧用在了几乎所有操作里面,包括插入,删除,所有的区间操作。splay操作的时间复杂度 $O(logn)$

以下是序列维护用的基本splaytree模板,要改成支持区间求和什么的就自己改吧。

template<typename T>
struct splaytree_seq
{
    struct data
    {
        T v;
        data(int _v = 0) :v(_v) {}
        operator T() const { return v; }
    };
    struct node
    {
        int ch[2], fa, sz;
        data d;
        node(int z = 1) :sz(z) { ch[0] = ch[1] = fa = 0; }
    };
    vector<node> nodes;
    int root;
    int recyc;
    int reserve_size;
    void reserve()
    {
        if (size() >= reserve_size)
            nodes.reserve((reserve_size *= 2) + 1);
    }
    inline int& ch(int tp, int r) { return nodes[tp].ch[r]; }
    inline int& fa(int tp) { return nodes[tp].fa; }
    inline int rch(int tp) { return ch(fa(tp), 1) == tp; }
    int new_node()
    {
        int id = (int)nodes.size();
        if (recyc)
        {
            id = recyc;
            if (ch(recyc, 0) && ch(recyc, 1))
                recyc = merge(ch(recyc, 0), ch(recyc, 1));
            else
                recyc = ch(recyc, 0) ? ch(recyc, 0) : ch(recyc, 1);
            fa(recyc) = 0;
            nodes[id] = node();
        }
        else nodes.push_back(node());
        return id;
    }
    void update(int tp)
    {
        node& n = nodes[tp];
        n.sz = 1 + nodes[n.ch[0]].sz + nodes[n.ch[1]].sz;
    }
    void add(int tp, const data& d)
    {
        node& n = nodes[tp];
        n.d = n.d + d;
    }
    void rotate(int s)
    {
        int f1 = fa(s), f2 = fa(f1);
        int d1 = rch(s), d2 = rch(f1);
        ch(f2, d2) = s; fa(s) = f2;
        fa(ch(s, d1 ^ 1)) = f1; ch(f1, d1) = ch(s, d1 ^ 1);
        fa(f1) = s; ch(s, d1 ^ 1) = f1;
        update(f1); update(s);
    }
    void splay(int tp, int goal = 0)
    {
        for (int f; (f = fa(tp)) != goal; rotate(tp))
            if (fa(f) != goal) rotate(rch(tp) == rch(f) ? f : tp);
        if (!goal) root = tp;
    }
    int find_m(int tp, int r)
    {
        int p = tp;
        while (ch(p, r)) p = ch(p, r);
        if (p != tp) splay(p, tp);
        return p;
    }
    int merge(int tl, int tr)
    {
        if (!tl) { fa(tr) = 0; return tr; }
        if (!tr) { fa(tl) = 0; return tl; }
        int p = find_m(tl, 1);
        ch(p, 1) = tr; fa(tr) = p;
        return tl;
    }
    void insert(int k, const data& d)
    {
        int tp = new_node();
        splay(kth(root, k + 1)); splay(kth(root, k), root);
        int c = ch(root, 0);
        nodes[c].ch[1] = tp;
        nodes[tp].fa = c;
        nodes[tp].d = d;
        update(c); update(root);
    }
    void remove(int& tp)
    {
        fa(tp) = 0;
        if (recyc == 0) recyc = tp;
        else recyc = merge(recyc, tp);
        tp = 0;
    }
    int kth(int tp, int k)
    {
        if (tp == 0) return tp;
        node& n = nodes[tp];
        //pushdown(tp);
        int sz = nodes[n.ch[0]].sz + 1;
        if (sz > k) return kth(n.ch[0], k);
        if (sz >= k) return tp;
        return kth(n.ch[1], k - sz);
    }
    // interface
    void init(int size)
    {
        nodes.clear();
        nodes.reserve((size = max(size, 15)) + 1);
        nodes.push_back(node(0));
        nodes.push_back(node()); nodes.push_back(node());
        nodes[1].ch[0] = 2; nodes[1].sz = 2; nodes[2].fa = 1;
        root = 1; // be the bound
        recyc = 0; reserve_size = size + 1;
    }
    T get(int id) { return nodes[id].d; }
    int size() { return nodes[root].sz - 2; }
    int kth(int k) { int id = kth(root, k + 1); splay(id); return id; }
    void erase(int l, int r)
    {
        splay(kth(root, r + 2)); splay(kth(root, l), root);
        remove(ch(ch(root, 0), 1));
        update(ch(root, 0)); update(root);
    }
    void range_add(int l, int r, data v)
    {
        splay(kth(root, r + 2)); splay(kth(root, l), root);
        add(ch(ch(root, 0), 1), v);
        update(ch(root, 0)); update(root);
    }
};

以上已经直接写好了区间删除,对于区间反转等操作,可以模仿线段树加懒惰标记即可。

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus