33#include " cp-algo/util/checkpoint.hpp"
44#include " cp-algo/util/bump_alloc.hpp"
55#include " cp-algo/util/simd.hpp"
6- #include " cp-algo/math/common .hpp"
6+ #include " cp-algo/math/combinatorics .hpp"
77#include " cp-algo/number_theory/modint.hpp"
88#include < ranges>
99
@@ -37,38 +37,37 @@ namespace cp_algo::math {
3737 t = mod - t - 1 ;
3838 y = t % 2 ? 1 : mod-1 ;
3939 }
40- int pw = 0 ;
40+ auto pw = 32ull * (t + 1 ) ;
4141 while (t > limit_reg) {
4242 limit_odd = std::max (limit_odd, (t - 1 ) / 2 );
4343 odd_args_per_block[(t - 1 ) / 2 / subblock].push_back ({int (i), (t - 1 ) / 2 });
4444 t /= 2 ;
4545 pw += t;
4646 }
4747 reg_args_per_block[t / subblock].push_back ({int (i), t});
48- y *= bpow ( base ( 2 ), pw );
48+ y *= pow_fixed< base, 2 >( int (pw % (mod - 1 )) );
4949 }
5050 checkpoint (" init" );
51- uint32_t b2x32 = ( 1ULL << 32 ) % mod ;
51+ base bi2x32 = pow_fixed<base, 2 >( 32 ). inv () ;
5252 auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) {
5353 base fact = 1 ;
5454 for (int b = 0 ; b <= limit; b += accum * block) {
5555 u32x8 cur[accum];
5656 static std::array<u32x8, subblock> prods[accum];
5757 for (int z = 0 ; z < accum; z++) {
5858 for (int j = 0 ; j < simd_size; j++) {
59+ #pragma GCC diagnostic push
60+ #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
5961 cur[z][j] = uint32_t (b + z * block + j * subblock);
6062 cur[z][j] = proj (cur[z][j]);
6163 prods[z][0 ][j] = cur[z][j] + !cur[z][j];
62- #pragma GCC diagnostic push
63- #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
64- cur[z][j] = uint32_t (uint64_t (cur[z][j]) * b2x32 % mod);
64+ prods[z][0 ][j] = uint32_t (uint64_t (prods[z][0 ][j]) * bi2x32.getr () % mod);
6565#pragma GCC diagnostic pop
6666 }
6767 }
6868 for (int i = 1 ; i < block / simd_size; i++) {
6969 for (int z = 0 ; z < accum; z++) {
7070 cur[z] += step;
71- cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
7271 prods[z][i] = montgomery_mul (prods[z][i - 1 ], cur[z], mod, imod);
7372 }
7473 }
@@ -85,12 +84,12 @@ namespace cp_algo::math {
8584 checkpoint (" mul ans" );
8685 }
8786 };
88- uint32_t b2x33 = ( 1ULL << 33 ) % mod ;
89- process (limit_reg, reg_args_per_block, b2x32, std::identity{ });
90- process (limit_odd, odd_args_per_block, b2x33, []( uint32_t x) { return 2 * x + 1 ;} );
87+ process (limit_reg, reg_args_per_block, 1 , std::identity{}) ;
88+ process (limit_odd, odd_args_per_block, 2 , []( uint32_t x) { return 2 * x + 1 ; });
89+ auto invs = bulk_invs<base>(res );
9190 for (auto [i, x]: res | std::views::enumerate) {
9291 if (args[i] >= mod / 2 ) {
93- x = x. inv () ;
92+ x = invs[i] ;
9493 }
9594 }
9695 checkpoint (" inv ans" );
0 commit comments