Memo

Twitterに書くには長すぎることを書きます。Opinions are my own.

AtCoder Beginner Contest 129 F - Takahashi's Basics in Education and Learning

問題: https://atcoder.jp/contests/abc129/tasks/abc129_f

3x3行列かー、と思ったけどそこから先の実装も難しい

考察パート

  • 数列の各項の桁数の種類は高々18種類しかないので、桁数ごとに求めたい -> コンテスト中にわかった
  • 高速に求めるには行列累乗が良さそう -> コンテスト中にわかった
  • 3x3行列の行列累乗になる -> コンテスト中にわからなかった
ll mod;

struct mint {
    ll x;
    mint():x(0){}
    mint(ll x):x((x%mod+mod)%mod){}
    mint& fix() { x = (x%mod+mod)%mod; return *this;}
    mint operator-() const { return mint(0) - *this;}
    mint& operator+=(const mint& a){ if((x+=a.x)>=mod) x-=mod; return *this;}
    mint& operator-=(const mint& a){ if((x+=mod-a.x)>=mod) x-=mod; return *this;}
    mint& operator*=(const mint& a){ (x*=a.x)%=mod; return *this;}
    mint operator+(const mint& a)const{ return mint(*this) += a;}
    mint operator-(const mint& a)const{ return mint(*this) -= a;}
    mint operator*(const mint& a)const{ return mint(*this) *= a;}
    bool operator<(const mint& a)const{ return x < a.x;}
    bool operator==(const mint& a)const{ return x == a.x;}
};
istream& operator>>(istream&i,mint&a){i>>a.x;return i;}
ostream& operator<<(ostream&o,const mint&a){o<<a.x;return o;}

mint mod_pow(mint a, __uint128_t x) {
  mint ret = 1;
  while(x > 0) {
    if (x & (__uint128_t)1) ret *= a;
    a *= a; x >>= 1;
  }
  return ret;
}

// return a * b where a and b are n * n matrix
vector<mint> mat_mul(const vector<mint>& a, const vector<mint>& b, int n) {
  vector<mint> ret(n*n);
  rep(i, n) rep(j, n) rep(k, n) ret[i*n+j] += a[i*n+k]*b[k*n+j];
  return ret;
}

// return identity matrix of size n * n
vector<mint> id_mat(int n) {
  vector<mint> ret(n*n);
  rep(i, n) ret[i*n+i] = 1;
  return ret;
}

// return a^x where a is n * n matrix
// a is changed, so do not use &a
vector<mint> mat_pow(vector<mint> a, ll x, int n) {
  auto ret = id_mat(n);
  while(x>0) {
    if (x&1) ret = mat_mul(ret, a, n);
    a = mat_mul(a, a, n); x>>=1;
  }
  return ret;
}

signed main() {
  ios_base::sync_with_stdio(false);
  cin.tie(nullptr);

  ll l, a, b; cin >> l >> a >> b;
  cin >> mod;

  ll s = 10;
  vector<ll> pos(19);
  pos[0] = -1;
  for(int i=1; i<=18; i++) {
    if (a >= s) {
      pos[i] = 0;
    }
    ll ok = -1;
    ll ng = l;
    while(ok+1<ng) {
      auto med = (ok+ng)/2;
      auto k = a+b*med;
      if (k >= s) {
        ng = med;
      } else {
        ok = med;
      }
    }
    pos[i] = ok;
    if (i != 18) s *= 10;
  }

  mint ans = 0;
  __uint128_t digit = 0;
  for(int i=18; i>=1; i--) {
    auto end = pos[i];
    auto start = pos[i-1]+1;
    if (start > end) continue;
    auto cnt = end-start+1;
    vector<mint> mat = {mod_pow(10, i), 1, 0, 0, 1, b, 0, 0, 1};
    auto mati = mat_pow(mat, cnt-1, 3);
    auto now = mati[0]*(a+b*start)+mati[1]*(a+b*(start+1))+mati[2];
    ans += mod_pow(10, digit)*now;
    digit += ((__uint128_t)cnt)*i;
  }
  cout << ans << endl;
}
  • 桁数を分けるところは O(1)で求められるが、雑に二分探索した。尺取だと間に合わない。
  • 合計の桁数は128bit整数に収まるので __uint128_t を使った (初めて使った)
  • mat_pow(a, x, n)xint になっていてバグらせ続けた...

提出コード: https://atcoder.jp/contests/abc129/submissions/5872762

(追記) 桁数そのものではなく10のべき乗を持つように工夫したら __uint128_t は不要になった

  mint ans = 0;
  mint k = 1;
  for(int i=18; i>=1; i--) {
    auto end = pos[i];
    auto start = pos[i-1]+1;
    if (start > end) continue;
    auto cnt = end-start+1;
    vector<mint> mat = {mod_pow(10, i), 1, 0, 0, 1, b, 0, 0, 1};
    auto mati = mat_pow(mat, cnt-1, 3);
    auto now = mati[0]*(a+b*start)+mati[1]*(a+b*(start+1))+mati[2];
    ans += k*now;
    k *= mod_pow(mod_pow(10, cnt), i);
  }
  cout << ans << endl;

提出コード: https://atcoder.jp/contests/abc129/submissions/5873585