Quick sort(快速排序)杂谈 2

上一篇我们介绍了四种不同的划分算法,现在我们来针对Hoare partition scheme来讲解一些优化和注意的问题。

最坏时间复杂度的优化

在前一篇的示例代码里面,只是最简单地选择的最开头或最后面的元素作为划分,这对于乱序的数据还好,对于有序的数据这么做,时间复杂度就直接变成$O(n^2)$了,那么怎么解决?第一个要解决的反而不是划分元素的选择上,而是改成intro sort,记录递归深度或类似的办法,到达限制条件的时候改而使用堆排序,这属于混合排序,先让排序的最坏时间复杂度降下来是第一要务。所以可以改写出以下代码:

sort_element_t* intro_sort_partition(sort_element_t * beg, sort_element_t * end)
{
    sort_element_t *l = beg, *r = end - 1;
    sort_element_t pivot = *r;
    while (1)
    {
        while (l < r && *l < pivot)
            ++l;
        while (l < r && !(*r < pivot))
            --r;
        if (l >= r)
            break;
        sort_element_swap(l++, r);
    }
    sort_element_swap(l, end - 1);
    return l;
}

void intro_sort_recursive(sort_element_t * beg, sort_element_t * end, int depth)
{
    if (end - beg <= 1)
        return;

    if (depth <= 0)
    {
        heap_sort(beg, end);
        return;
    }

    sort_element_t *p = intro_sort_partition(beg, end);

    intro_sort_recursive(beg, p, depth - 1);
    intro_sort_recursive(p + 1, end, depth - 1);
}

void intro_sort(sort_element_t * beg, sort_element_t * end)
{
    int depth = (int)(log(end - beg) * 2.5);
    intro_sort_recursive(beg, end, depth);
}

同时,这份代码使用的分割结果是

< = >=

至于为什么选择这种划分结果,与后面的优化有关系。

划分元素的选择

好,我们从这个代码开始重新谈一下分割元素的选取,网上常见有以下方案:

  • 固定选择最开头或最后一个
  • 固定选择中间的元素
  • 随机选择元素
  • 在开头、中间、末尾固定三个数里选择一个中位数
  • 在更多的数里面找中位数,同样是固定选择

首先,我们能显然地排除第一个,它对已排序或接近排序完成的数据效率特别低,所以,接下来我们的思路就是分别针对这些选择方案看看有哪些情况让它们特别低效。

那么按这个思路,我样再来考虑一下,如果要排序的数据重复元素特别的多,甚至于所有元素都相等呢?这时候你会发现不管哪种选取方案都没有用,都会陷入最坏情况,所以我们接下来在继续研究分割元素的选取之前,先考虑重复元素的问题。

重复元素的过滤

之所以分割方式使用上文那种,为的就是能更方便地过滤重复元素,大于等于分割元素的,都一定在分割位置的右边,所以intro_sort_recursive函数可以这样改

void intro_sort_recursive(sort_element_t * beg, sort_element_t * end, int depth)
{
    if (end - beg <= 1)
        return;

    if (depth <= 0)
    {
        heap_sort(beg, end);
        return;
    }

    sort_element_t *p = intro_sort_partition(beg, end);

    intro_sort_recursive(beg, p, depth - 1);
    for (++p; p < end && !(p[-1] < *p || *p < p[-1]); )
        ++p;
    intro_sort_recursive(p, end, depth - 1);
}

其中!(p[-1] < *p || *p < p[-1])就是判断p[-1] == *p

这样修改后,在所有元素相等的情况下,时间复杂度能达到$O(n)$

再次考虑划分元素的选择

这次我们把方案一个个来列举优缺点

  1. 固定选择中间的元素,速度快,对有序或逆序的数据分割效率最高,但在中间附近元素都是数据里最大的数据的时候,即遇到最坏情况,会快速退化成堆排序
  2. 在开头、中间、末尾固定三个数里选择一个中位数,这个是也gcc的STL实现,但能构造出让它遇到最坏情况的数据,比乱序数据慢4倍甚至更多
  3. 在更多的数里面找中位数,同样是固定选择,VS的STL实现就是在9个数里面取,但同样能构造出让它遇到最坏情况的数据,比乱序数据慢4倍甚至更多
  4. 随机选择元素,遇到最坏情况机率极低,但对于有序或逆序的数据时间比前三种都会慢一些

事实上,任何的固定选取方案理论上都能构造出固定的使其遇到最坏的情况,而我们的目标当然是让它减少遇到的机率,所以我推荐增加随机选择在里面,但单纯的随机选择虽然并不太优秀,但已足够实用,而且我们并不需要真的搞什么复杂的随机数,所以我们先来实现这个:

sort_element_t* intro_sort_partition(sort_element_t * beg, sort_element_t * end)
{
    static int s_rnd = 0x123456;
    sort_element_t *l = beg, *r = end - 1;
    sort_element_swap(r, l + (++s_rnd % (end - beg)));
    sort_element_t pivot = *r;
    while (1)
    {
        while (l < r && *l < pivot)
            ++l;
        while (l < r && !(*r < pivot))
            --r;
        if (l >= r)
            break;
        sort_element_swap(l++, r);
    }
    sort_element_swap(l, end - 1);
    return l;
}

只要一个静态的计数器作为随机数其实就足够了,把选中的元素交换到最右边即可。

那我们还是要考虑优化的话怎么弄呢?那我们结合着来,在三个数里取中位数,不过还要带上随机,代码如下

inline void make_mid_pivot(sort_element_t* l, sort_element_t* mid, sort_element_t* r)
{
    if (*r < *mid)
    {
        if (*mid < *l)
            return;
        sort_element_swap(mid, r);
    }
    if (*mid < *l)
    {
        sort_element_swap(mid, l);
        if (*r < *mid)
            sort_element_swap(mid, r);
    }
}

sort_element_t* intro_sort_partition(sort_element_t * beg, sort_element_t * end)
{
    static int s_rnd = 0x123456;
    sort_element_t *l = beg, *r = end - 1;
    int half = (end - beg) / 2;
    make_mid_pivot(l + s_rnd % half, l + half, r - s_rnd % half);
    ++s_rnd;
    sort_element_swap(r, l + half);
    sort_element_t pivot = *r;
    while (1)
    {
        while (l < r && *l < pivot)
            ++l;
        while (l < r && !(*r < pivot))
            --r;
        if (l >= r)
            break;
        sort_element_swap(l++, r);
    }
    sort_element_swap(l, end - 1);
    return l;
}

到这个阶段,这个实现已经能和VS版本的std::sort平起平坐了,接下我们就要超越它

小数据优化

当需要排序的数据长度较小的时候,快速排序其实并不快,我们改用插入排序,在16个元素或以下的时候使用。修改的那部分代码如下

void intro_sort_recursive(sort_element_t * beg, sort_element_t * end, int depth)
{
    if (end - beg <= 16)
    {
        insert_sort(beg, end);
        return;
    }

    if (depth <= 0)
    {
        heap_sort(beg, end);
        return;
    }

    sort_element_t *p = intro_sort_partition(beg, end);

    intro_sort_recursive(beg, p, depth - 1);
    for (++p; p < end && !(p[-1] < *p || *p < p[-1]); )
        ++p;
    intro_sort_recursive(p, end, depth - 1);
}

至此,已经能比VS的std::sort快了

运行测试

我们来看在VS2005下对4500000个int排序的测试结果

int 1 2 3 4 5 6 7 8 9 10 Avg
intro_sort 3 38 43 282 281 71 18 60 103 269 116
std::sort 2 53 57 327 331 72 40 85 114 317 139

再看mingw64-gcc9下使用-O3参数编译,对4500000个int排序的测试结果

int 1 2 3 4 5 6 7 8 9 10 Avg
intro_sort 4 41 46 274 281 67 18 80 105 266 118
std::sort 45 52 55 254 255 98 82 64 106 248 125

可以看到几乎在所有的测试数据里都比VS的实现快。不过尽管这个比VS的实现快,但比起gcc的实现还是稍差一点点,但已经足够好。

本篇就介绍到这里了,最后把优化后的完整代码给出,更多优化请看下一篇文章

点击展开

inline void make_mid_pivot(sort_element_t* l, sort_element_t* mid, sort_element_t* r)
{
    if (*r < *mid)
    {
        if (*mid < *l)
        {
            return;
        }
        sort_element_swap(mid, r);
    }
    if (*mid < *l)
    {
        sort_element_swap(mid, l);
        if (*r < *mid)
            sort_element_swap(mid, r);
    }
}

sort_element_t* intro_sort_partition(sort_element_t * beg, sort_element_t * end)
{
    static int s_rnd = 0x123456;
    sort_element_t *l = beg, *r = end - 1;
    int half = (end - beg) / 2;
    make_mid_pivot(l + s_rnd % half, l + half, r - s_rnd % half);
    ++s_rnd;
    sort_element_swap(r, l + half);
    sort_element_t pivot = *r;
    while (1)
    {
        while (l < r && *l < pivot)
            ++l;
        while (l < r && !(*r < pivot))
            --r;
        if (l >= r)
            break;
        sort_element_swap(l++, r);
    }
    sort_element_swap(l, end - 1);
    return l;
}

void intro_sort_recursive(sort_element_t * beg, sort_element_t * end, int depth)
{
    if (end - beg <= 16)
    {
        insert_sort(beg, end);
        return;
    }

    if (depth <= 0)
    {
        heap_sort(beg, end);
        return;
    }

    sort_element_t *p = intro_sort_partition(beg, end);

    intro_sort_recursive(beg, p, depth - 1);
    for (++p; p < end && !(p[-1] < *p || *p < p[-1]); )
        ++p;
    intro_sort_recursive(p, end, depth - 1);
}

void intro_sort(sort_element_t * beg, sort_element_t * end)
{
    int depth = (int)(log(end - beg) * 2.5);
    intro_sort_recursive(beg, end, depth);
}

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus