二分搜索杂谈
二分搜索虽然基本思想简单,但其细节却令人意外的抓狂(Although the basic idea of binary search is comparatively straightforward, the details can be surprisingly tricky)。这里我们来分析一下常见写法的坑。
两类不同的二分
二分有两类,一是找值,即判断某个值存不存在,二是找边界,前者比后者简单很多,以下是找值的典型写法(来自中文wiki百科)
int binary_search(const int arr[], int start, int end, int key) {
int ret = -1; // 未搜索到数据返回-1下标
int mid;
while (start <= end) {
mid = start + (end - start) / 2; //直接平均可能會溢位,所以用此算法
if (arr[mid] < key)
start = mid + 1;
else if (arr[mid] > key)
end = mid - 1;
else { // 最後檢測相等是因為多數搜尋狀況不是大於要不就小於
ret = mid;
break;
}
}
return ret; // 单一出口
}
对于找边界,有四种情况,如以下例子,查找数值5的四种边界
1 1 1 2 2 5 5 5 5 7 9
^ 1 小于的最右元素
^ 2 大于等于的最左元素
^ 3 小于等于的最右元素
^ 4 大于的最左元素
既然有四种边界,那就有四种写法。
常见基本写法
为了好理解,这里直接用int写,且使用闭区间写法,而不使用模板
int bin_search_1(int arr[], int len, int val) {
int l = 0, r = len - 1;
while (l < r) {
int m = l + ((r - l + 1) / 2);
if (arr[m] < val) {
l = m;
} else {
r = m - 1;
}
}
return l;
}
int bin_search_2(int arr[], int len, int val) {
int l = 0, r = len - 1;
while (l < r) {
int m = l + ((r - l) / 2);
if (arr[m] < val) {
l = m + 1;
} else {
r = m;
}
}
return l;
}
int bin_search_3(int arr[], int len, int val) {
int l = 0, r = len - 1;
while (l < r) {
int m = l + ((r - l + 1) / 2);
if (arr[m] <= val) {
l = m;
} else {
r = m - 1;
}
}
return l;
}
int bin_search_4(int arr[], int len, int val) {
int l = 0, r = len - 1;
while (l < r) {
int m = l + ((r - l) / 2);
if (arr[m] <= val) {
l = m + 1;
} else {
r = m;
}
}
return l;
}
int main() {
int a[] = {1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10};
int len = 16, val = 5;
cout << bin_search_1(a, len, val) << endl;
cout << bin_search_2(a, len, val) << endl;
cout << bin_search_3(a, len, val) << endl;
cout << bin_search_4(a, len, val) << endl;
return 0;
}
以上写法的坑点
坑点主要有以下4个,当这4个你搞明白了以后,你就可以肉眼debug出一个二分写法的问题了。
1. 区间表示
如果采用闭区间,那么循环的条件就是l < r
,当这个条件满足时,这个区间就表示了两个或以上的元素
同理,如果采用半开半闭区间,那么循环的条件就是l + 1 < r
,如果采用开区间,那么循环条件就是l + 2 < r
2. 比较
比较的方式取决于你的边界是左边界还是右边界,如果是左边界,那么边界左边的数均为小于,而右边的数是大于等于,那么你应该用 arr[m] < val
作为判断条件;同理地,如果是右边界,那么边界左边的数均为小于等于,而右边的数是大于,这时应该用arr[m] <= val
作为判断条件
3. 更新区间
在比较后,更新区间时,有的+1有的-1有的不需要加减,这是怎么决定要不要加呢?这个由你所查找的区间是否包含这个数决定。举个具体例子,在4 5 5 6
中,给出要查找的数值5
,要查找到小于5的最大那个数,即4,那么比较方式是arr[m] < val
。如果这个小于号满足,比如说arr[m]
是4,那么这个数是可能在查找区间内,所以不应该+1或-1,但如果arr[m]
是5,这个数不应该在查找区间内,那么就果断要+1或-1。在边界查找写法里,必然有一边是需要+1或-1,而另一边不需要,这个会影响下面要介绍的中间数选择。
4. 中间数的选择
中间数选择,即以上代码中的变量m,有的时候(r - l) / 2
,有的时候(r - l + 1) / 2
,这个写法取决于当r - l == 1
时,那么m要么等于l,要么等于r,为了这个循环能结束,如果下方代码是l = m
,那么m要取r,即要+1,如果下方代码是r = m
,那么m要取l,不需要+1。这个如果写错就会导致死循环的发生。
另外,还有一个更多人犯的错误,很多人会写m = (l + r) / 2
,当然,这个其实在不少情况下确实也不会怎么样,但是,如果我们要做的事情并不是在数组中查找,而是在一个区间里面,比如找一个方程的整数解,或者求3次方根的整数部分,这就会产生问题了,当 l和r是负数的时候,与它们是正数的时候,除以2的含义是不一样的,C语言中的除法实际上是向0取整,比如说,-7/2 == -3
,但是,在二分搜索时,我们必须要么向上取整,要么向下取整,但存在负数时,这样做除法可能会导致二分死循环。采用(r - l) / 2
可以避免这个问题,或者,你还可以使用位运算技巧(r - l) >> 1
。当然,为了减少中坑起见,写m = l + (r - l) / 2
最保险。
避免进坑的写法
以上写法坑太多(其实这些坑的有一部分的本质是你采用了闭区间),怎么样避免掉以上种种问题搞一个坑最少最不容易出错的写法呢?来看这个写法
int bin_search_r1(int arr[], int len, int val) {
int l = -1, r = len;
while (l + 1 < r) {
int m = l + ((r - l) / 2);
if (arr[m] < val) {
l = m;
} else {
r = m;
}
}
return l;
}
int bin_search_r2(int arr[], int len, int val) {
int l = -1, r = len;
while (l + 1 < r) {
int m = l + ((r - l) / 2);
if (arr[m] < val) {
l = m;
} else {
r = m;
}
}
return r;
}
int bin_search_r3(int arr[], int len, int val) {
int l = -1, r = len;
while (l + 1 < r) {
int m = l + ((r - l) / 2);
if (arr[m] <= val) {
l = m;
} else {
r = m;
}
}
return l;
}
int bin_search_r4(int arr[], int len, int val) {
int l = -1, r = len;
while (l + 1 < r) {
int m = l + ((r - l) / 2);
if (arr[m] <= val) {
l = m;
} else {
r = m;
}
}
return r;
}
以上写法看初值似乎是在采用开区间,其实不然,实际上是半开半闭区间表示(看循环条件),表示要查找的值在[l, r)
内或(l, r]
内,返回值如果是l那区间就是[l, r)
,反之就是(l, r]
。在循环结束前,区间[l, r]
内至少有3个元素,这样m肯定与l或r不相等,不存在死循环的可能性,也不需要关心向上取整还是向下取整的问题,这也是半开半闭区间表示的优点。另外,不论区间是3个还是4个元素,最终都会缩短为刚好2个元素,l指向满足比较条件的最右的元素,r指向不满足比较条件的最左元素。还有一点,由于m永远不会等于l或r,所以这样初始化并不会产生越界访问,而且这个初始化保证了m有可能取到[l+1, r-1]
中任何一个。所以,这个二分写法,我们只需要管比较条件的写法来决定找左边界或右边界,以及返回l或r决定具体边界元素即可,几乎就是个无坑版本写法,非常建议在比赛里采用这个写法。
简单的参考模板
// 等同于std::lower_bound
template <typename ITER, typename V>
ITER bin_search_lower(ITER begin, ITER end, const V &val) {
ITER l = begin - 1, r = end;
while (l + 1 < r) {
ITER m = l + (r - l) / 2;
if (*m < val) {
l = m;
} else {
r = m;
}
}
return l;
}
// 等同于std::upper_bound
template <typename ITER, typename V>
ITER bin_search_upper(ITER begin, ITER end, const V &val) {
ITER l = begin - 1, r = end;
while (l + 1 < r) {
ITER m = l + (r - l) / 2;
if (*m <= val) {
l = m;
} else {
r = m;
}
}
return r;
}
// cmp为true表示这个数在val的左边,否则在右边
template <typename ITER, typename V, class COMP>
ITER bin_search(ITER begin, ITER end, const V &val, COMP cmp) {
ITER l = begin - 1, r = end;
while (l + 1 < r) {
ITER m = l + (r - l) / 2;
if (cmp(*m, val)) {
l = m;
} else {
r = m;
}
}
return l;
}
扩展:更通用的写法
前面的写法只能使用随机迭代器,如果现在只有前向迭代器,那这写法就不行了。以下写法来自cppreference
template<class ForwardIt, class T, class Compare>
ForwardIt lower_bound(ForwardIt first, ForwardIt last, const T& value, Compare comp)
{
ForwardIt it;
typename std::iterator_traits<ForwardIt>::difference_type count, step;
count = std::distance(first, last);
while (count > 0) {
it = first;
step = count / 2;
std::advance(it, step);
if (comp(*it, value)) {
first = ++it;
count -= step + 1;
}
else
count = step;
}
return first;
}
当然,更通用意味着自己写时更容易出问题,尽量直接用STL不自己手写那是最好的。