大整数高精度计算1——基础算法

最近在编写大整数库的过程中,踩到不少的坑,于是把一些有用的细节准备写成文章做整理。如果你只是想直接查找并使用一个大整数库,那直接上GMP即可,如果是用在比赛,那直接用我的MiniBigInteger项目,如果你是想学习个中细节,那你可以坐下来细品。

所谓大整数,又叫高精度运算,就是运算对象是上千位甚至到百万位,总之远远超过内置数据类型的表示范围,这类数字都叫大整数。而C/C++的标准库里目前并没有大整数库,于是这个轮子被反复制造了无数个,不过在github上比较有质量的轮子并没有很多。本文除了介绍基础实现,主要还是介绍优化方法。

大整数的表示

大整数的表示方法最常见的有4种:

  1. 直接使用string
  2. 使用定长数组(仅适用于竞赛)
  3. 使用链表
  4. 使用变长线性表例如vector

直接用string的方式适合初学者,输入输出直观,但缺点也非常明显,因为计算时需要在字符与数值之间来回转换,浪费太多不必要的时间,效率会非常差。不过如果你是初学者,先用string表示法来写未尝不是个好主意。但有个细节就是,如果想要效率高,最好把string前后倒置调整为低位在前再做运算,这样速度和实现难度都会低一些。

至于使用链表,好处是变长容易,变短也不难,但性能比用string的更差还更难写,这里就不谈了,以下介绍使用数组的表示法

为了在数组里表示一个大整数,如果我们采用10进制,表示123456789,那很简单,例如这样:

int a[] = {9, 8, 7, 6, 5, 4, 3, 2, 1};

a[0]表示个位,a[1]表示十位,如此类推。之所以这样做,是希望同一个位置的元素的含义是固定的,这样能简化后面的算法编写。

但是,对于计算机来说,这样表示实在是太浪费空间和时间,我们还可以这样:

int a[] = {6789, 2345, 1};

也就是说,采用10000进制,那么这个10000叫做基数base。我们还可以视情况使用其它的基数,例如使用$2^{16}$或$2^{32}$等等。对于这种我们原始进制是b的情况,但通过更大的基数$b^n$来表示的方法,叫做压位高精度,n就是压的位数。压位高精度的运算效率远超非压位。

以下我们用万进制来做一步步的演示,首先定义出这个类,在这里我用vector来保存(没有必要采用定长数组,因为效率没啥区别,除非你的OJ就不给加O2优化例如某谷或参加某些信息竞赛):

struct BigIntSimple {
    static const int BIGINT_BASE = 10000;
    static const int BIGINT_DIGITS = 4;

    int sign; //1表示正数,-1表示负数
    std::vector<int> v;
};

于是这里有一个问题,0怎么表示?在这里规定0的sign为1,且v的长度为1,这种约定可以在后期简化一些代码。当然你也可以规定v的长度为0,你怎么约定就怎么写代码即可。

大整数的基本运算

实现一个功能还算得上完整的大整数,需要的基本运算有:

  1. 比较运算
  2. 字符串输入
  3. 字符串输出
  4. 加法
  5. 减法
  6. 乘法
  7. 除法及求余

1. 比较运算

大整数的比较很简单,先判断符号,符号不同再看位数,位数不同再从最高位一位一位比就行。但是我们其实更需要一个按绝对值比较的函数(后面的运算会需要它),那么先写一个无视符号的比较版本,再在运算符重载处判断符号即可。

2. 字符串输入输出

对于输入输出的字符串进制与该大整数的基相同时,输入输出直接映射转换即可,对于不同进制的输入输出在后面再做介绍,以下是实现了输入输出及一些基础功能的版本

struct BigIntSimple {
    static const int BIGINT_BASE = 10000;
    static const int BIGINT_DIGITS = 4;

    int sign; //1表示正数,-1表示负数
    std::vector<int> v;

    //定义0也需要长度1
    BigIntSimple() {
        sign = 1;
        v.push_back(0);
    }
    BigIntSimple(int n) {
        *this = n;
    }
    //判断是否为0
    bool iszero() const {
        return v.size() == 1 && v.back() == 0;
    }
    //消除前导0并修正符号
    void trim() {
        while (v.back() == 0 && v.size() > 1)
            v.pop_back();
        if (iszero())
            sign = 1;
    }
    //获取pos位置上的数值,用于防越界,简化输入处理
    int get(unsigned pos) const {
        if (pos >= v.size())
            return 0;
        return v[pos];
    }
    //绝对值大小比较
    bool absless(const BigIntSimple &b) const {
        if (v.size() == b.v.size()) {
            for (size_t i = v.size() - 1; i < v.size(); --i)
                if (v[i] != b.v[i])
                    return v[i] < b.v[i];
            return false;
        } else {
            return v.size() < b.v.size();
        }
    }
    //字符串输入
    void set(const char *s) {
        v.clear();
        sign = 1;
        //处理负号
        while (*s == '-')
            sign = -sign, ++s;
        //先按数位直接存入数组里
        for (size_t i = 0; s[i]; ++i)
            v.push_back(s[i] - '0');
        std::reverse(v.begin(), v.end());
        //压位处理,e是压位后的长度
        size_t e = (v.size() + BIGINT_DIGITS - 1) / BIGINT_DIGITS;
        for (size_t i = 0, j = 0; i < e; ++i, j += BIGINT_DIGITS) {
            v[i] = v[j]; //设置压位的最低位
            //高位的按每一位上的数值乘以m,m是该位的权值
            for (size_t k = 1, m = 10; k < BIGINT_DIGITS; ++k, m *= 10)
                v[i] += get(j + k) * m;
        }
        //修正压位后的长度
        if (e) {
            v.resize(e);
            trim();
        } else {
            v.resize(1);
        }
    }
    //字符串输出
    std::string to_str() const {
        std::string s;
        for (size_t i = 0; i < v.size(); ++i) {
            int d = v[i];
            //拆开压位
            for (size_t k = 0; k < BIGINT_DIGITS; ++k) {
                s += d % 10 + '0';
                d /= 10;
            }
        }
        //去除前导0
        while (s.size() > 1 && s.back() == '0')
            s.pop_back();
        //补符号
        if (sign < 0)
            s += '-';
        //不要忘记要逆序
        std::reverse(s.begin(), s.end());
        return s;
    }

    BigIntSimple &operator=(int n) {
        v.clear();
        sign = n >= 0 ? 1 : -1;
        for (n = abs(n); n; n /= BIGINT_BASE)
            v.push_back(n % BIGINT_BASE);
        if (v.empty())
            v.push_back(0);
        return *this;
    }

    BigIntSimple &operator=(const std::string &s) {
        set(s.c_str());
        return *this;
    }

};

下文的介绍为了不重复,就不带上以上的代码了

3. 加法和减法

加法和减法都挺简单,核心思想就是模拟手工竖式,手工怎么算它就怎么算。要注意的点就是符号的处理。

另外,还有一些小优化,加法进位的时候,这个if是可以简单省略掉的,用求模和除法运算即可。即当前位是sum % base,进位是sum / base。示例代码如下

        for (size_t i = 0; i < b.v.size(); ++i) {
            carry += r.v[i] + b.v[i];
            r.v[i] = carry % BIGINT_BASE;
            carry /= BIGINT_BASE;
        }

但减法就没这么简单了,这时候还可以利用位运算来区分正负,比如,我们用当前位是sum - (sum >> 31) * base,因为如果sum是负数,那么sum >> 31在sum是int时就等于-1,相当于sum + base;而如果sum是非负数,那么sum >> 31就是0,结果就相当于sum,这样就成功实现了sum<0 ? sum+base : sum的逻辑。而这种方法同样可以用在加法上。另外,在较新的CPU上除法的性能比较高的时候,加法这种改进写法就会被淘汰。

其它的细节在以下代码的注释中有标注

    BigIntSimple operator-() const {
        BigIntSimple r = *this;
        r.sign = -r.sign;
        return r;
    }

    BigIntSimple operator+(const BigIntSimple &b) const {
        //符号不同时转换为减法
        if (sign != b.sign)
            return *this - -b;
        BigIntSimple r = *this;
        //填充高位
        if (r.v.size() < b.v.size())
            r.v.resize(b.v.size());
        int carry = 0;
        //逐位相加
        for (size_t i = 0; i < b.v.size(); ++i) {
            carry += r.v[i] + b.v[i] - BIGINT_BASE;
            r.v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理进位,拆两个循环来写是避免做 i < b.v.size() 的判断
        for (size_t i = b.v.size(); carry && i < r.v.size(); ++i) {
            carry += r.v[i] - BIGINT_BASE;
            r.v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理升位进位
        if (carry)
            r.v.push_back(carry);
        return r;
    }

    BigIntSimple &subtract(const BigIntSimple &b) {
        int borrow = 0;
        //先处理b的长度
        for (size_t i = 0; i < b.v.size(); ++i) {
            borrow += v[i] - b.v[i];
            v[i] = borrow;
            v[i] -= BIGINT_BASE * (borrow >>= 31);
        }
        //如果还有借位就继续处理
        for (size_t i = b.v.size(); borrow; ++i) {
            borrow += v[i];
            v[i] = borrow;
            v[i] -= BIGINT_BASE * (borrow >>= 31);
        }
        //减法可能会出现前导0需要消去
        trim();
        return *this;
    }

    BigIntSimple operator-(const BigIntSimple &b) const {
        //符号不同时转换为加法
        if (sign != b.sign)
            return (*this) + -b;
        if (absless(b)) { //保证大数减小数
            BigIntSimple r = b;
            return -r.subtract(*this);
        } else {
            BigIntSimple r = *this;
            return r.subtract(b);
        }
    }

4. 乘法

本文不区分高精度乘以低精度,和高精度乘以高精度,下文的除法也一样,因为实在没有这个必要分开写,以下直接介绍的是高精度乘以高精度。

乘法可以看成在加法外面再套一层循环,内循环相比加法多了一个偏移和一个乘法。但是由于进位的值不会只是1,所以那个位运算方法在这里不能使用,只能用求模了。以下直接上代码

    BigIntSimple operator*(const BigIntSimple &b) const {
        //r记录相加结果
        BigIntSimple r;
        r.v.resize(v.size() + b.v.size()); //初始化长度
        for (size_t j = 0; j < v.size(); ++j) {
            int carry = 0, m = v[j]; //m用来缓存乘数
            //carry可能很大,只能使用求模的办法,此循环与加法部分几乎相同,就多乘了个m
            for (size_t i = 0; i < b.v.size(); ++i) {
                carry += r.v[i + j] + b.v[i] * m;
                r.v[i + j] = carry % BIGINT_BASE;
                carry /= BIGINT_BASE;
            }
            r.v[j + b.v.size()] += carry;
        }
        r.trim();
        r.sign = sign * b.sign;
        return r;
    }

5. 除法和求余

除法是高精度的基础算法里面变化最多的,也是基础算法里面最难的,网上也有很多不同的写法,顺带说一说一些误区,同时这里提供一个我自己的写法,不过此法有限制条件,但在限制条件内应该是模拟手工的方法里面速度较快的。

第一种就是暴力整体二分,然后做乘法验证(或者利用二进制一位一位来确定,但却不使用移位减法,而使用整体相乘),这时候,设是2n位除以n位,那么二分的次数就是n,然后一次乘法是n^2,所以整体复杂度是$O(n^3)$,这是一个非常糟糕的方法,虽然写起来似乎更简单,但时间上还不如直接模拟,千万不要做整体二分。类似地还有二分开方,也是$O(n^3)$。

在模拟手工除法时,最关键的就是试商的部分,试商方法有很多,最简单的方式是用减法,先判断余数是不是大于等于除数,如果是,就做一次减法。这个写法确实容易,但问题是效率低下,除非你用的基数特别小,比如不压位10进制,这样速度还能看,甚至乎你用的是$2^n$进制,直接每个位枚举。

那我们来个二分呢?如果在每一位上分别二分,那在基数较大的时候比做减法好一些,但还是不够好,我们还需要更快的方案,进一步减少试商次数。

假设被除数有4个位,是$a_4,a_3,a_2,a_1$,除数有3个位,是$b_3,b_2,b_1$,那么我们只要试一位的商(多个位就是一位的写法加个循环),假如我们用$\dfrac{a_4 base+a_3}{b_3}$来试商,结果一定大于等于实际的商,但同时这会有一个问题,就是假如b2等于base-1,那会导致试商与实际的商误差非常大,例如9999,0000除以1,9999,直接用高位除得到9999,但这远超过实际商了,修正商的代价也不小。于是,我们想到,在base不太大的情况下,我们可以通过增加位数来估商,这样误差就会小得多。比如说我们用$\dfrac{a_4 base^2 + a_3 base + a_2}{b_3 base + b_2}$来试商,精度确实会大为提高,而且商的误差最多只有1,但缺点是,$a_4 base^2$的结果超出int的范围了,不过我们还可以用double。注意到,上式中a2对结果并没有任何影响,所以可以变形为$\dfrac{a_4 base + a_3}{b_3 + b_2/base}$。另外,我不希望这个估商总比实际商大,我们希望是小于等于实际商,这样在试后一位的时候,这个结果能自然得到修正,就增加了试商的效率,于是可以把式子改为$\dfrac{a_4 base + a_3}{b_3 + (b_2+1)base^{-1}}$,但这个+1导致误差增大,与实际商的误差最大达到2,那解决方法很简单,我们再增加1位的精度,得到式子$\dfrac{a_4 base + a_3}{b_3 + b_2 base^{-1} + (b_1+1)base^{-2}}$,于是便得到接下来在代码中所使用的算法。由于误差不超过1,如果估小了,那在下一位的估商时候就会产生补回去的效果,于是不必重复试商。这个方法要求base<=32768以避免各种溢出,在满足此条件下,因为每个位均只需要估一次,那么其时间常数与乘法相比,和减法与加法常数比是几乎相同的。

但是,有一个非常极端的情况,既然商最大误差是1,那么相当于余数的最大误差就等于除数,那如果除数特别大呢?举个例子,求9999,9999,9999,9999/9999,9999,9999,用上面的方法,高一位的试商结果是0,次一位的试商是9999(实际商是10000),相减得到余数1,0000,0000,9998,结果最高位并没有减到0,于是不能继续移位,需要在同一位再一次试商。但虽然在同一位试商两次,但后一位就不用做减法了,均摊还是n次。

算法有了,该处理细节了,除法需要一个减法函数,不过这个函数对借位的处理和之前的减法可不一样,因为可能一次借n个,那就产生了一个问题:对负数求模。负数求模的结果和正数很不一样,所以如果还是要避免if做判断处理,那就要再换个方法,我承认我很菜,花了很久时间才想到这个法子,实现如下

    //对b乘以mul再左移offset的结果相减,为除法服务
    BigIntSimple &sub_mul(const BigIntSimple &b, int mul, int offset) {
        if (mul == 0)
            return *this;
        int borrow = 0;
        //与减法不同的是,borrow可能很大,不能使用减法的写法
        for (size_t i = 0; i < b.v.size(); ++i) {
            borrow += v[i + offset] - b.v[i] * mul - BIGINT_BASE + 1;
            v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
            borrow /= BIGINT_BASE;
        }
        //如果还有借位就继续处理
        for (size_t i = b.v.size(); borrow; ++i) {
            borrow += v[i + offset] - BIGINT_BASE + 1;
            v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
            borrow /= BIGINT_BASE;
        }
        return *this;
    }

以上代码关键点就是这两行

borrow += v[i + offset] - b.v[i] * mul - BIGINT_BASE + 1;
v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;

核心思想是通过减去base统一在负数段求模后,再加上base回来,这样也达到避免if,避免if的写法比带if的写法时间上可以节省一半,不清楚有没有性能更好的做法。以下为除法及求余的具体实现。

    BigIntSimple div_mod(const BigIntSimple &b, BigIntSimple &r) const {
        BigIntSimple d;
        r = *this;
        if (absless(b)) return d;
        d.v.resize(v.size() - b.v.size() + 1);
        //提前算好除数的最高三位+1的倒数,若最高三位是a3,a2,a1
        //那么db是a3+a2/base+(a1+1)/base^2的倒数,最后用乘法估商的每一位
        //此法在BIGINT_BASE<=32768时可在int32范围内用
        //但即使使用int64,那么也只有BIGINT_BASE<=131072时可用(受double的精度限制)
        //能保证估计结果q'与实际结果q的关系满足q'<=q<=q'+1
        //所以每一位的试商平均只需要一次,只要后面再统一处理进位即可
        //如果要使用更大的base,那么需要更换其它试商方案
        double t = (b.get((unsigned)b.v.size() - 2) +
                    (b.get((unsigned)b.v.size() - 3) + 1.0) / BIGINT_BASE);
        double db = 1.0 / (b.v.back() + t / BIGINT_BASE);
        for (size_t i = v.size() - 1, j = d.v.size() - 1; j <= v.size();) {
            int rm = r.get(i + 1) * BIGINT_BASE + r.get(i);
            int m = std::max((int)(db * rm), r.get(i + 1));
            r.sub_mul(b, m, j);
            d.v[j] += m;
            if (!r.get(i + 1)) //检查最高位是否已为0,避免极端情况
                --i, --j;
        }
        r.trim();
        //修正结果的个位
        int carry = 0;
        while (!r.absless(b)) {
            r.subtract(b);
            ++carry;
        }
        //修正每一位的进位
        for (size_t i = 0; i < d.v.size(); ++i) {
            carry += d.v[i];
            d.v[i] = carry % BIGINT_BASE;
            carry /= BIGINT_BASE;
        }
        d.trim();
        d.sign = sign * b.sign;
        return d;
    }

    BigIntSimple operator/(const BigIntSimple &b) const {
        BigIntSimple r;
        return div_mod(b, r);
    }

    BigIntSimple operator%(const BigIntSimple &b) const {
        BigIntSimple r;
        div_mod(b, r);
        return r;
        //return *this - *this / b * b;
    }

求余就不说了,有求商之后代码仅一行return *this - *this / b * b,而且这样写与C语言规则一致,模与被除数符号相同。当然以上代码实现的div_mod函数本身就把商和余数同时求出,就可以直接调用而少了乘法和减法。

但是,现在系统几乎都是64位的,于是在64位下我们就应该考虑更大的base,例如说使用亿进制,最大限度利用64位带来的性能提升。但是,如果我们使用大的base,那么考虑到double的精度,前面的试商法就失效了,甚至于连base的平方都超出double能精确表示的范围,那这时候应该怎么办呢?为了方便描述,这里我们假设用的是$2^{n}$进制,且n>=18,而double的尾数是53位,于是$log_2base^3=3n=54>53$,这就是令原问题失效的边界。但是,我们真的需要3n长度的尾数吗?确实并不需要,我们的base的精度是n位,只需要再增加k位,只要n+k到53还有一些距离就行,把这个数看成是浮点数,同样地,把被除数和除数都看成浮点数,那我们的除数如何进行+1操作呢?比如说除数表示为$b_3b_2b_1$,那我们让$b_3$右移n+k位加上去,不就相当于+1了。这个方法需要考虑到下一个数估商超过base大小的情况,所以要注意base不能超过$2^{30}$,如果你要使用$2^{32}$,那又得换一个办法了。在实际操作中,我令k=n/2,这样操作起来更为方便,具体请参阅我的MiniBigInteger项目。

完整模板

点击展开

struct BigIntSimple {
    static const int BIGINT_BASE = 10000;
    static const int BIGINT_DIGITS = 4;

    int sign; // 1表示正数,-1表示负数
    std::vector<int> v;

    //定义0也需要长度1
    BigIntSimple() {
        sign = 1;
        v.push_back(0);
    }
    BigIntSimple(int n) { *this = n; }
    //判断是否为0
    bool iszero() const { return v.size() == 1 && v.back() == 0; }
    //消除前导0并修正符号
    void trim() {
        while (v.back() == 0 && v.size() > 1)
            v.pop_back();
        if (iszero()) sign = 1;
    }
    //获取pos位置上的数值,用于防越界,简化输入处理
    int get(unsigned pos) const {
        if (pos >= v.size()) return 0;
        return v[pos];
    }
    //绝对值大小比较
    bool absless(const BigIntSimple &b) const {
        if (v.size() == b.v.size()) {
            for (size_t i = v.size() - 1; i < v.size(); --i)
                if (v[i] != b.v[i]) return v[i] < b.v[i];
            return false;
        } else {
            return v.size() < b.v.size();
        }
    }
    //字符串输入
    void set(const char *s) {
        v.clear();
        sign = 1;
        //处理负号
        while (*s == '-')
            sign = -sign, ++s;
        //先按数位直接存入数组里
        for (size_t i = 0; s[i]; ++i)
            v.push_back(s[i] - '0');
        std::reverse(v.begin(), v.end());
        //压位处理,e是压位后的长度
        size_t e = (v.size() + BIGINT_DIGITS - 1) / BIGINT_DIGITS;
        for (size_t i = 0, j = 0; i < e; ++i, j += BIGINT_DIGITS) {
            v[i] = v[j]; //设置压位的最低位
            //高位的按每一位上的数值乘以m,m是该位的权值
            for (size_t k = 1, m = 10; k < BIGINT_DIGITS; ++k, m *= 10)
                v[i] += get(j + k) * m;
        }
        //修正压位后的长度
        if (e) {
            v.resize(e);
            trim();
        } else {
            v.resize(1);
        }
    }
    //字符串输出
    std::string to_str() const {
        std::string s;
        for (size_t i = 0; i < v.size(); ++i) {
            int d = v[i];
            //拆开压位
            for (size_t k = 0; k < BIGINT_DIGITS; ++k) {
                s += d % 10 + '0';
                d /= 10;
            }
        }
        //去除前导0
        while (s.size() > 1 && s.back() == '0')
            s.pop_back();
        //补符号
        if (sign < 0) s += '-';
        //不要忘记要逆序
        std::reverse(s.begin(), s.end());
        return s;
    }

    bool operator<(const BigIntSimple &b) const {
        if (sign == b.sign) {
            return sign > 0 ? absless(b) : b.absless(*this);
        } else {
            return sign < 0;
        }
    }

    BigIntSimple &operator=(int n) {
        v.clear();
        sign = n >= 0 ? 1 : -1;
        for (n = abs(n); n; n /= BIGINT_BASE)
            v.push_back(n % BIGINT_BASE);
        if (v.empty()) v.push_back(0);
        return *this;
    }

    BigIntSimple &operator=(const std::string &s) {
        set(s.c_str());
        return *this;
    }

    BigIntSimple operator-() const {
        BigIntSimple r = *this;
        r.sign = -r.sign;
        return r;
    }

    BigIntSimple operator+(const BigIntSimple &b) const {
        //符号不同时转换为减法
        if (sign != b.sign) return *this - -b;
        BigIntSimple r = *this;
        //填充高位
        if (r.v.size() < b.v.size()) r.v.resize(b.v.size());
        int carry = 0;
        //逐位相加
        for (size_t i = 0; i < b.v.size(); ++i) {
            carry += r.v[i] + b.v[i] - BIGINT_BASE;
            r.v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理进位,拆两个循环来写是避免做 i < b.v.size() 的判断
        for (size_t i = b.v.size(); carry && i < r.v.size(); ++i) {
            carry += r.v[i] - BIGINT_BASE;
            r.v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理升位进位
        if (carry) r.v.push_back(carry);
        return r;
    }

    BigIntSimple &subtract(const BigIntSimple &b) {
        int borrow = 0;
        //先处理b的长度
        for (size_t i = 0; i < b.v.size(); ++i) {
            borrow += v[i] - b.v[i];
            v[i] = borrow;
            v[i] -= BIGINT_BASE * (borrow >>= 31);
        }
        //如果还有借位就继续处理
        for (size_t i = b.v.size(); borrow; ++i) {
            borrow += v[i];
            v[i] = borrow;
            v[i] -= BIGINT_BASE * (borrow >>= 31);
        }
        //减法可能会出现前导0需要消去
        trim();
        return *this;
    }

    BigIntSimple operator-(const BigIntSimple &b) const {
        //符号不同时转换为加法
        if (sign != b.sign) return (*this) + -b;
        if (absless(b)) { //保证大数减小数
            BigIntSimple r = b;
            return -r.subtract(*this);
        } else {
            BigIntSimple r = *this;
            return r.subtract(b);
        }
    }

    BigIntSimple operator*(const BigIntSimple &b) const {
        // r记录相加结果
        BigIntSimple r;
        r.v.resize(v.size() + b.v.size()); //初始化长度
        for (size_t j = 0; j < v.size(); ++j) {
            int carry = 0, m = v[j]; // m用来缓存乘数
            // carry可能很大,只能使用求模的办法,此循环与加法部分几乎相同,就多乘了个m
            for (size_t i = 0; i < b.v.size(); ++i) {
                carry += r.v[i + j] + b.v[i] * m;
                r.v[i + j] = carry % BIGINT_BASE;
                carry /= BIGINT_BASE;
            }
            r.v[j + b.v.size()] += carry;
        }
        r.trim();
        r.sign = sign * b.sign;
        return r;
    }

    //对b乘以mul再左移offset的结果相减,为除法服务
    BigIntSimple &sub_mul(const BigIntSimple &b, int mul, int offset) {
        if (mul == 0) return *this;
        int borrow = 0;
        //与减法不同的是,borrow可能很大,不能使用减法的写法
        for (size_t i = 0; i < b.v.size(); ++i) {
            borrow += v[i + offset] - b.v[i] * mul - BIGINT_BASE + 1;
            v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
            borrow /= BIGINT_BASE;
        }
        //如果还有借位就继续处理
        for (size_t i = b.v.size(); borrow; ++i) {
            borrow += v[i + offset] - BIGINT_BASE + 1;
            v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
            borrow /= BIGINT_BASE;
        }
        return *this;
    }

    BigIntSimple div_mod(const BigIntSimple &b, BigIntSimple &r) const {
        BigIntSimple d;
        r = *this;
        if (absless(b)) return d;
        d.v.resize(v.size() - b.v.size() + 1);
        //提前算好除数的最高三位+1的倒数,若最高三位是a3,a2,a1
        //那么db是a3+a2/base+(a1+1)/base^2的倒数,最后用乘法估商的每一位
        //此法在BIGINT_BASE<=32768时可在int32范围内用
        //但即使使用int64,那么也只有BIGINT_BASE<=131072时可用(受double的精度限制)
        //能保证估计结果q'与实际结果q的关系满足q'<=q<=q'+1
        //所以每一位的试商平均只需要一次,只要后面再统一处理进位即可
        //如果要使用更大的base,那么需要更换其它试商方案
        double t = (b.get((unsigned)b.v.size() - 2) +
                   (b.get((unsigned)b.v.size() - 3) + 1.0) / BIGINT_BASE);
        double db = 1.0 / (b.v.back() + t / BIGINT_BASE);
        for (size_t i = v.size() - 1, j = d.v.size() - 1; j <= v.size();) {
            int rm = r.get(i + 1) * BIGINT_BASE + r.get(i);
            int m = std::max((int)(db * rm), r.get(i + 1));
            r.sub_mul(b, m, j);
            d.v[j] += m;
            if (!r.get(i + 1)) //检查最高位是否已为0,避免极端情况
                --i, --j;
        }
        r.trim();
        //修正结果的个位
        int carry = 0;
        while (!r.absless(b)) {
            r.subtract(b);
            ++carry;
        }
        //修正每一位的进位
        for (size_t i = 0; i < d.v.size(); ++i) {
            carry += d.v[i];
            d.v[i] = carry % BIGINT_BASE;
            carry /= BIGINT_BASE;
        }
        d.trim();
        d.sign = sign * b.sign;
        return d;
    }

    BigIntSimple operator/(const BigIntSimple &b) const {
        BigIntSimple r;
        return div_mod(b, r);
    }

    BigIntSimple operator%(const BigIntSimple &b) const {
        return *this - *this / b * b;
    }
};

经测试,这个实现在1秒内可以计算出30000阶乘,而计算10000阶乘不到0.1秒,附上求阶乘的实现如下:

BigIntSimple fac(int start, int n) {
    if (n < 16) {
        BigIntSimple s = 1;
        for (int i = start; i < start + n; ++i)
            s = BigIntSimple(i) * s;
        return s;
    }
    int m = (n + 1) / 2;
    return fac(start, m) * fac(start + m, n - m);
}

int main() {
    int n;
    while (cin >> n) {
        cout << fac(1, n).to_str() << endl;
    }
    return 0;
}

在hdu oj 1042上,以171ms通过,在提交记录里面,超过了大部分其它提交。当然这还不是极限。

对于除法,2n/n所需要时间与n*n所需时间比大约是1.5(n是位数),不同编译器这个比值稍有不同。

本基础篇就介绍到这里,后文将介绍乘法和除法的优化。

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus