大整数高精度计算2——乘法优化及进制转换

在上一篇介绍了基础算法后,本篇介绍算法级别的优化。本篇的三大内容:乘法优化,任意进制输入,任意进制输出。

乘法优化

这里要介绍的,是Karatsuba发现的分治法,我们假设要相乘的两个数,都有2n位,那么这两个数就可以分别表示为$a_1base^n+a_2, b_1base^n+b_2$,其中,$a_1,a_2,b_1,b_2$是n位的大整数,那么,它们的积就是

$$\begin{align} & (a_1base^n+a_2) \times (b_1base^n+b_2) \\
& = a_1b_1base^{2n} + (a_1b_2+a_2b_1)base^n + a_2b_2 \\
& = a_1b_1base^{2n} + ((a_1+a_2)(b_1+b_2)-a_1b_1-a_2b_2)base^n + a_2b_2 \end{align}$$

如果这样不够明显的话,我们用$c_1$代替$a_1b_1$,用$c_3$代替$a_2b_2$,得到

$c_1base^{2n} + ((a_1+a_2)(b_1+b_2)-c_1-c_3)base^n + c_3$

于是,这里一共有3次乘法,比起原来的4次暴力乘法减少了1次。而里面的乘法又可以进行递归优化,时间复杂度从$O(n^2)$下降到$O(n^{log_23})$约$O(n^{1.585})$

