Treap与SBT

这里之所以把这两个放在一起讲,是因为它们实在是相似度很高(至少在竞赛领域),都需要求kth和指定元素的rank(Treap的话可有可无,但通常会需要)。不过如果你没有写过树,强烈建议你自己通过理解来写一遍。

BST

首先,Treap和SBT都属于BST的一种,BST就是二叉搜索树,它满足的特点是:

  • 二叉树
  • 没有两个节点的值相等
  • 任意子树的根节点的值都比左子树所有节点的值要大
  • 任意子树的根节点的值都比右子树所有节点的值要小
  • 任意子树均为二叉搜索树

如果我们实在需要支持多个相同值放在树里面,那么有两种情况,如果那些相同值是确实完全没有区别(例如int),那么只需要在每个节点多加一个字段记录这个值出现的次数就可以了,但如果这些值只有偏序关系,可能不是严格相等,存在其它非比较字段,那么我们就再在每个节点增加一个next域做成一个链表即可。

基本操作

  • 插入(insert):对比子树的根节点的值r与插入的值v,如果v与r相等,根节点重复数量+1,如果v小于r,插入到左子树,v大于r则插入到右子树
  • 查找(find):和插入相似,值相等时返回其id
  • 删除(erase):先查找,找到的时候,再查找它的后继(右子树的最小),用后继元素替换后再删除原后继

支持简单重复元素的BST模板

这个模板还添加了size域,用于求第k小元素和元素排名

struct bst
{
    struct data
    {
        int v;
        data(int _v = 0) :v(_v) {}
        bool operator==(const data& d) const
        {
            return v == d.v;
        }
        bool operator<(const data& d) const
        {
            return v < d.v;
        }
    };
    struct node
    {
        int ch[2], sz, dup;
        data d;
        node(int z = 1) :sz(z), dup(z) { ch[0] = ch[1] = 0; }
    };
    vector<node> nodes;
    vector<int> recycle;
    int root;
    int reserve_size;
    void reserve()
    {
        if (size() >= reserve_size)
            nodes.reserve((reserve_size *= 2) + 1);
    }
    int new_node()
    {
        int id = (int)nodes.size();
        if (!recycle.empty())
        {
            id = recycle.back();
            recycle.pop_back();
            nodes[id] = node();
        }
        else nodes.push_back(node());
        return id;
    }
    int insert(int& tp, const data& d)
    {
        if (tp == 0)
        {
            tp = new_node();
            nodes[tp].d = d;
            return tp;
        }
        node& n = nodes[tp];
        ++n.sz;
        if (d == n.d)
        {
            ++n.dup;
            return tp;
        }
        int r = d < n.d;
        int& s = n.ch[r ^ 1];
        int ret = insert(s, d);
        return ret;
    }
    int find(int tp, const data& d) // return id
    {
        if (tp == 0) return 0;
        if (d == nodes[tp].d) return tp;
        return find(nodes[tp].ch[(d < nodes[tp].d) ^ 1], d);
    }
    int erasefind(int& tp, const data& d) // return deleted
    {
        if (tp == 0) return 0;
        if (d == nodes[tp].d)
        {
            --nodes[tp].sz;
            if (--nodes[tp].dup <= 0) remove(tp);
            return 1;
        }
        if (erasefind(nodes[tp].ch[(d < nodes[tp].d) ^ 1], d))
        {
            --nodes[tp].sz;
            return 1;
        }
        return 0;
    }
    void remove(int& tp)
    {
        if (tp == 0) return;
        if (!nodes[tp].ch[0] || !nodes[tp].ch[1])
        {
            recycle.push_back(tp);
            tp = nodes[tp].ch[!nodes[tp].ch[0]];
        }
        else
        {
            int nxt = nodes[tp].ch[1];
            while (nodes[nxt].ch[0])
                nxt = nodes[nxt].ch[0];
            int dup = nodes[nxt].dup;
            nodes[tp].d = nodes[nxt].d;
            nodes[tp].dup = nodes[nxt].dup;
            recycle.push_back(nxt);
            int* tmp = &nodes[tp].ch[1];
            while (nodes[*tmp].ch[0])
            {
                nodes[*tmp].sz -= dup;
                tmp = &nodes[*tmp].ch[0];
            }
            *tmp = nodes[*tmp].ch[1];
        }
    }
    int kth(int tp, int k) // return id
    {
        node& n = nodes[tp];
        int sz = nodes[n.ch[0]].sz;
        if (sz >= k) return kth(n.ch[0], k);
        if (sz + n.dup >= k) return tp;
        return kth(n.ch[1], k - sz - n.dup);
    }
    int rank(int tp, const data& d, int dup)
    {
        if (tp == 0) return 1;
        node& n = nodes[tp];
        if (d == n.d) return nodes[n.ch[0]].sz + 1 + dup * n.dup;
        else if (d < n.d) return rank(n.ch[0], d, dup);
        return rank(n.ch[1], d, dup) + nodes[n.ch[0]].sz + n.dup;
    }
    // interface
    void init(int size)
    {
        nodes.clear();
        recycle.clear();
        nodes.reserve(size + 1);
        nodes.push_back(node(0));
        root = 0; reserve_size = size;
    }
    int get(int id) { return nodes[id].d.v; }
    int size() { return nodes[root].sz; }
    int insert(data v) { reserve(); return insert(root, v); }
    int erase(data v) { return erasefind(root, v); }
    int find(data v) { return find(root, v); } // return id
    int kth(int k) { return kth(root, k); } // return id
    // upperbound when upper = 1
    int rank(data v, int upper = 0) { return rank(root, v, upper); }
};

一些说明

子节点用的节段是ch数组(child的缩写),不使用left和right的原因是为了节省代码,例如在insert函数里,通过计算d < n.d的值是0或1决定下一步是递归调用左还是右子节点的时候,就不需要分别针对left和right写代码,后面的find和erase同理。

rank函数在upper为0的时候,找到的是相同元素里面最小的排名,如果v不存在树里面,那么就是v假如要插入到树里的排名。upper为1的时候,找到的是大于v的最小的元素的排名,即v的后继。

kth函数的参数如果非法,会导致无限循环,如果你想避免那么你可以在函数里添加检查,例如加一句if (tp == 0) return 0;即可。

优化BST

单纯的BST最大的问题是,它最坏的情况是可能成为一条链表,例如你按从小到大插入到树里面的时候,缺乏让它缩减树高的机制,所以接下来要讲两个非常重要的操作,就是树的旋转

图A

graph TD;
4-->2
2-->1
2-->3
4-->6

图B

graph TD;
2-->4
4-->3
2-->1
4-->6

以上两图,从A到B叫做zig,把左儿子旋转到root的位置,也叫右旋,B到A叫做zag,把右儿子旋转到root的位置,也叫左旋,旋转代码也很简单

void rotate(int& tp, int r)
{
    node& n = nodes[tp];
    int s = n.ch[r];
    n.ch[r] = nodes[s].ch[r ^ 1];
    nodes[s].ch[r ^ 1] = tp;
    tp = s;
}

以上参数r如果是0,就是zig,r是1就是zag。有了旋转操作,我们就可以开始看自平衡树了。

Treap

Treap其实炒鸡简单,在BST的基础上多一个随机数生成的字段,这个字段用于决定树要怎么旋转。这个字段就是个优先级,父节点的优先级不大于两个子节点的优先级,这其实就是heap,所以,Treap就是树堆(Tree-heap)。维护Treap,我们只需要在insert的时候,检查是不是满足heap,如果不满足就旋转,相对BST只加了非常少的代码,也就加了rotate函数,rnd函数(直接用rand也行),insert加了维护,以及旋转时需要的update函数维护size字段,也就是说Treap是最经济实惠的平衡树。

操作:

  • 更新(update): 累加左右子树的size
  • 删除(erase): 找到要删除的元素后,对比两子树的根的优先级,把较的小旋转到原来要删除的元素的位置,使要删除的元素的深度+1,直到删除的元素没有两个子树为止

    struct treap
    {
    struct data
    {
        int v;
        data(int _v = 0) :v(_v) {}
        bool operator==(const data& d) const
        {
            return v == d.v;
        }
        bool operator<(const data& d) const
        {
            return v < d.v;
        }
    };
    struct node
    {
        int ch[2], sz, dup;
        unsigned k;
        data d;
        node(int z = 1) :sz(z), dup(z), k(rnd()) { ch[0] = ch[1] = 0; }
        static unsigned rnd()
        {
            static unsigned r = 0x123;
            r = r * 69069 + 1;
            return r;
        }
    };
    vector<node> nodes;
    vector<int> recycle;
    int root;
    int reserve_size;
    void reserve()
    {
        if (size() >= reserve_size)
            nodes.reserve((reserve_size *= 2) + 1);
    }
    int new_node()
    {
        int id = (int)nodes.size();
        if (!recycle.empty())
        {
            id = recycle.back();
            recycle.pop_back();
            nodes[id] = node();
        }
        else nodes.push_back(node());
        return id;
    }
    void update(int tp)
    {
        node& n = nodes[tp];
        n.sz = n.dup + nodes[n.ch[0]].sz + nodes[n.ch[1]].sz;
    }
    int insert(int& tp, const data& d)
    {
        if (tp == 0)
        {
            tp = new_node();
            nodes[tp].d = d;
            return tp;
        }
        node& n = nodes[tp];
        ++n.sz;
        if (d == n.d)
        {
            ++n.dup;
            return tp;
        }
        int r = d < n.d;
        int& s = n.ch[r ^ 1];
        int ret = insert(s, d);
        if (nodes[s].k < n.k) rotate(tp, r ^ 1), update(tp);
        return ret;
    }
    void rotate(int& tp, int r)
    {
        node& n = nodes[tp];
        int s = n.ch[r];
        n.ch[r] = nodes[s].ch[r ^ 1];
        nodes[s].ch[r ^ 1] = tp;
        update(tp);
        tp = s;
    }
    int find(int tp, const data& d) // return id
    {
        if (tp == 0) return 0;
        if (d == nodes[tp].d) return tp;
        return find(nodes[tp].ch[(d < nodes[tp].d) ^ 1], d);
    }
    int erasefind(int& tp, const data& d) // return deleted
    {
        if (tp == 0) return 0;
        if (d == nodes[tp].d)
        {
            --nodes[tp].sz;
            if (--nodes[tp].dup <= 0) remove(tp);
            return 1;
        }
        if (erasefind(nodes[tp].ch[(d < nodes[tp].d) ^ 1], d))
        {
            --nodes[tp].sz;
            return 1;
        }
        return 0;
    }
    void remove(int& tp)
    {
        if (tp == 0) return;
        if (!nodes[tp].ch[0] || !nodes[tp].ch[1])
        {
            recycle.push_back(tp);
            tp = nodes[tp].ch[!nodes[tp].ch[0]];
        }
        else
        {
            int r = nodes[nodes[tp].ch[0]].k < nodes[nodes[tp].ch[1]].k;
            rotate(tp, r ^ 1);
            remove(nodes[tp].ch[r]);
            update(tp);
        }
    }
    int kth(int tp, int k) // return id
    {
        node& n = nodes[tp];
        int sz = nodes[n.ch[0]].sz;
        if (sz >= k) return kth(n.ch[0], k);
        if (sz + n.dup >= k) return tp;
        return kth(n.ch[1], k - sz - n.dup);
    }
    int rank(int tp, const data& d, int dup)
    {
        if (tp == 0) return 1;
        node& n = nodes[tp];
        if (d == n.d) return nodes[n.ch[0]].sz + 1 + dup * n.dup;
        else if (d < n.d) return rank(n.ch[0], d, dup);
        return rank(n.ch[1], d, dup) + nodes[n.ch[0]].sz + n.dup;
    }
    // interface
    void init(int size)
    {
        nodes.clear();
        recycle.clear();
        nodes.reserve(size + 1);
        nodes.push_back(node(0));
        root = 0; reserve_size = size;
    }
    int get(int id) { return nodes[id].d.v; }
    int size() { return nodes[root].sz; }
    int insert(data v) { reserve(); return insert(root, v); }
    int erase(data v) { return erasefind(root, v); }
    int find(data v) { return find(root, v); } // return id
    int kth(int k) { return kth(root, k); } // return id
    // upperbound when upper = 1
    int rank(data v, int upper = 0) { return rank(root, v, upper); }
    };
    

SBT

这个名字和BST特别像,但它全名是Size Balanced Tree,作者是CQF,它通过size字段来进行树平衡。它的关键操作叫做maintain,这个操作的平摊复杂度是 $O(1)$ ,这货具体解释可以看CQF的论文,SBT的时间常数比Treap更小一些,内存也更小,不过代码也稍长一些(就是因为maintain

struct sbt
{
    struct data
    {
        int v;
        data(int _v = 0) :v(_v) {}
        bool operator==(const data& d) const
        {
            return v == d.v;
        }
        bool operator<(const data& d) const
        {
            return v < d.v;
        }
    };
    struct node
    {
        int ch[2], sz, dup;
        data d;
        node(int z = 1) :sz(z), dup(z) { ch[0] = ch[1] = 0; }
    };
    vector<node> nodes;
    vector<int> recycle;
    int root;
    int reserve_size;
    void reserve()
    {
        if (size() >= reserve_size)
            nodes.reserve((reserve_size *= 2) + 1);
    }
    int new_node()
    {
        int id = (int)nodes.size();
        if (!recycle.empty())
        {
            id = recycle.back();
            recycle.pop_back();
            nodes[id] = node();
        }
        else nodes.push_back(node());
        return id;
    }
    void update(int tp)
    {
        node& n = nodes[tp];
        n.sz = n.dup + nodes[n.ch[0]].sz + nodes[n.ch[1]].sz;
    }
    int insert(int& tp, const data& d)
    {
        if (tp == 0)
        {
            tp = new_node();
            nodes[tp].d = d;
            return tp;
        }
        node& n = nodes[tp];
        ++n.sz;
        if (d == n.d)
        {
            ++n.dup;
            return tp;
        }
        int r = d < n.d;
        int& s = n.ch[r ^ 1];
        int ret = insert(s, d);
        maintain(tp, r ^ 1);
        return ret;
    }
    void rotate(int& tp, int r)
    {
        node& n = nodes[tp];
        int s = n.ch[r];
        n.ch[r] = nodes[s].ch[r ^ 1];
        nodes[s].ch[r ^ 1] = tp;
        update(tp); update(s);
        tp = s;
    }
    void maintain(int& tp, int s)
    {
        if (tp == 0) return;
        if (nodes[nodes[nodes[tp].ch[s]].ch[s]].sz > nodes[nodes[tp].ch[s ^ 1]].sz)
            rotate(tp, s);
        else if (nodes[nodes[nodes[tp].ch[s]].ch[s ^ 1]].sz > nodes[nodes[tp].ch[s ^ 1]].sz)
        {
            rotate(nodes[tp].ch[s], s ^ 1);
            rotate(tp, s);
        }
        else return;

        maintain(nodes[tp].ch[s], s);
        maintain(nodes[tp].ch[s ^ 1], s ^ 1);
        maintain(tp, s);
        maintain(tp, s ^ 1);
    }

    int find(int tp, const data& d) // return id
    {
        if (tp == 0) return 0;
        if (d == nodes[tp].d) return tp;
        return find(nodes[tp].ch[(d < nodes[tp].d) ^ 1], d);
    }
    int erasefind(int& tp, const data& d) // return deleted
    {
        if (tp == 0) return 0;
        if (d == nodes[tp].d)
        {
            --nodes[tp].sz;
            if (--nodes[tp].dup <= 0) remove(tp);
            return 1;
        }
        if (erasefind(nodes[tp].ch[(d < nodes[tp].d) ^ 1], d))
        {
            --nodes[tp].sz;
            return 1;
        }
        return 0;
    }
    void remove(int& tp)
    {
        if (tp == 0) return;
        if (!nodes[tp].ch[0] || !nodes[tp].ch[1])
        {
            recycle.push_back(tp);
            tp = nodes[tp].ch[!nodes[tp].ch[0]];
        }
        else
        {
            int r = nodes[nodes[tp].ch[0]].sz >= nodes[nodes[tp].ch[1]].sz;
            rotate(tp, r ^ 1);
            remove(nodes[tp].ch[r]);
            update(tp);
        }
    }
    int kth(int tp, int k) // return id
    {
        node& n = nodes[tp];
        int sz = nodes[n.ch[0]].sz;
        if (sz >= k) return kth(n.ch[0], k);
        if (sz + n.dup >= k) return tp;
        return kth(n.ch[1], k - sz - n.dup);
    }
    int rank(int tp, const data& d, int dup)
    {
        if (tp == 0) return 1;
        node& n = nodes[tp];
        if (d == n.d) return nodes[n.ch[0]].sz + 1 + dup * n.dup;
        else if (d < n.d) return rank(n.ch[0], d, dup);
        return rank(n.ch[1], d, dup) + nodes[n.ch[0]].sz + n.dup;
    }
    // interface
    void init(int size)
    {
        nodes.clear();
        recycle.clear();
        nodes.reserve(size + 1);
        nodes.push_back(node(0));
        root = 0; reserve_size = size;
    }
    int get(int id) { return nodes[id].d.v; }
    int size() { return nodes[root].sz; }
    int insert(data v) { reserve(); return insert(root, v); }
    int erase(data v) { return erasefind(root, v); }
    int find(data v) { return find(root, v); } // return id
    int kth(int k) { return kth(root, k); } // return id
    // upperbound when upper = 1
    int rank(data v, int upper = 0) { return rank(root, v, upper); }
};

温馨提示

别看代码长,如果你还没自己实现过,硬着头皮写一次,你就懂了,核心其实只有那么点。以上没有实现插入相等元素后用链表串起来,实在有需要时自己加一下就好了。又或者,如果你不需要处理相同元素,那么代码也有不少的简化,特别是删除元素的地方。

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus