二分搜索杂谈

二分搜索虽然基本思想简单,但其细节却令人意外的抓狂(Although the basic idea of binary search is comparatively straightforward, the details can be surprisingly tricky)。这里我们来分析一下常见写法的坑。

两类不同的二分

二分有两类,一是找值,即判断某个值存不存在,二是找边界,前者比后者简单很多,以下是找值的典型写法(来自中文wiki百科)

int binary_search(const int arr[], int start, int end, int key) {
	int ret = -1;       // 未搜索到数据返回-1下标

	int mid;
	while (start <= end) {
		mid = start + (end - start) / 2; //直接平均可能會溢位,所以用此算法
		if (arr[mid] < key)
			start = mid + 1;
		else if (arr[mid] > key)
			end = mid - 1;
		else {            // 最後檢測相等是因為多數搜尋狀況不是大於要不就小於
			ret = mid;  
			break;
		}
	}

	return ret;     // 单一出口
}

对于找边界,有四种情况,如以下例子,查找数值5的四种边界

1 1 1 2 2 5 5 5 5 7 9
        ^ 1 小于的最右元素
          ^ 2 大于等于的最左元素
                ^ 3 小于等于的最右元素
                  ^ 4 大于的最左元素

既然有四种边界,那就有四种写法。

常见基本写法

为了好理解,这里直接用int写,且使用闭区间写法,而不使用模板

int bin_search_1(int arr[], int len, int val) {
  int l = 0, r = len - 1;
  while (l < r) {
    int m = l + ((r - l + 1) / 2);
    if (arr[m] < val) {
      l = m;
    } else {
      r = m - 1;
    }
  }
  return l;
}

int bin_search_2(int arr[], int len, int val) {
  int l = 0, r = len - 1;
  while (l < r) {
    int m = l + ((r - l) / 2);
    if (arr[m] < val) {
      l = m + 1;
    } else {
      r = m;
    }
  }
  return l;
}

int bin_search_3(int arr[], int len, int val) {
  int l = 0, r = len - 1;
  while (l < r) {
    int m = l + ((r - l + 1) / 2);
    if (arr[m] <= val) {
      l = m;
    } else {
      r = m - 1;
    }
  }
  return l;
}

int bin_search_4(int arr[], int len, int val) {
  int l = 0, r = len - 1;
  while (l < r) {
    int m = l + ((r - l) / 2);
    if (arr[m] <= val) {
      l = m + 1;
    } else {
      r = m;
    }
  }
  return l;
}

int main() {
  int a[] = {1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10};
  int len = 16, val = 5;
  cout << bin_search_1(a, len, val) << endl;
  cout << bin_search_2(a, len, val) << endl;
  cout << bin_search_3(a, len, val) << endl;
  cout << bin_search_4(a, len, val) << endl;
  return 0;
}

以上写法的坑点

坑点主要有以下4个,当这4个你搞明白了以后,你就可以肉眼debug出一个二分写法的问题了。

1. 区间表示

如果采用闭区间,那么循环的条件就是l < r,当这个条件满足时,这个区间就表示了两个或以上的元素

同理,如果采用半开半闭区间,那么循环的条件就是l + 1 < r,如果采用开区间,那么循环条件就是l + 2 < r

2. 比较

比较的方式取决于你的边界是左边界还是右边界,如果是左边界,那么边界左边的数均为小于,而右边的数是大于等于,那么你应该用 arr[m] < val作为判断条件;同理地,如果是右边界,那么边界左边的数均为小于等于,而右边的数是大于,这时应该用arr[m] <= val作为判断条件

3. 更新区间

在比较后,更新区间时,有的+1有的-1有的不需要加减,这是怎么决定要不要加呢?这个由你所查找的区间是否包含这个数决定。举个具体例子,在4 5 5 6中,给出要查找的数值5,要查找到小于5的最大那个数,即4,那么比较方式是arr[m] < val。如果这个小于号满足,比如说arr[m]是4,那么这个数是可能在查找区间内,所以不应该+1或-1,但如果arr[m]是5,这个数不应该在查找区间内,那么就果断要+1或-1。在边界查找写法里,必然有一边是需要+1或-1,而另一边不需要,这个会影响下面要介绍的中间数选择。

4. 中间数的选择

中间数选择,即以上代码中的变量m,有的时候(r - l) / 2,有的时候(r - l + 1) / 2,这个写法取决于当r - l == 1时,那么m要么等于l,要么等于r,为了这个循环能结束,如果下方代码是l = m,那么m要取r,即要+1,如果下方代码是r = m,那么m要取l,不需要+1。这个如果写错就会导致死循环的发生。

另外,还有一个更多人犯的错误,很多人会写m = (l + r) / 2,当然,这个其实在不少情况下确实也不会怎么样,但是,如果我们要做的事情并不是在数组中查找,而是在一个区间里面,比如找一个方程的整数解,或者求3次方根的整数部分,这就会产生问题了,当 l和r是负数的时候,与它们是正数的时候,除以2的含义是不一样的,C语言中的除法实际上是向0取整,比如说,-7/2 == -3,但是,在二分搜索时,我们必须要么向上取整,要么向下取整,但存在负数时,这样做除法可能会导致二分死循环。采用(r - l) / 2可以避免这个问题,或者,你还可以使用位运算技巧(r - l) >> 1。当然,为了减少中坑起见,写m = l + (r - l) / 2最保险。

避免进坑的写法

以上写法坑太多(其实这些坑的有一部分的本质是你采用了闭区间),怎么样避免掉以上种种问题搞一个坑最少最不容易出错的写法呢?来看这个写法

int bin_search_r1(int arr[], int len, int val) {
  int l = -1, r = len;
  while (l + 1 < r) {
    int m = l + ((r - l) / 2);
    if (arr[m] < val) {
      l = m;
    } else {
      r = m;
    }
  }
  return l;
}

int bin_search_r2(int arr[], int len, int val) {
  int l = -1, r = len;
  while (l + 1 < r) {
    int m = l + ((r - l) / 2);
    if (arr[m] < val) {
      l = m;
    } else {
      r = m;
    }
  }
  return r;
}

int bin_search_r3(int arr[], int len, int val) {
  int l = -1, r = len;
  while (l + 1 < r) {
    int m = l + ((r - l) / 2);
    if (arr[m] <= val) {
      l = m;
    } else {
      r = m;
    }
  }
  return l;
}

int bin_search_r4(int arr[], int len, int val) {
  int l = -1, r = len;
  while (l + 1 < r) {
    int m = l + ((r - l) / 2);
    if (arr[m] <= val) {
      l = m;
    } else {
      r = m;
    }
  }
  return r;
}

以上写法看初值似乎是在采用开区间,其实不然,实际上是半开半闭区间表示(看循环条件),表示要查找的值在[l, r)内或(l, r]内,返回值如果是l那区间就是[l, r),反之就是(l, r]。在循环结束前,区间[l, r]内至少有3个元素,这样m肯定与l或r不相等,不存在死循环的可能性,也不需要关心向上取整还是向下取整的问题,这也是半开半闭区间表示的优点。另外,不论区间是3个还是4个元素,最终都会缩短为刚好2个元素,l指向满足比较条件的最右的元素,r指向不满足比较条件的最左元素。还有一点,由于m永远不会等于l或r,所以这样初始化并不会产生越界访问,而且这个初始化保证了m有可能取到[l+1, r-1]中任何一个。所以,这个二分写法,我们只需要管比较条件的写法来决定找左边界或右边界,以及返回l或r决定具体边界元素即可,几乎就是个无坑版本写法,非常建议在比赛里采用这个写法。

简单的参考模板

// 等同于std::lower_bound
template <typename ITER, typename V>
ITER bin_search_lower(ITER begin, ITER end, const V &val) {
  ITER l = begin - 1, r = end;
  while (l + 1 < r) {
    ITER m = l + (r - l) / 2;
    if (*m < val) {
      l = m;
    } else {
      r = m;
    }
  }
  return l;
}

// 等同于std::upper_bound
template <typename ITER, typename V>
ITER bin_search_upper(ITER begin, ITER end, const V &val) {
  ITER l = begin - 1, r = end;
  while (l + 1 < r) {
    ITER m = l + (r - l) / 2;
    if (*m <= val) {
      l = m;
    } else {
      r = m;
    }
  }
  return r;
}

// cmp为true表示这个数在val的左边,否则在右边
template <typename ITER, typename V, class COMP>
ITER bin_search(ITER begin, ITER end, const V &val, COMP cmp) {
  ITER l = begin - 1, r = end;
  while (l + 1 < r) {
    ITER m = l + (r - l) / 2;
    if (cmp(*m, val)) {
      l = m;
    } else {
      r = m;
    }
  }
  return l;
}

扩展:更通用的写法

前面的写法只能使用随机迭代器,如果现在只有前向迭代器,那这写法就不行了。以下写法来自cppreference

template<class ForwardIt, class T, class Compare>
ForwardIt lower_bound(ForwardIt first, ForwardIt last, const T& value, Compare comp)
{
  ForwardIt it;
  typename std::iterator_traits<ForwardIt>::difference_type count, step;
  count = std::distance(first, last);
 
  while (count > 0) {
    it = first;
    step = count / 2;
    std::advance(it, step);
    if (comp(*it, value)) {
      first = ++it;
      count -= step + 1;
    }
    else
      count = step;
  }
  return first;
}

当然,更通用意味着自己写时更容易出问题,尽量直接用STL不自己手写那是最好的。

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus