堆排序优化思路

堆排序其实是最不好优化的一个,在数据结构都确定的情况下,优化空间太小,除非优化数据结构本身,但那样就不叫做堆排序了。类似堆排序的还有smooth sortweakheap sort,有兴趣可以自己查找相关资料。

堆排序原理

堆排序其实是选择排序的优化变种,选择排序是把最大或最小的元素放到最边上,然后不断重复以上过程,堆排序也是如此,只不过堆排序通过构建数据结构,让查找最大或最小元素并放到最边上的速度比选择排序快得多。

首先我们先来介绍什么是堆。堆只是个缩写,全名是二叉堆,是一种完全二叉树,它的特点是二叉堆的父节点元素不小于子节点的元素,以下为二叉堆例图

graph TD;
30-->29
30-->28
29-->24
29-->25
28-->26
28-->22
24-->21

根节点是最大值的时候,就叫做最大堆,反之叫做最小堆。之所以使用完全二叉堆,是为了它能直接放到数组里,例如上图放数组里的结果是:

下标 0 1 2 3 4 5 6 7
30 29 28 24 25 26 22 21

可以看出,就是按层序遍历的结果。这样表示的好处是,可以通过下标的简单运算得到子节点的下标,我们通过简单找规律就能发现,下标k的子节点是2k+1和2k+2,所以我们能直接在数组里组织一个二叉堆。而且堆的根节点就是最值,找最值的时间复杂度是O(1)

那么假如堆结构已经组织好了,我们接下来看怎么排序,还是以刚刚的数组作为例子,下标0就是最大值,那么我们就把它拿出来,与最右边元素交换,得到

下标 0 1 2 3 4 5 6 7
21 29 28 24 25 26 22 30
指针 r a b

下标7的是已经排好的,后面不管它。这时候堆的性质破坏了,我们要去修正,指针r的子结点分别是a和b,在r,a,b中找出最大的元素a与r交换,得到

下标 0 1 2 3 4 5 6
29 21 28 24 25 26 22
指针 r a b

再继续操作

下标 0 1 2 3 4 5 6
29 25 28 24 21 26 22
指针 r

这时候,指针r已经没有子结点,这时候堆就修正好了。这个过程由于是从根向叶子的操作,叫做shiftdown,还有一种相反的过程叫做shiftup,就是从叶子向根的方向操作。以上就是一轮完整的操作,包括交换最值,修正堆两步。不断循环以上操作直到所有元素有序即可。

排序方法说完了,我们回头来说怎么构建堆。用shiftup法描述起来比较简单,一开始,只看下标0,那一个元素就肯定是堆。然后添加一个元素在最末,然后shiftup,修改新添加的元素,如此循环。伪代码如下

makeheap(arr, len)
    for i in (1, len)
        shiftup(arr, i)

而使用shiftdown的话,原理一样,只是换一个方向

makeheap(arr, len)
    for i in (len / 2, 1)
        shiftdown(arr, i)

以下为完整动态图演示

堆排序基本实现

以下shiftdown实现的函数名是max_heapify

void max_heapify(sort_element_t arr[], size_t index, size_t length)
{
    size_t child;
    sort_element_t temp = arr[index];
    for (; (child = 2 * index + 1) < length; index = child)
    {
        if (child + 1 < length && arr[child] < arr[child + 1])
            ++child;
        if (temp < arr[child])
            arr[index] = arr[child];
        else
            break;
    }
    arr[index] = temp;
}

void heap_sort(sort_element_t arr[], size_t length)
{
    if (length > 1)
    {
        for (size_t i = length / 2; i-- > 0; )
            max_heapify(arr, i, length);
        for (size_t i = length - 1; i > 0; --i)
        {
            sort_element_swap(&arr[0], &arr[i]);
            max_heapify(arr, 0, i);
        }
    }
}

void heap_sort(sort_element_t* beg, sort_element_t* end)
{
    heap_sort(beg, end - beg);
}

堆排序”优化”

我们来重新看原来的表

下标 0 1 2 3 4 5 6 7
30 29 28 24 25 26 22 21

这里下标k的子节点是2k+1和2k+2,运算起来麻烦,那我们如果把所有的下标加1,得到

下标 1 2 3 4 5 6 7 8
30 29 28 24 25 26 22 21

这时候,下标k的子节点是2k和2k+1,能减少一步加法运算

所以可以得到以下模板代码

template <class RandomAccessIterator, class Comp>
void max_heapify(RandomAccessIterator arr, size_t index, size_t last, Comp compare)
{
    typename std::iterator_traits<RandomAccessIterator>::value_type temp = arr[index];
    size_t child;
    for (; (child = index << 1) <= last; index = child)
    {
        if (child < last && compare(*(arr + child), *(arr + child + 1)))
            ++child;
        if (compare(temp, *(arr + child)))
            *(arr + index) = *(arr + child);
        else
            break;
    }
    *(arr + index) = temp;
}

template <class RandomAccessIterator, class Comp>
void heap_sort(RandomAccessIterator beg, RandomAccessIterator end, Comp compare)
{
    if (end - beg > 1)
    {
        size_t length = (size_t)(end - beg);
        RandomAccessIterator parr = beg - 1;
        for (size_t i = length / 2; i > 0; --i)
            max_heapify_1(parr, i, length, compare);
        for (size_t i = length - 1; i > 0; --i)
        {
            std::swap(*beg, *(beg + i));
            max_heapify_1(parr, 1, i, compare);
        }
    }
}

template <class RandomAccessIterator, class Comp>
void heap_sort(RandomAccessIterator beg, RandomAccessIterator end, Comp compare)
{
    heap_sort(beg, end, compare);
}

但是,还是那句但是,这些所谓的优化到了不同编译器的手上结果可能会和你想象的不一样。这个在VS上是有效果的,但不幸的是这份代码在GCC下跑得并没有比原来的快,不过在GCC上办法还是有的,就是把下标操作全部换成指针,见代码

template <class RandomAccessIterator, class Comp>
void max_heapify_p(RandomAccessIterator first, RandomAccessIterator target, RandomAccessIterator last, Comp compare)
{
    typename std::iterator_traits<RandomAccessIterator>::value_type temp = *target;
    --first;
    RandomAccessIterator son;
    for (; (son = target + (target - first)) <= last; target = son)
    {
        if (son < last && compare(*son, *(son + 1)))
            ++son;
        if (compare(temp, *son))
            *target = *son;
        else
            break;
    }
    *target = temp;
}

template <class RandomAccessIterator, class Comp>
void heap_sort_p(RandomAccessIterator beg, RandomAccessIterator end, Comp compare)
{
    if (end - beg > 1)
    {
        for (RandomAccessIterator i = beg + (end - beg) / 2; i >= beg; --i)
            max_heapify_p(beg, i, end - 1, compare);
        for (; --end > beg; )
        {
            std::swap(*beg, *end);
            max_heapify_p(beg, beg, end - 1, compare);
        }
    }
}

这份代码比GCC STL中的make_heap/sort_heap实现来得快一些。完整实现可以参见我的sort项目

堆排序其它注意问题

堆排序在特殊情况下是能以O(n)复杂度完成的,就是当几乎所有元素或所有元素都相等的时候,是可以很快完成的。但是VS和GCC的STL中的堆排序实现面对这种情形时却花费较多的时间,原因是对相等元素的处理上,我们拿最前面的代码作为例子,我们如果改成这样子:

void max_heapify(sort_element_t arr[], size_t index, size_t length)
{
    size_t child;
    sort_element_t temp = arr[index];
    for (; (child = 2 * index + 1) < length; index = child)
    {
        if (child + 1 < length && arr[child] < arr[child + 1])
            ++child;
        if (temp <= arr[child]) // 小于改小于等于
            arr[index] = arr[child];
        else
            break;
    }
    arr[index] = temp;
}

或者是

void max_heapify(sort_element_t arr[], size_t index, size_t length)
{
    size_t child;
    sort_element_t temp = arr[index];
    for (; (child = 2 * index + 1) < length; index = child)
    {
        if (child + 1 < length && arr[child] < arr[child + 1])
            ++child;
        if (arr[child] < temp) // 小于才跳出,与前一个等价
            break;
        arr[index] = arr[child];
    }
    arr[index] = temp;
}

这样导致的后果是不存在最优情况下O(n)的时间复杂度。最后再来回应开头的问题,什么很少听说smooth sortweakheap sort呢?单从时间复杂度其实看不出来不用它们的理由,但真正的问题是时间常数的问题,对普通的乱序数组,smooth sortweakheap sortheap sort慢,而且实现更复杂。

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus