大整数高精度计算3——快速傅里叶/数论变换及分治除法

在上一篇介绍了基础优化算法后,本篇介绍更复杂的内容。本篇的三大内容:FFT,NTT,分治除法。中文资料中我尚未发现有博客文章在介绍分治除法的,所以我就来写第一个介绍吧。

FFT (Schönhage–Strassen algorithm)

FFT就是快速傅里叶变换的缩写,FFT这里不重点介绍,参见oiwiki中对FFT的介绍

用比较显浅且简要的方式来介绍的话,FFT可以在$O(nlogn)$时间里(注:严格来说应该是$O(nlognloglogn)$,论文链接在这,在274页7.8处开始证明)求出两个多项式的乘积,而大整数乘法就可以视为两个整数系数的多项式的积再处理进位即可。那这样就产生一个问题,如果这个大整数采用的基数base,且长度为n,那么多项式的积的最大值可能会达到$n(base-1)^2$,因为FFT采用浮点运算,那么这个最大值必须在double的精度内,所以如果采用万进制,那么很容易导致精度不足进而结果错误,另外,还有求单位复根时的误差也会让它提早出现结果错误。实际应用的时候,一般采用10进制来增加长度的上限。所以如果你的大整数采用压位表示,那么在运用FFT之前,要先做拆位拆成10进制。总之,因为浮点精度的问题,不建议在大整数实现里使用FFT,建议采用下文要介绍的NTT来代替。所以这里对FFT的优化不做介绍,重心放在NTT上。

NTT

NTT就是快速数论变换,它和FFT很像,但它在整数域上做变换,具体数学介绍参见oiwiki中对NTT的介绍。由于没有浮点数参与,于是没有精度上的问题。

NTT里面有两个重要的数,一个是原根g,通常取3,另一个是素数p,通常取1004535809或998244353,之所以取这两个素数是因为3都是它们的原根,此外,这个素数p必须是形如$2^nm+1$,像$1004535809-1=2^{21} \times 479,\ 998244353-1=2^{23} \times 119$,而这个n限制了变换时的最大长度为$2^n$。在NTT变换的结果里,最大值可能会达到$n(base-1)^2$,为了保证这个值没有因为求解过程中被求模变小了,就要限制$n(base-1)^2 < p$,即这个大整数的最大长度n满足$n < \lfloor\dfrac{p}{(base-1)^2}\rfloor$且$n = 2^k$,那么比如说我们取p为998244353,然后用万进制,得到n的最大值是8!?只有8还用什么呢,我们换换,改用10进制,那可以算得n的最大值是$2^{23}=8388608$,这个长度就比较实用了,如果用万进制做拆位,那这个最大长度就是$2^{21}=2097152$,对付数据范围在1e6以内的题目是足够了。

但是,在实际使用时,要是真出现比这还要长的数怎么办?这个倒不难解决,先用前一篇文章的分治乘法,拆成原来的一半长度,直到满足NTT的使用上限时,再使用NTT即可。

但以上这还不是本文要讲的重点,以上方法需要做拆位,拆位的常数能优化得更小吗?或者说,如果我坚持要使用万进制的话,NTT还能用吗?重新考虑一开始的限制,那个限制的原因,是最大值可能大于等于p,导致在模运算下失去了原来的值,那么,如果我们使用两个不同的p,最后用中国剩余定理来还原实际值呢,假设这2个素数是p1和p2,分别取1004535809和998244353,这样的话上限就是$p1p2=1002772198720536577$,在这个条件下用万进制的话,长度即使上100亿也不会越界,如果用65536进制的话长度也能达2亿,已经远远超过那个素数p本身带来的限制,完全足够使用了。不过代价是需要多做一次NTT变换,但比起万进制拆位拆成4个来比较,常数还是更小的。而且,即使不拆位,直接用10进制,速度也还是比不过使用双素数万进制。

方案有了,接下来又有新问题,就是要解这个数论方程$z = xp_1 + c_1 = yp_2 + c_2 \ (x,y,z \in N)$

直接用中国剩余定理可解,先求出$p_1$在$p_2$下的逆元$i_1$以及$p_2$在$p_1$下的逆元$i_2$,那么$z = (c_1i_1p_1+c_2i_2p_2) \ mod \ p_1p_2$

注意到,我们的p是在1e9附近,于是就出问题了,$c_1i_1p_1$的结果可能会达到1e27,爆掉了int64,当然如果你选的p比$2^{21}$小,那就不至于溢出,但长度又受限制了,那这时候怎么办?

有一个办法是模仿快速幂的方式来求乘法的模,但这个方法常数很大,我们要找更好的方案。

在《算法竞赛进阶指南》里面(page 6)就介绍了这么一段代码

int64_t mul_mod(int64_t a, int64_t b, int64_t m) {
    a %= m;
    b %= m;
    int64_t c = (int64_t)((long double)a * b / m);
    int64_t ans = a * b - c * m;
    if (ans < 0)
        ans += m;
    else if (ans >= m)
        ans -= m;
    return ans;
}

具体原理就不介绍了,但要注意的一点问题是,比如在VS编译器上,long double等于double,那么a和b就需要同时小于$2^{53}$,即$n(base-1)^2 < 2^{53}$,否则产生的误差就不是后面一个if能解决得了的,那么比如你用65536进制,那么n的最大值只有$2^{21}=2097152$,这是要注意的点(嗯,没错,这就是我踩过的坑)。

那再来,有没有办法避免这个问题?经过我一番资料搜索,还真有,那个同余方程还有另一种解法。先回顾一下原方程 $z = xp_1 + c_1 = yp_2 + c_2 \ (x,y,z \in N)$

我们可以不直接求z,改成求y,那么 $y = (c_1 - c_2)i_2 \ mod \ p_1$

于是不但完美避开int64的乘法求模,而且计算量更少,现在MiniBigInteger项目中的实现就是使用这个方法。

基本原理介绍完了,以下讲讲细节,在选取NTT的模p的时候,其实网上最常用的两个数1004535809和998244353并不是两个模的时候的最好的选择,通过暴力求解,最佳选择是$469762049=2^{26} \times 7 + 1,\ 167772161=2^{25} \times 5 + 1$,这个组合可以让最大长度达到$2^{25}$。

不过在这里就不给出用NTT的写法模板了,具体实现细节及常数优化参见项目源代码。以下列举部分关键代码

inline ntt_base_t mul_mod(int64_t a, int64_t b) {
    return a * b % NTT_MOD;
}
ntt_base_t pow_mod(int64_t a, int64_t b) {
    int64_t ans = 1;
    a %= NTT_MOD;
    while (b) {
        if (b & 1) ans = ans * a % NTT_MOD;
        b >>= 1;
        a = a * a % NTT_MOD;
    }
    return (ntt_base_t)ans;
}
void transform(ntt_base_t a[], size_t len, int on) {
    for (size_t i = 0; i < len; i++) {
        if (i < ntt_r[i]) std::swap(a[i], a[ntt_r[i]]);
    }
    size_t id = 0;
    for (size_t h = 1; h < len; h <<= 1) {
        ntt_base_t wn = ntt_wn[on][++id];
        for (size_t j = 0; j < len; j += h << 1) {
            ntt_base_t w = 1;
            size_t e = j + h;
            for (size_t k = j; k < e; k++, w = mul_mod(w, wn)) {
                ntt_base_t t = mul_mod(w, a[k + h]);
                a[k + h] = (a[k] - t + NTT_MOD) % NTT_MOD;
                a[k] = (a[k] + t) % NTT_MOD;
            }
        }
    }
    if (on == 0) {
        ntt_base_t inv = pow_mod(len, NTT_MOD - 2);
        for (size_t i = 0; i < len; i++)
            a[i] = mul_mod(a[i], inv);
    }
}
void mul_conv(size_t n) {
    transform(&ntt_a.front(), n, 1);
    transform(&ntt_b.front(), n, 1);
    for (size_t i = 0; i < n; i++)
        ntt_a[i] = mul_mod(ntt_a[i], ntt_b[i]);
    transform(&ntt_a.front(), n, 0);
}

分治除法

没错,大整数除法也能分治,这很多人还不知道这点。并不只有牛顿迭代这一路线。(也许本文是中文博客中第一篇介绍大整数的分治除法)

我们假设,相除的两个数表示为$\dfrac{a_1a_2a_3a_4}{b_1b_2}$,其中$a_1,a_2,a_3,a_4,b_1,b_2$看成是一个以$base^n$为基的“1位数字”,即看成是“4位数”除以“2位数”,并且,我们假设$b_1 \geq \dfrac{base^n}2$,接下来,我们就看成是模拟手算的除法,用前3位除以那个2位数,求出的余数便再来一次3位除以2位。而在3位除以2位时,我们需要的便是先试商,试商的方法是取被除数的最高2位除以除数的最高1位。因为我们在前面假设了$b_1 \geq \dfrac{base^n}2$ (这又叫做规则化),Knuth在TAOCP书中证明了在除数规则化后,可以证明这样试商的结果q’与实际商q的关系满足$q’-2 \leq q \leq q’$,所以最多只要对试商结果做2次减法修正便可。再看试商这一步,是“2位数”除以“1位数”,那我们就把原本以$base^n$为基,改为以$base^{n/2}$为基,于是就变成“4位数”除以“2位数”,便可以递归进行了。

对于实际应用的时候,在使用分治法之前,要做两件事情,首先找出系数m,令除数B乘以m后,最高位的$b_1$尽可能接近且小于base,然后让被除数和除数,均乘以m,再开始做分治。计算完毕后,得到的余数只需要再除以m即可。另外,如果分治过程中出现B的位数是奇数,不能正好除以2时,那也没有关系,直接向下取整,那么多出来的1位就作为精度补充使用。

再简单复述一篇,对于2n/n,降低base的指数,转为4n/2n,然后就是模拟手算,前3位除以2位,即3n/2n,需要这样除两个回合。而在3n/2n的时候,需要试商,试商时就是2n/n的除法,此步通过递归完成,除得的商修正个位数即可。

但当然,这个分治法仍然有改进的余地,我的实现可能比较糟糕,是同长度乘法的约2到4倍时间,这个倍数取决于长度,即它的时间复杂度为$O(mlogn)$,其中n表示除数长度,m表示长度n的乘法的时间,所以在n较大的时候,用牛顿迭代会更好。以下是我的实现部分代码

//返回右移后的结果
BigIntSimple shr_to(size_t n) const {
    BigIntSimple r;
    if (n >= v.size()) return r;
    r.v.assign(v.begin() + n, v.end());
    return r;
}

//对自身左移
BigIntSimple &shl(size_t n) {
    if (n == 0) return *this;
    v.insert(v.begin(), n, 0);
    return *this;
}

//分治除法递归部分
BigIntSimple &dividediv_recursion(const BigIntSimple &a, const BigIntSimple &b, BigIntSimple &r) {
    if (a < b) {
        r = a;
        return *this = BigIntSimple(0);
    } else if (b.v.size() <= 300) {
        return *this = a.div_mod(b, r);
    }
    size_t base = (b.v.size() + 1) / 2;
    //符合3/2时,进行试商
    if (a.v.size() <= base * 3) {
        base = b.v.size() / 2;
        BigIntSimple ma = a, mb = b, e;
        BigIntSimple ha = ma.shr_to(base);
        BigIntSimple hb = mb.shr_to(base);
        dividediv_recursion(ha, hb, r);
        ha = *this * b;
        while (a < ha) {
            ha.subtract(b);
            subtract(BigIntSimple(1));
        }
        r = a - ha;
        return *this;
    }
    //选择合适的base长度做分割
    if (a.v.size() > base * 4) base = a.v.size() / 2;
    BigIntSimple ha = a.shr_to(base);
    BigIntSimple c, d, m;
    dividediv_recursion(ha, b, d);
    shl(base);
    m.v.resize(base + d.v.size());
    for (size_t i = 0; i < base; ++i)
        m.v[i] = a.v[i];
    for (size_t i = 0; i < d.v.size(); ++i)
        m.v[base + i] = d.v[i];
    *this = *this + c.dividediv_recursion(m, b, r);
    return *this;
}

//分治除法规则化准备
BigIntSimple &dividediv(const BigIntSimple &a, const BigIntSimple &b, BigIntSimple &r) {
    if (b.v.size() <= 300) {
        return *this = a.div_mod(b, r);
    }
    //被除数不及两倍除数长度减2时,可以忽略一部分最低位且不影响结果
    if (b.v.size() * 2 - 2 > a.v.size()) {
        BigIntSimple ta = a, tb = b;
        size_t ans_len = a.v.size() - b.v.size() + 2;
        size_t shr = b.v.size() - ans_len;
        ta = ta.shr_to(shr);
        tb = tb.shr_to(shr);
        return dividediv(ta, tb, r);
    }
    //规则化
    int mul = (int)(((uint64_t)BIGINT_BASE * BIGINT_BASE - 1) /    //
        (*(b.v.begin() + b.v.size() - 1) * (uint64_t)BIGINT_BASE + //
            *(b.v.begin() + b.v.size() - 2) + 1));
    BigIntSimple ma = a * BigIntSimple(mul);
    BigIntSimple mb = b * BigIntSimple(mul);
    BigIntSimple d;
    ma.sign = mb.sign = 1;
    dividediv_recursion(ma, mb, d);
    r = d.div_mod(BigIntSimple(mul), ma);
    return *this;
}

BigIntSimple operator/(const BigIntSimple &b) const {
    BigIntSimple r, t;
    // t = div_mod(b, r);
    t.dividediv(*this, b, r);
    t.sign = sign * b.sign;
    return t;
}

BigIntSimple类的完整代码

事实上,代码写到这里,就已经和我在项目里实现的bigint_mini没有什么区别了,直接看这个好了,主要区别就是少了进制转换输入输出,性能上也稍微差一丁点。完整代码有点长,近500行。

点击展开

struct BigIntSimple {
    static const int BIGINT_BASE = 10000;
    static const int BIGINT_DIGITS = 4;

    int sign; // 1表示正数,-1表示负数
    std::vector<int> v;

    //定义0也需要长度1
    BigIntSimple() {
        sign = 1;
        v.push_back(0);
    }
    BigIntSimple(int n) { *this = n; }
    //判断是否为0
    bool iszero() const { return v.size() == 1 && v.back() == 0; }
    //消除前导0并修正符号
    void trim() {
        while (v.back() == 0 && v.size() > 1)
            v.pop_back();
        if (iszero()) sign = 1;
    }
    //获取pos位置上的数值,用于防越界,简化输入处理
    int get(unsigned pos) const {
        if (pos >= v.size()) return 0;
        return v[pos];
    }
    //绝对值大小比较
    bool absless(const BigIntSimple &b) const {
        if (v.size() == b.v.size()) {
            for (size_t i = v.size() - 1; i < v.size(); --i)
                if (v[i] != b.v[i]) return v[i] < b.v[i];
            return false;
        } else {
            return v.size() < b.v.size();
        }
    }
    //字符串输入
    void set(const char *s) {
        v.clear();
        sign = 1;
        //处理负号
        while (*s == '-')
            sign = -sign, ++s;
        //先按数位直接存入数组里
        for (size_t i = 0; s[i]; ++i)
            v.push_back(s[i] - '0');
        std::reverse(v.begin(), v.end());
        //压位处理,e是压位后的长度
        size_t e = (v.size() + BIGINT_DIGITS - 1) / BIGINT_DIGITS;
        for (size_t i = 0, j = 0; i < e; ++i, j += BIGINT_DIGITS) {
            v[i] = v[j]; //设置压位的最低位
            //高位的按每一位上的数值乘以m,m是该位的权值
            for (size_t k = 1, m = 10; k < BIGINT_DIGITS; ++k, m *= 10)
                v[i] += get(j + k) * m;
        }
        //修正压位后的长度
        if (e) {
            v.resize(e);
            trim();
        } else {
            v.resize(1);
        }
    }
    //分治进制转换输入
    BigIntSimple &_from_str(const std::string &s, int base) {
        //较短长度时直接计算,36^4 < 2^31,但取5就大于了,所以长度上限是4
        if (s.size() <= 4) {
            int v = 0;
            for (size_t i = 0; i < s.size(); ++i) {
                int digit = -1;
                if (s[i] >= '0' && s[i] <= '9')
                    digit = s[i] - '0';
                else if (s[i] >= 'A' && s[i] <= 'Z')
                    digit = s[i] - 'A' + 10;
                else if (s[i] >= 'a' && s[i] <= 'z')
                    digit = s[i] - 'a' + 10;
                v = v * base + digit;
            }
            return *this = v;
        }
        BigIntSimple m(base), h;
        size_t len = 1;
        //计算分割点
        for (; len * 3 < s.size(); len *= 2) {
            m = m * m;
        }
        h._from_str(s.substr(0, s.size() - len), base);
        _from_str(s.substr(s.size() - len), base);
        *this = *this + m * h;
        return *this;
    }
    //任意进制字符串输入(2~36进制)
    BigIntSimple &from_str(const char *s, int base = 10) {
        //特殊情况直接用原来的读入函数速度快
        if (base == 10) {
            set(s);
            return *this;
        }
        int vsign = 1, i = 0;
        while (s[i] == '-') {
            ++i;
            vsign = -vsign;
        }
        _from_str(std::string(s + i), base);
        sign = vsign;
        return *this;
    }
    //字符串输出
    std::string to_dec() const {
        std::string s;
        for (size_t i = 0; i < v.size(); ++i) {
            int d = v[i];
            //拆开压位
            for (size_t k = 0; k < BIGINT_DIGITS; ++k) {
                s += d % 10 + '0';
                d /= 10;
            }
        }
        //去除前导0
        while (s.size() > 1 && s.back() == '0')
            s.pop_back();
        //补符号
        if (sign < 0) s += '-';
        //不要忘记要逆序
        std::reverse(s.begin(), s.end());
        return s;
    }
    //递归分治进制转换输出
    std::string _to_str(int base, int pack) const {
        std::string s;
        //长度只剩下2时可以直接算
        if (v.size() <= 2) {
            int d = v[0] + (v.size() > 1 ? v[1] : 0) * BIGINT_BASE;
            do {
                int g = d % base;
                if (g < 10) {
                    s += char(g + '0');
                } else {
                    s += char(g + 'a' - 10);
                }
                d /= base;
            } while (d);
            //填充前导0
            while (s.size() < pack)
                s += '0';
            std::reverse(s.begin(), s.end());
            return s;
        }
        BigIntSimple m(base), h, l;
        size_t len = 1; //计算余数部分要补的前导0
        //计算分割点
        for (; m.v.size() * 3 < v.size(); len *= 2) {
            m = m * m;
        }
        h = div_mod(m, l); //算出分割后的高位h和低位l
        s = h._to_str(base, std::max(pack - (int)len, 0));
        return s + l._to_str(base, len);
    }
    //任意进制(2~36进制)字符串输出
    std::string to_str(int base = 10) const {
        if (base == 10) {
            return to_dec();
        }
        std::string s;
        BigIntSimple m(*this);
        m.sign = 1;
        s = m._to_str(base, 0);
        return sign >= 0 ? s : "-" + s;
    }

    bool operator<(const BigIntSimple &b) const {
        if (sign == b.sign) {
            return sign > 0 ? absless(b) : b.absless(*this);
        } else {
            return sign < 0;
        }
    }

    bool operator==(const BigIntSimple &b) const {
        if (sign == b.sign) {
            return !absless(b) && !b.absless(*this);
        }
        return false;
    }

    BigIntSimple &operator=(int n) {
        v.clear();
        sign = n >= 0 ? 1 : -1;
        for (n = abs(n); n; n /= BIGINT_BASE)
            v.push_back(n % BIGINT_BASE);
        if (v.empty()) v.push_back(0);
        return *this;
    }

    BigIntSimple &operator=(const std::string &s) {
        set(s.c_str());
        return *this;
    }

    BigIntSimple operator-() const {
        BigIntSimple r = *this;
        r.sign = -r.sign;
        return r;
    }

    BigIntSimple operator+(const BigIntSimple &b) const {
        //符号不同时转换为减法
        if (sign != b.sign) return *this - -b;
        BigIntSimple r = *this;
        //填充高位
        if (r.v.size() < b.v.size()) r.v.resize(b.v.size());
        int carry = 0;
        //逐位相加
        for (size_t i = 0; i < b.v.size(); ++i) {
            carry += r.v[i] + b.v[i] - BIGINT_BASE;
            r.v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理进位,拆两个循环来写是避免做 i < b.v.size() 的判断
        for (size_t i = b.v.size(); carry && i < r.v.size(); ++i) {
            carry += r.v[i] - BIGINT_BASE;
            r.v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理升位进位
        if (carry) r.v.push_back(carry);
        return r;
    }

    BigIntSimple &subtract(const BigIntSimple &b) {
        int borrow = 0;
        //先处理b的长度
        for (size_t i = 0; i < b.v.size(); ++i) {
            borrow += v[i] - b.v[i];
            v[i] = borrow;
            v[i] -= BIGINT_BASE * (borrow >>= 31);
        }
        //如果还有借位就继续处理
        for (size_t i = b.v.size(); borrow; ++i) {
            borrow += v[i];
            v[i] = borrow;
            v[i] -= BIGINT_BASE * (borrow >>= 31);
        }
        //减法可能会出现前导0需要消去
        trim();
        return *this;
    }

    BigIntSimple operator-(const BigIntSimple &b) const {
        //符号不同时转换为加法
        if (sign != b.sign) return (*this) + -b;
        if (absless(b)) { //保证大数减小数
            BigIntSimple r = b;
            return -r.subtract(*this);
        } else {
            BigIntSimple r = *this;
            return r.subtract(b);
        }
    }

    BigIntSimple &offset_add(const BigIntSimple &b, int offset) {
        //填充高位
        if (v.size() < b.v.size() + offset) v.resize(b.v.size() + offset);
        int carry = 0;
        //逐位相加
        for (size_t i = 0; i < b.v.size(); ++i) {
            carry += v[i + offset] + b.v[i] - BIGINT_BASE;
            v[i + offset] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理进位,拆两个循环来写是避免做 i < b.v.size() 的判断
        for (size_t i = b.v.size() + offset; carry && i < v.size(); ++i) {
            carry += v[i] - BIGINT_BASE;
            v[i] = carry - BIGINT_BASE * (carry >> 31);
            carry = (carry >> 31) + 1;
        }
        //处理升位进位
        if (carry) v.push_back(carry);
        return *this;
    }

    BigIntSimple mul(const BigIntSimple &b) const {
        // r记录相加结果
        BigIntSimple r;
        r.v.resize(v.size() + b.v.size()); //初始化长度
        for (size_t j = 0; j < v.size(); ++j) {
            int carry = 0, m = v[j]; // m用来缓存乘数
            // carry可能很大,只能使用求模的办法,此循环与加法部分几乎相同,就多乘了个m
            for (size_t i = 0; i < b.v.size(); ++i) {
                carry += r.v[i + j] + b.v[i] * m;
                r.v[i + j] = carry % BIGINT_BASE;
                carry /= BIGINT_BASE;
            }
            r.v[j + b.v.size()] += carry;
        }
        r.trim();
        return r;
    }

    BigIntSimple &fastmul(const BigIntSimple &a, const BigIntSimple &b) {
        //小于某个阈值就直接用暴力乘法
        if (std::min(a.v.size(), b.v.size()) <= 300) {
            return *this = a.mul(b);
        }
        BigIntSimple ah, al, bh, bl, h, m;
        //计算分割点
        size_t split = std::max(                            //
            std::min((a.v.size() + 1) / 2, b.v.size() - 1), //
            std::min((b.v.size() + 1) / 2, a.v.size() - 1));
        //按分割点拆成4个数
        al.v.assign(a.v.begin(), a.v.begin() + split);
        ah.v.assign(a.v.begin() + split, a.v.end());
        bl.v.assign(b.v.begin(), b.v.begin() + split);
        bh.v.assign(b.v.begin() + split, b.v.end());
        //按公式递归计算
        fastmul(al, bl);
        h.fastmul(ah, bh);
        m.fastmul(al + ah, bl + bh);
        m.subtract(*this + h);
        v.resize(a.v.size() + b.v.size());

        offset_add(m, split);
        offset_add(h, split * 2);
        trim();
        return *this;
    }

    BigIntSimple operator*(const BigIntSimple &b) const {
        BigIntSimple r;
        r.fastmul(*this, b);
        // r = mul(b);
        r.sign = sign * b.sign;
        return r;
    }

    //对b乘以mul再左移offset的结果相减,为除法服务
    BigIntSimple &sub_mul(const BigIntSimple &b, int mul, int offset) {
        if (mul == 0) return *this;
        int borrow = 0;
        //与减法不同的是,borrow可能很大,不能使用减法的写法
        for (size_t i = 0; i < b.v.size(); ++i) {
            borrow += v[i + offset] - b.v[i] * mul - BIGINT_BASE + 1;
            v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
            borrow /= BIGINT_BASE;
        }
        //如果还有借位就继续处理
        for (size_t i = b.v.size(); borrow; ++i) {
            borrow += v[i + offset] - BIGINT_BASE + 1;
            v[i + offset] = borrow % BIGINT_BASE + BIGINT_BASE - 1;
            borrow /= BIGINT_BASE;
        }
        return *this;
    }

    BigIntSimple div_mod(const BigIntSimple &b, BigIntSimple &r) const {
        BigIntSimple d;
        r = *this;
        if (absless(b)) return d;
        d.v.resize(v.size() - b.v.size() + 1);
        //提前算好除数的最高三位+1的倒数,若最高三位是a3,a2,a1
        //那么db是a3+a2/base+(a1+1)/base^2的倒数,最后用乘法估商的每一位
        //此法在BIGINT_BASE<=32768时可在int32范围内用
        //但即使使用int64,那么也只有BIGINT_BASE<=131072时可用(受double的精度限制)
        //能保证估计结果q'与实际结果q的关系满足q'<=q<=q'+1
        //所以每一位的试商平均只需要一次,只要后面再统一处理进位即可
        //如果要使用更大的base,那么需要更换其它试商方案
        double t = (b.get((unsigned)b.v.size() - 2) + (b.get((unsigned)b.v.size() - 3) + 1.0) / BIGINT_BASE);
        double db = 1.0 / (b.v.back() + t / BIGINT_BASE);
        for (size_t i = v.size() - 1, j = d.v.size() - 1; j <= v.size();) {
            int rm = r.get(i + 1) * BIGINT_BASE + r.get(i);
            int m = std::max((int)(db * rm), r.get(i + 1));
            r.sub_mul(b, m, j);
            d.v[j] += m;
            if (!r.get(i + 1)) //检查最高位是否已为0,避免极端情况
                --i, --j;
        }
        r.trim();
        //修正结果的个位
        int carry = 0;
        while (!r.absless(b)) {
            r.subtract(b);
            ++carry;
        }
        //修正每一位的进位
        for (size_t i = 0; i < d.v.size(); ++i) {
            carry += d.v[i];
            d.v[i] = carry % BIGINT_BASE;
            carry /= BIGINT_BASE;
        }
        d.trim();
        return d;
    }

    //返回右移后的结果
    BigIntSimple shr_to(size_t n) const {
        BigIntSimple r;
        if (n >= v.size()) return r;
        r.v.assign(v.begin() + n, v.end());
        return r;
    }

    //对自身左移
    BigIntSimple &shl(size_t n) {
        if (n == 0) return *this;
        v.insert(v.begin(), n, 0);
        return *this;
    }

    //分治除法递归部分
    BigIntSimple &dividediv_recursion(const BigIntSimple &a, const BigIntSimple &b, BigIntSimple &r) {
        if (a < b) {
            r = a;
            return *this = BigIntSimple(0);
        } else if (b.v.size() <= 300) {
            return *this = a.div_mod(b, r);
        }
        size_t base = (b.v.size() + 1) / 2;
        //符合3/2时,进行试商
        if (a.v.size() <= base * 3) {
            base = b.v.size() / 2;
            BigIntSimple ma = a, mb = b, e;
            BigIntSimple ha = ma.shr_to(base);
            BigIntSimple hb = mb.shr_to(base);
            dividediv_recursion(ha, hb, r);
            ha = *this * b;
            while (a < ha) {
                ha.subtract(b);
                subtract(BigIntSimple(1));
            }
            r = a - ha;
            return *this;
        }
        //选择合适的base长度做分割
        if (a.v.size() > base * 4) base = a.v.size() / 2;
        BigIntSimple ha = a.shr_to(base);
        BigIntSimple c, d, m;
        dividediv_recursion(ha, b, d);
        shl(base);
        m.v.resize(base + d.v.size());
        for (size_t i = 0; i < base; ++i)
            m.v[i] = a.v[i];
        for (size_t i = 0; i < d.v.size(); ++i)
            m.v[base + i] = d.v[i];
        *this = *this + c.dividediv_recursion(m, b, r);
        return *this;
    }

    //分治除法规则化准备
    BigIntSimple &dividediv(const BigIntSimple &a, const BigIntSimple &b, BigIntSimple &r) {
        if (b.v.size() <= 300) {
            return *this = a.div_mod(b, r);
        }
        //被除数不及两倍除数长度减2时,可以忽略一部分最低位且不影响结果
        if (b.v.size() * 2 - 2 > a.v.size()) {
            BigIntSimple ta = a, tb = b;
            size_t ans_len = a.v.size() - b.v.size() + 2;
            size_t shr = b.v.size() - ans_len;
            ta = ta.shr_to(shr);
            tb = tb.shr_to(shr);
            return dividediv(ta, tb, r);
        }
        //规则化
        int mul = (int)(((uint64_t)BIGINT_BASE * BIGINT_BASE - 1) /                //
                        (*(b.v.begin() + b.v.size() - 1) * (uint64_t)BIGINT_BASE + //
                         *(b.v.begin() + b.v.size() - 2) + 1));
        BigIntSimple ma = BigIntSimple(mul) * a;
        BigIntSimple mb = BigIntSimple(mul) * b;
        BigIntSimple d;
        ma.sign = mb.sign = 1;
        dividediv_recursion(ma, mb, d);
        r = d.div_mod(BigIntSimple(mul), ma);
        return *this;
    }

    BigIntSimple operator/(const BigIntSimple &b) const {
        BigIntSimple r, t;
        // t = div_mod(b, r);
        t.dividediv(*this, b, r);
        t.sign = sign * b.sign;
        return t;
    }

    BigIntSimple operator%(const BigIntSimple &b) const { return *this - *this / b * b; }
};

Avatar
抱抱熊

一个喜欢折腾和研究算法的大学生

Related

comments powered by Disqus