Modulo Square Root

正整数 $ m $ と整数 $ a \in [0, m) $ について $ x^2 = a \mod m $ の解を求めます.

方針

$ m = p_1^{e_1} \cdots p_k^{e_k} $ と素因数分解し,各素数べきについて $ x^2 = a \mod p_i^{e_i} $ を解いて,中国剰余定理を用いて解を構成します.

素因数分解と中国剰余定理部分は解説しません.

前処理

  • $ x^2 = 0 \mod p^e $ の場合,$ x^2 = \alpha p^e$ と書いて両辺の $ p $ 因子の数を比較すると $x = b p^{\lceil e / 2 \rceil}$ と書けることがわかります.この形の数はすべて方程式を満たすのでこれらが解となります.解の個数は $ \sqrt{p^e} $ です.

  • $ x^2 = a \mod p^e $ ($ a \neq 0 $) の場合,$ a = b p^k $ と $ p $ 因子を括り出して $ x^2 = b p^k + \alpha p^e$ と書き,両辺の $ p $ 因子の数を比較すると $ k $ が奇数のとき解が存在しないこと・$ k $ が偶数のとき $ x = y p^{k / 2} $ とおけることがわかります.したがって以下 $y^2 = b \mod p^e$ ただし $ \mathrm{gcd}(b, p) = 1 $ について考えれば十分です.

Hensel Lift

一般に $ y^2 = b \mod p^e $ を満たす $y $ は任意の $ 0 \le k \le e $ について $ y^2 = b \mod p^k $ を満たします.このことを手がかりに,まずは低い次数の素数べきでの解を求め,それを高い次数の素数べきに「持ち上げる」ことを考えます.

$l \le k$ とします.$ y_k^2 = b \mod p^k $ を満たす $ y_k $ をとり,$ y_{k+l}^2 = b \mod p^{k+l} $ を満たす $ y_{k+l} $ のうち $ y_{k+l} = y_k + t p^k $ と書けるものを探します.直接計算により

$$ y_{k+l}^2 = y_k^2 + 2 y_k t p^k + t^2 p^{2 k} = b + \left( \frac{y_k^2 - b}{p^k} + 2 y_k t \right) p^k \mod p^{k+l} $$

となるので $ (y_k^2 - b)/p^k + 2 y_k t = 0 \mod p^{l} $ なる $ t $ を求めればよいことがわかります;式中に現れる分数は割り切れることに注意してください.この構成(を一般化したもの)は Hensel Lift と呼ばれています.$l \approx k$ にとることで $O(\log e)$ 回の Hensel Lift で解を構成できます.

以下 $ p $ が奇素数の場合と $ 2 $ の場合にわけて具体的な手続きを与えます.

素数べき

$ y_1^2 = b \mod p $ を満たす $y_1$ は Tonelli-Shanks や Cipolla のアルゴリズムによって見つけられます.この解を初期値として Hensel Lift を行います.

$ y_k^2 = b \mod p^k $ なる $y_k$ をもっているとします.もし $y_k$ が $ p $ と互いに素であれば $2 y_k$ は法 $p^l$ のもとでモジュロ逆元 $(2 y_k)^{-1}$ をもち,これを用いて表せる $ t = ((y_k^2 - b)/p^k) \times (2 y_k)^{-1} $ が Hensel Lift の条件を満たす $ t $ となります.したがって $ y_{k+l} = y_k + t p^k = y_k + (y_k^2 - b) (2 y_k)^{-1} \mod p^{k+l} $ に解を持ち上げることができました.この式から $ y_{k+l} $ もまた $ p $ と互いに素であるのでこの手続きを繰り返すことができ,任意次数まで解を持ち上げることができます.

2 べき

$ k \le 2 $ のときは全探索で解が求まります.具体的には以下がわかります:

  • $ y^2 = b \mod 2 $ は前処理より $b = 1$,解は $x = 1$.
  • $ y^2 = b \mod 4 $ は前処理より $b = 1$ または $b = 3$.全通り試すことで $b = 1$ の場合にのみ解をもち,解は $y = 1$ と $y = 3$.

$ k \ge 3 $ を考えます.$ s = (y_k^2 - b)/2^k $ とおいて Hensel Lift の式を眺めると $ s + 2 y_k t = 0 \mod 2^l $ であればよいとわかります.この式から $ s $ は 2 の倍数でなければならないので $ s' = s / 2 $ とおけて $ 2 (s' + y_k t) = 0 \mod 2^l $ となり,$l = 1$ では $t = 0$, $ = 1$ がともに条件を満たし,$l \ge 2$ では $ s' + y_k t = 0 \mod 2^l $ および $ s' + y_k t = 2^{k - 1} \mod 2^l $ に帰着します.いずれも $ y_k $ は $ 2 $ と互いに素なのでモジュロ逆元がとれて $ t $ が計算できます.

なお,この構成から $k \ge 3$ のとき解が 0 個または 4 個であることがわかります.

実装例

ソースコードの全体はテンプレート周りが長いので一部のみを貼ります.

pub trait SqrtMod: Int {
    fn sqrt_mod_all(self, modulo: Self) -> Vec<Self> {
        let solve = |p: Self, e: u32| -> Vec<Self> {
            let p_e = p.clone().pow(e);
            let a = self.clone() % p_e.clone();
            return if a.is_zero() {
                let mut solutions = Vec::new();
                let mut curr = Self::zero();
                let step = p.pow((e + 1) / 2);
                while curr < p_e {
                    solutions.push(curr.clone());
                    curr = curr + step.clone();
                }
                solutions
            } else {
                let (b, l) = a.trim_factor(p.clone());
                if l.is_odd() {
                    Vec::new()
                } else {
                    let mut solutions = Vec::new();
                    for x in b.sqrt_mod_prime_power(p.clone(), e - l) {
                        let begin = x * p.clone().pow(l / 2);
                        let step = p.clone().pow(e - l / 2);
                        let mut curr = begin.clone();
                        loop {
                            solutions.push(curr.clone());
                            curr = curr.add_mod(step.clone(), p_e.clone());
                            if curr == begin {
                                break;
                            }
                        }
                    }
                    solutions.sort_unstable();
                    solutions.dedup();
                    solutions
                }
            };
        };
        let mut modulos = Vec::new();
        let mut solutions = Vec::new();
        for (p, e) in modulo.prime_factors() {
            solutions.push(solve(p.clone(), e));
            modulos.push(p.pow(e));
        }
        solutions
            .into_iter()
            .map(|vec| vec.into_iter())
            .multi_cartesian_product()
            .map(move |data| {
                let (a, _) =
                    chinese_remainder_theorem(data.into_iter().zip(modulos.clone())).unwrap();
                a
            })
            .collect()
    }
    fn sqrt_mod_two_power(self, e: u32) -> Vec<Self> {
        let two = Self::one() + Self::one();
        let four = two.clone() + two.clone();
        let eight = four.clone() + four.clone();

        if e == 1 {
            return vec![Self::one()];
        }
        if e == 2 {
            if !(self % four).is_one() {
                return Vec::new();
            } else {
                return vec![Self::one(), Self::one() + Self::two()];
            }
        }
        if !(self.clone() % eight).is_one() {
            return Vec::new();
        }
        let mut ys = vec![Self::one(), Self::one() + Self::two()];
        let mut p_k = four;
        let mut k = 2;
        while k < e {
            let l = k.min(e - k);
            let p_l = if k == l {
                p_k.clone()
            } else {
                two.clone().pow(l)
            };
            let p_kl = p_k.clone() * p_l.clone();

            let mut next = Vec::new();
            for y in ys {
                let s = (y.clone() * y.clone() - self.clone()) / p_k.clone();
                if s.is_odd() {
                    continue;
                }
                if l == 1 {
                    next.push(y.clone());
                    next.push(y.clone() + p_k.clone());
                } else {
                    let s = (s / two.clone()) % p_l.clone();
                    let inv_y = (y.clone() % p_l.clone()).inv_mod(p_l.clone()).unwrap();

                    let t = (-s.clone() * inv_y.clone()) % p_l.clone();
                    let new_y = (y.clone() + t * p_k.clone()) % p_kl.clone();
                    next.push(if new_y.is_negative() {
                        p_kl.clone() + new_y
                    } else {
                        new_y
                    });
                    let t = ((p_l.clone() / two.clone() - s) * inv_y.clone()) % p_l.clone();
                    let new_y = (y.clone() + t * p_k.clone()) % p_kl.clone();
                    next.push(if new_y.is_negative() {
                        p_kl.clone() + new_y
                    } else {
                        new_y
                    });
                }
            }
            ys = next;
            k += l;
            p_k = p_kl;
        }
        ys
    }
    fn sqrt_mod_odd_prime_power(self, p: Self, e: u32) -> Vec<Self> {
        if legendre_symbol(self.clone(), p.clone()) != 1 {
            return Vec::new();
        }
        let mut y = (self.clone() % p.clone())
            .sqrt_mod_tonelli_shanks(p.clone())
            .unwrap();
        let mut p_k = p.clone();
        let mut k = 1;
        while k < e {
            let l = k.min(e - k);
            let p_l = if k == l {
                p_k.clone()
            } else {
                p.clone().pow(l)
            };
            let p_kl = p_k.clone() * p_l.clone();

            let two_y = (y.clone() + y.clone()) % p_l.clone();
            let inv_df = two_y.inv_mod(p_l.clone()).unwrap();
            let f = (y.clone() * y.clone() - self.clone()) % p_kl.clone();
            y = (y - f * inv_df) % p_kl.clone();

            if y.is_negative() {
                y = y + p_kl.clone();
            }
            k += l;
            p_k = p_kl;
        }
        vec![y.clone(), p_k - y]
    }
    fn sqrt_mod_prime_power(self, p: Self, e: u32) -> Vec<Self> {
        if p.is_even() {
            self.sqrt_mod_two_power(e)
        } else {
            self.sqrt_mod_odd_prime_power(p, e)
        }
    }
    fn sqrt_mod(self, modulo: Self) -> Option<Self> {
        self.sqrt_mod_all(modulo).pop()
    }
}

impl<T: Int + Signed> SqrtMod for T {}