77#include < vector>
88#include < map>
99namespace cp_algo ::math {
10- using gaussint = complex <int64_t >;
11- gaussint two_squares_prime_any (int64_t p) {
10+ template <typename T>
11+ using gaussint = complex <T>;
12+ template <typename _Int>
13+ auto two_squares_prime_any (_Int p) {
1214 if (p == 2 ) {
13- return gaussint ( 1 , 1 ) ;
15+ return gaussint<_Int>{ 1 , 1 } ;
1416 }
1517 assert (p % 4 == 1 );
16- using base = dynamic_modint<>;
18+ using Int = std::make_signed_t <_Int>;
19+ using base = dynamic_modint<Int>;
1720 return base::with_mod (p, [&](){
1821 base g = primitive_root (p);
1922 int64_t i = bpow (g, (p - 1 ) / 4 ).getr ();
@@ -25,49 +28,50 @@ namespace cp_algo::math {
2528 q0 = std::exchange (q1, q0 + d * q1);
2629 r = std::exchange (m, r % m);
2730 } while (q1 < p / q1);
28- return gaussint ( q0, (base (i) * base (q0)).rem ()) ;
31+ return gaussint<_Int>{ q0, (base (i) * base (q0)).rem ()} ;
2932 });
3033 }
3134
32- std::vector<gaussint> two_squares_all (int64_t n) {
35+ template <typename Int>
36+ std::vector<gaussint<Int>> two_squares_all (Int n) {
3337 if (n == 0 ) {
3438 return {0 };
3539 }
3640 auto primes = factorize (n);
37- std::map<int64_t , int > cnt;
41+ std::map<Int , int > cnt;
3842 for (auto p: primes) {
3943 cnt[p]++;
4044 }
41- std::vector<gaussint> res = {1 };
45+ std::vector<gaussint<Int> > res = {1 };
4246 for (auto [p, c]: cnt) {
43- std::vector<gaussint> nres;
47+ std::vector<gaussint<Int> > nres;
4448 if (p % 4 == 3 ) {
4549 if (c % 2 == 0 ) {
46- auto mul = bpow (gaussint (p), c / 2 );
50+ auto mul = bpow (gaussint<Int> (p), c / 2 );
4751 for (auto p: res) {
4852 nres.push_back (p * mul);
4953 }
5054 }
5155 } else if (p % 4 == 1 ) {
52- gaussint base = two_squares_prime_any (p);
56+ auto base = two_squares_prime_any (p);
5357 for (int i = 0 ; i <= c; i++) {
5458 auto mul = bpow (base, i) * bpow (conj (base), c - i);
5559 for (auto p: res) {
5660 nres.push_back (p * mul);
5761 }
5862 }
5963 } else if (p % 4 == 2 ) {
60- auto mul = bpow (gaussint (1 , 1 ), c);
64+ auto mul = bpow (gaussint<Int> (1 , 1 ), c);
6165 for (auto p: res) {
6266 nres.push_back (p * mul);
6367 }
6468 }
6569 res = nres;
6670 }
67- std::vector<gaussint> nres;
71+ std::vector<gaussint<Int> > nres;
6872 for (auto p: res) {
6973 while (p.real () < 0 || p.imag () < 0 ) {
70- p *= gaussint (0 , 1 );
74+ p *= gaussint<Int> (0 , 1 );
7175 }
7276 nres.push_back (p);
7377 if (!p.real () || !p.imag ()) {
0 commit comments