当然,在实际应用时,这两个数不可能都正好一样的位,不过也不要紧,这个算法对分割点位置也没有要求,但当然分割位置尽可能在中间,效率越高,于是实际实现就会有一堆细节,这里就不做介绍了。只要在原有的模板里面增加移位加法操作,这个优化算法就能套上去用,用起来非常简单。以下就是在上文中对BigIntSimple修改的例子

    BigIntSimple &offset_add(const BigIntSimple &b, int offset) {
        //填充高位
        if (v.size() < b.v.size() + offset) v.resize(b.v.size() + offset);
        int carry = 0;
        //逐位相加
        for (size_t i = 0; i < b.v.size(); ++i) {
            carry += v[i + offset] + b.v[i] - BIGINT_BASE;
            v[i + offset] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理进位,拆两个循环来写是避免做 i < b.v.size() 的判断
        for (size_t i = b.v.size() + offset; carry && i < v.size(); ++i) {
            carry += v[i] - BIGINT_BASE;
            v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理升位进位
        if (carry) v.push_back(carry);
        return *this;
    }

    BigIntSimple mul(const BigIntSimple &b) const {
        // r记录相加结果
        BigIntSimple r;
        r.v.resize(v.size() + b.v.size()); //初始化长度
        for (size_t j = 0; j < v.size(); ++j) {
            int carry = 0, m = v[j]; // m用来缓存乘数
            // carry可能很大,只能使用求模的办法,此循环与加法部分几乎相同,就多乘了个m
            for (size_t i = 0; i < b.v.size(); ++i) {
                carry += r.v[i + j] + b.v[i] * m;
                r.v[i + j] = carry % BIGINT_BASE;
                carry /= BIGINT_BASE;
            }
            r.v[j + b.v.size()] += carry;
        }
        r.trim();
        return r;
    }

    BigIntSimple &fastmul(const BigIntSimple &a, const BigIntSimple &b) {
        //小于某个阈值就直接用暴力乘法
        if (std::min(a.v.size(), b.v.size()) <= 300) {
            return *this = a.mul(b);
        }
        BigIntSimple ah, al, bh, bl, h, m;
        //计算分割点
        size_t split = std::max(
            std::min((a.v.size() + 1) / 2, b.v.size() - 1),
            std::min((b.v.size() + 1) / 2, a.v.size() - 1));
        //按分割点拆成4个数
        al.v.assign(a.v.begin(), a.v.begin() + split);
        ah.v.assign(a.v.begin() + split, a.v.end());
        bl.v.assign(b.v.begin(), b.v.begin() + split);
        bh.v.assign(b.v.begin() + split, b.v.end());
        //按公式递归计算
        fastmul(al, bl);
        h.fastmul(ah, bh);
        m.fastmul(al + ah, bl + bh);
        m.subtract(*this + h);
        v.resize(a.v.size() + b.v.size());
        offset_add(m, split);
        offset_add(h, split * 2);
        trim();
        return *this;
    }

    BigIntSimple operator*(const BigIntSimple &b) const {
        BigIntSimple r;
        r.fastmul(*this, b);
        r.sign = sign * b.sign;
        return r;
    }

这个优化一加,那个10000阶乘的题目就轻松缩短到109ms。且代码不长,加起来轻松又愉快。

类似地,以上优化算法是分成两路进行分治,如果分成三路,那么就叫做Toom-3算法,时间复杂度为$O(n^{log_35})$约$O(n^{1.465})$,如果分成四路,那么就叫做Toom-4算法,时间复杂度为$O(n^{log_47})$约$O(n^{1.404})$,Toom算法还有很多个变种,Karatsuba分治法其实就是Toom算法在n为2的情况。但很多时候,用Karatsuba已经足够了,除非你对性能有特别的追求。

任意进制读入

所谓任意进制,如果用字符串输入,通常限定为2~36进制,当然如果你直接用数组作为输入,那确实可以支持任意进制。为了方便表述,以下假设你的大整数类使用base进制,输入是b进制,$base \neq b$。

对于这个问题,最多人的想法是,按进制的定义直接加起来,假设输入是$s_ns_{n-1} \dots s_2s_1s_0$,那就求出$s_0+s_1b+s_2b^2+ \dots +s_nb^n$,再整理就得到$((\dots((s_n*b+s_{n-1}) *b+s_{n-2}) *b+\dots)+s_1) *b + s_0$,所以写一个循环,计算过程用这个大整数类即可。这样一共有n次乘法,而这里每次乘法是$O(n)$,所以整体复杂度是$O(n^2)$。

但是,不能光看时间复杂度,来回忆一下在上一篇文章里,阶乘代码是怎么写的

BigIntSimple fac(int start, int n) {
    if (n < 16) {
        BigIntSimple s = 1;
        for (int i = start; i < start + n; ++i)
            s = BigIntSimple(i) * s;
        return s;
    }
    int m = (n + 1) / 2;
    return fac(start, m) * fac(start + m, n - m);
}

这个写法丝毫没有减少BigIntSimple之间做乘法的次数,但却比暴力乘法来的快,原因是什么?你有想过原因吗?

假设有2n个小整数相乘,假设它们大小都在base附近,那么直接乘的话,那么小整数的乘法次数(即把大整数的乘法过程拆分来统计)就是

$\sum\limits_{i=1}^{2n-1}{i} = n(2n-1) = 2n^2 - n$

而我们分成两组,每组各n个先乘,再来做1次n位的乘法,那么总次数是

$2\sum\limits_{i=1}^{n-1}{i} + n^2 = n(n-1) + n^2 = 2n^2 - n$

完全没有差别!那再假设,它们的大小都在$\sqrt{base}$附近呢,那直接乘的话,是

$\sum\limits_{i=1}^{2n-1}{\lceil\dfrac{i}{2}\rceil} = \dfrac{n(n-1)+n(n+1)}{2} = n^2$

分组再乘是

$2\sum\limits_{i=1}^{n-1}{\lceil\dfrac{i}{2}\rceil + (\dfrac{n}{2})^2} = \dfrac{n^2}{2} + \dfrac{n^2}{4} = \dfrac{3n^2}{4}$

甚至直接分成n组,每组两两相乘

$n + \sum\limits_{i=1}^{n-1}{i} = n + \dfrac{n(n-1)}{2} = \dfrac{n^2 + n}{2}$

这回差别就产生了,也就是说,如果做多次小整数乘法(小整数的定义为小于base),那么通过分组便可以有效减少计算次数,而且,我们还可以进行递归分组来节省更多的时间,最终得到约$\dfrac{n^2}{2}$的计算量。而求阶乘正好满足这个条件,于是便有了这个求阶乘的优化代码,而且这个方法还可以拓展,用优先队列维护最小的两个,每次找最小的两个来相乘,用此法时间在那个阶乘题目里可以稍微减少到93ms。而且,转换成两个大整数相乘还能用上前面说的乘法优化节省更多的时间。

那阶乘和这进制转换有啥关系啊?重新看看那个进制转换的循环过程,是不是有超多的小整数乘法?所以又可以分治了。

找一个分割点n,且满足$n=2^k$,原输入分割为$a_1b^n+a_2$,其中$a_1,a_2$都是大整数,这样求出$b^n$只需要k次自乘,这样就把原输入分割为两小段,这两小段再分别做输入的进制转换,这就是一个递归。这k次自乘,相乘的两数必然是等长度的,可以非常好的利用乘法加速特性。

    //分治进制转换输入
    BigIntSimple &_from_str(const std::string &s, int base) {
        //较短长度时直接计算,36^4 < 2^31,但取5就大于了,所以长度上限是4
        if (s.size() <= 4) {
            int v = 0;
            for (size_t i = 0; i < s.size(); ++i) {
                int digit = -1;
                if (s[i] >= '0' && s[i] <= '9')
                    digit = s[i] - '0';
                else if (s[i] >= 'A' && s[i] <= 'Z')
                    digit = s[i] - 'A' + 10;
                else if (s[i] >= 'a' && s[i] <= 'z')
                    digit = s[i] - 'a' + 10;
                v = v * base + digit;
            }
            return *this = v;
        }
        BigIntSimple m(base), h;
        size_t len = 1;
        //计算分割点
        for (; len * 3 < s.size(); len *= 2) {
            m = m * m;
        }
        h._from_str(s.substr(0, s.size() - len), base);
        _from_str(s.substr(s.size() - len), base);
        *this = *this + m * h;
        return *this;
    }
    //任意进制字符串输入(2~36进制)
    BigIntSimple &from_str(const char *s, int base = 10) {
        //特殊情况直接用原来的读入函数速度快
        if (base == 10) {
            set(s);
            return *this;
        }
        int vsign = 1, i = 0;
        while (s[i] == '-') {
            ++i;
            vsign = -vsign;
        }
        _from_str(std::string(s + i), base);
        sign = vsign;
        return *this;
    }

任意进制输出

相信对于这个问题,如果那个数有n位,你不会考虑做n次除法吧,做n次除法的总时间复杂度是$O(n^3)$,一个一万位的大整数你要进制转换输出那你得计算到什么时候去。假设大整数类的基是base,要输出的进制是b,通过前一个输入的方案,你应该很容易想到做除法分割,这样下一次的长度就下降到n/2,即我们先求出分割点k,满足$log_{base}b^k \approx n/2$,原数是a的话,计算出$a_1=\lfloor\dfrac{a}{b^k}\rfloor, a_2=a\,mod\,b^k$,然后再分别对$a_1,a_2$做进制转换,最后$a_2$的结果视情况补充前导0后,与$a_1$的结果做字符串连接即可。

    //字符串输出
    std::string to_dec() const {
        std::string s;
        for (size_t i = 0; i < v.size(); ++i) {
            int d = v[i];
            //拆开压位
            for (size_t k = 0; k < BIGINT_DIGITS; ++k) {
                s += d % 10 + '0';
                d /= 10;
            }
        }
        //去除前导0
        while (s.size() > 1 && s.back() == '0')
            s.pop_back();
        //补符号
        if (sign < 0) s += '-';
        //不要忘记要逆序
        std::reverse(s.begin(), s.end());
        return s;
    }
    //递归分治进制转换输出
    std::string _to_str(int base, int pack) const {
        std::string s;
        //长度只剩下2时可以直接算
        if (v.size() <= 2) {
            int d = v[0] + (v.size() > 1 ? v[1] : 0) * BIGINT_BASE;
            do {
                int g = d % base;
                if (g < 10) {
                    s += char(g + '0');
                } else {
                    s += char(g + 'a' - 10);
                }
                d /= base;
            } while (d);
            //填充前导0
            while (s.size() < pack)
                s += '0';
            std::reverse(s.begin(), s.end());
            return s;
        }
        BigIntSimple m(base), h, l;
        size_t len = 1; //计算余数部分要补的前导0
        //计算分割点
        for (; m.v.size() * 3 < v.size(); len *= 2) {
            m = m * m;
        }
        h = div_mod(m, l); //算出分割后的高位h和低位l
        s = h._to_str(base, std::max(pack - (int)len, 0));
        return s + l._to_str(base, len);
    }
    //任意进制(2~36进制)字符串输出
    std::string to_str(int base = 10) const {
        if (base == 10) {
            return to_dec();
        }
        std::string s;
        BigIntSimple m(*this);
        m.sign = 1;
        s = m._to_str(base, 0);
        return sign >= 0 ? s : "-" + s;
    }

方法是有了,但问题是,这个方法并不够快,虽然复杂度确实下降到$O(n^2)$,但常数大,除法比乘法的常数大,而且除法不太好优化,但乘法的优化前面已经有了,有没有不需要除法的办法?有啊,考虑一下,既然输出是b进制,那么我们直接动态方式构造一个使用的base为$b^n$的大整数类,直接按前一个方法,把输出作为这个新大整数类的输入,这样输入完了再转了b进制输出不就只有乘法了。不过需要再另写一个类,这里不提供示例代码了,直接看项目代码吧,因为代码比较长。

具体请参见MiniBigInteger项目中BigIntBase的实现。

模板

最后,加上了乘法优化和任意进制输入输出的模板(用分治除法实现的输出)如下

struct BigIntSimple {
    static const int BIGINT_BASE = 10000;
    static const int BIGINT_DIGITS = 4;

    int sign; // 1表示正数,-1表示负数
    std::vector<int> v;

    //定义0也需要长度1
    BigIntSimple() {
        sign = 1;
        v.push_back(0);
    }
    BigIntSimple(int n) { *this = n; }
    //判断是否为0
    bool iszero() const { return v.size() == 1 && v.back() == 0; }
    //消除前导0并修正符号
    void trim() {
        while (v.back() == 0 && v.size() > 1)
            v.pop_back();
        if (iszero()) sign = 1;
    }
    //获取pos位置上的数值,用于防越界,简化输入处理
    int get(unsigned pos) const {
        if (pos >= v.size()) return 0;
        return v[pos];
    }
    //绝对值大小比较
    bool absless(const BigIntSimple &b) const {
        if (v.size() == b.v.size()) {
            for (size_t i = v.size() - 1; i < v.size(); --i)
                if (v[i] != b.v[i]) return v[i] < b.v[i];
            return false;
        } else {
            return v.size() < b.v.size();
        }
    }
    //字符串输入
    void set(const char *s) {
        v.clear();
        sign = 1;
        //处理负号
        while (*s == '-')
            sign = -sign, ++s;
        //先按数位直接存入数组里
        for (size_t i = 0; s[i]; ++i)
            v.push_back(s[i] - '0');
        std::reverse(v.begin(), v.end());
        //压位处理,e是压位后的长度
        size_t e = (v.size() + BIGINT_DIGITS - 1) / BIGINT_DIGITS;
        for (size_t i = 0, j = 0; i < e; ++i, j += BIGINT_DIGITS) {
            v[i] = v[j]; //设置压位的最低位
            //高位的按每一位上的数值乘以m,m是该位的权值
            for (size_t k = 1, m = 10; k < BIGINT_DIGITS; ++k, m *= 10)
                v[i] += get(j + k) * m;
        }
        //修正压位后的长度
        if (e) {
            v.resize(e);
            trim();
        } else {
            v.resize(1);
        }
    }
    //分治进制转换输入
    BigIntSimple &_from_str(const std::string &s, int base) {
        //较短长度时直接计算,36^4 < 2^31,但取5就大于了,所以长度上限是4
        if (s.size() <= 4) {
            int v = 0;
            for (size_t i = 0; i < s.size(); ++i) {
                int digit = -1;
                if (s[i] >= '0' && s[i] <= '9')
                    digit = s[i] - '0';
                else if (s[i] >= 'A' && s[i] <= 'Z')
                    digit = s[i] - 'A' + 10;
                else if (s[i] >= 'a' && s[i] <= 'z')
                    digit = s[i] - 'a' + 10;
                v = v * base + digit;
            }
            return *this = v;
        }
        BigIntSimple m(base), h;
        size_t len = 1;
        //计算分割点
        for (; len * 3 < s.size(); len *= 2) {
            m = m * m;
        }
        h._from_str(s.substr(0, s.size() - len), base);
        _from_str(s.substr(s.size() - len), base);
        *this = *this + m * h;
        return *this;
    }
    //任意进制字符串输入(2~36进制)
    BigIntSimple &from_str(const char *s, int base = 10) {
        //特殊情况直接用原来的读入函数速度快
        if (base == 10) {
            set(s);
            return *this;
        }
        int vsign = 1, i = 0;
        while (s[i] == '-') {
            ++i;
            vsign = -vsign;
        }
        _from_str(std::string(s + i), base);
        sign = vsign;
        return *this;
    }
    //字符串输出
    std::string to_dec() const {
        std::string s;
        for (size_t i = 0; i < v.size(); ++i) {
            int d = v[i];
            //拆开压位
            for (size_t k = 0; k < BIGINT_DIGITS; ++k) {
                s += d % 10 + '0';
                d /= 10;
            }
        }
        //去除前导0
        while (s.size() > 1 && s.back() == '0')
            s.pop_back();
        //补符号
        if (sign < 0) s += '-';
        //不要忘记要逆序
        std::reverse(s.begin(), s.end());
        return s;
    }
    //递归分治进制转换输出
    std::string _to_str(int base, int pack) const {
        std::string s;
        //长度只剩下2时可以直接算
        if (v.size() <= 2) {
            int d = v[0] + (v.size() > 1 ? v[1] : 0) * BIGINT_BASE;
            do {
                int g = d % base;
                if (g < 10) {
                    s += char(g + '0');
                } else {
                    s += char(g + 'a' - 10);
                }
                d /= base;
            } while (d);
            //填充前导0
            while (s.size() < pack)
                s += '0';
            std::reverse(s.begin(), s.end());
            return s;
        }
        BigIntSimple m(base), h, l;
        size_t len = 1; //计算余数部分要补的前导0
        //计算分割点
        for (; m.v.size() * 3 < v.size(); len *= 2) {
            m = m * m;
        }
        h = div_mod(m, l); //算出分割后的高位h和低位l
        s = h._to_str(base, std::max(pack - (int)len, 0));
        return s + l._to_str(base, len);
    }
    //任意进制(2~36进制)字符串输出
    std::string to_str(int base = 10) const {
        if (base == 10) {
            return to_dec();
        }
        std::string s;
        BigIntSimple m(*this);
        m.sign = 1;
        s = m._to_str(base, 0);
        return sign >= 0 ? s : "-" + s;
    }

    bool operator<(const BigIntSimple &b) const {
        if (sign == b.sign) {
            return sign > 0 ? absless(b) : b.absless(*this);
        } else {
            return sign < 0;
        }
    }

    bool operator==(const BigIntSimple &b) const {
        if (sign == b.sign) {
            return !absless(b) && !b.absless(*this);
        }
        return false;
    }

    BigIntSimple &operator=(int n) {
        v.clear();
        sign = n >= 0 ? 1 : -1;
        for (n = abs(n); n; n /= BIGINT_BASE)
            v.push_back(n % BIGINT_BASE);
        if (v.empty()) v.push_back(0);
        return *this;
    }

    BigIntSimple &operator=(const std::string &s) {
        set(s.c_str());
        return *this;
    }

    BigIntSimple operator-() const {
        BigIntSimple r = *this;
        r.sign = -r.sign;
        return r;
    }

    BigIntSimple operator+(const BigIntSimple &b) const {
        //符号不同时转换为减法
        if (sign != b.sign) return *this - -b;
        BigIntSimple r = *this;
        //填充高位
        if (r.v.size() < b.v.size()) r.v.resize(b.v.size());
        int carry = 0;
        //逐位相加
        for (size_t i = 0; i < b.v.size(); ++i) {
            carry += r.v[i] + b.v[i] - BIGINT_BASE;
            r.v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理进位,拆两个循环来写是避免做 i < b.v.size() 的判断
        for (size_t i = b.v.size(); carry && i < r.v.size(); ++i) {
            carry += r.v[i] - BIGINT_BASE;
            r.v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理升位进位
        if (carry) r.v.push_back(carry);
        return r;
    }

    BigIntSimple &subtract(const BigIntSimple &b) {
        int borrow = 0;
        //先处理b的长度
        for (size_t i = 0; i < b.v.size(); ++i) {
            borrow += v[i] - b.v[i];
            v[i] = borrow;
            v[i] -= BIGINT_BASE * (borrow >>= 31);
        }
        //如果还有借位就继续处理
        for (size_t i = b.v.size(); borrow; ++i) {
            borrow += v[i];
            v[i] = borrow;
            v[i] -= BIGINT_BASE * (borrow >>= 31);
        }
        //减法可能会出现前导0需要消去
        trim();
        return *this;
    }

    BigIntSimple operator-(const BigIntSimple &b) const {
        //符号不同时转换为加法
        if (sign != b.sign) return (*this) + -b;
        if (absless(b)) { //保证大数减小数
            BigIntSimple r = b;
            return -r.subtract(*this);
        } else {
            BigIntSimple r = *this;
            return r.subtract(b);
        }
    }

    BigIntSimple &offset_add(const BigIntSimple &b, int offset) {
        //填充高位
        if (v.size() < b.v.size() + offset) v.resize(b.v.size() + offset);
        int carry = 0;
        //逐位相加
        for (size_t i = 0; i < b.v.size(); ++i) {
            carry += v[i + offset] + b.v[i] - BIGINT_BASE;
            v[i + offset] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理进位,拆两个循环来写是避免做 i < b.v.size() 的判断
        for (size_t i = b.v.size() + offset; carry && i < v.size(); ++i) {
            carry += v[i] - BIGINT_BASE;
            v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理升位进位
        if (carry) v.push_back(carry);
        return *this;
    }

    BigIntSimple mul(const BigIntSimple &b) const {
        // r记录相加结果
        BigIntSimple r;
        r.v.resize(v.size() + b.v.size()); //初始化长度
        for (size_t j = 0; j < v.size(); ++j) {
            int carry = 0, m = v[j]; // m用来缓存乘数
            // carry可能很大,只能使用求模的办法,此循环与加法部分几乎相同,就多乘了个m
            for (size_t i = 0; i < b.v.size(); ++i) {
                carry += r.v[i + j] + b.v[i] * m;
                r.v[i + j] = carry % BIGINT_BASE;
                carry /= BIGINT_BASE;
            }
            r.v[j + b.v.size()] += carry;
        }
        r.trim();
        return r;
    }

    BigIntSimple &fastmul(const BigIntSimple &a, const BigIntSimple &b) {
        //小于某个阈值就直接用暴力乘法
        if (std::min(a.v.size(), b.v.size()) <= 300) {
            return *this = a.mul(b);
        }
        BigIntSimple ah, al, bh, bl, h, m;
        //计算分割点
        size_t split = std::max(                            //
            std::min((a.v.size() + 1) / 2, b.v.size() - 1), //
            std::min((b.v.size() + 1) / 2, a.v.size() - 1));
        //按分割点拆成4个数
        al.v.assign(a.v.begin(), a.v.begin() + split);
        ah.v.assign(a.v.begin() + split, a.v.end());
        bl.v.assign(b.v.begin(), b.v.begin() + split);
        bh.v.assign(b.v.begin() + split, b.v.end());
        //按公式递归计算
        fastmul(al, bl);
        h.fastmul(ah, bh);
        m.fastmul(al + ah, bl + bh);
        m.subtract(*this + h);
        v.resize(a.v.size() + b.v.size());

        offset_add(m, split);
        offset_add(h, split * 2);
        trim();
        return *this;
    }

    BigIntSimple operator*(const BigIntSimple &b) const {
        BigIntSimple r;
        r.fastmul(*this, b);
        r.sign = sign * b.sign;
        return r;
    }

    //对b乘以mul再左移offset的结果相减,为除法服务
    BigIntSimple &sub_mul(const BigIntSimple &b, int mul, int offset) {
        if (mul == 0) return *this;
        int borrow = 0;
        //与减法不同的是,borrow可能很大,不能使用减法的写法
        for (size_t i = 0; i < b.v.size(); ++i) {
            borrow += v[i + offset] - b.v[i] * mul - BIGINT_BASE + 1;
            v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
            borrow /= BIGINT_BASE;
        }
        //如果还有借位就继续处理
        for (size_t i = b.v.size(); borrow; ++i) {
            borrow += v[i + offset] - BIGINT_BASE + 1;
            v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
            borrow /= BIGINT_BASE;
        }
        return *this;
    }

    BigIntSimple div_mod(const BigIntSimple &b, BigIntSimple &r) const {
        BigIntSimple d;
        r = *this;
        if (absless(b)) return d;
        d.v.resize(v.size() - b.v.size() + 1);
        //提前算好除数的最高三位+1的倒数,若最高三位是a3,a2,a1
        //那么db是a3+a2/base+(a1+1)/base^2的倒数,最后用乘法估商的每一位
        //此法在BIGINT_BASE<=32768时可在int32范围内用
        //但即使使用int64,那么也只有BIGINT_BASE<=131072时可用(受double的精度限制)
        //能保证估计结果q'与实际结果q的关系满足q'<=q<=q'+1
        //所以每一位的试商平均只需要一次,只要后面再统一处理进位即可
        //如果要使用更大的base,那么需要更换其它试商方案
        double t = (b.get((unsigned)b.v.size() - 2) +
                   (b.get((unsigned)b.v.size() - 3) + 1.0) / BIGINT_BASE);
        double db = 1.0 / (b.v.back() + t / BIGINT_BASE);
        for (size_t i = v.size() - 1, j = d.v.size() - 1; j <= v.size();) {
            int rm = r.get(i + 1) * BIGINT_BASE + r.get(i);
            int m = std::max((int)(db * rm), r.get(i + 1));
            r.sub_mul(b, m, j);
            d.v[j] += m;
            if (!r.get(i + 1)) //检查最高位是否已为0,避免极端情况
                --i, --j;
        }
        r.trim();
        //修正结果的个位
        int carry = 0;
        while (!r.absless(b)) {
            r.subtract(b);
            ++carry;
        }
        //修正每一位的进位
        for (size_t i = 0; i < d.v.size(); ++i) {
            carry += d.v[i];
            d.v[i] = carry % BIGINT_BASE;
            carry /= BIGINT_BASE;
        }
        d.trim();
        d.sign = sign * b.sign;
        return d;
    }

    BigIntSimple operator/(const BigIntSimple &b) const {
        BigIntSimple r;
        return div_mod(b, r);
    }

    BigIntSimple operator%(const BigIntSimple &b) const { return *this - *this / b * b; }
};
Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus