Memo

Twitterに書くには長すぎることを書きます。Opinions are my own.

ABC161-F: Division or Subtraction

問題: https://atcoder.jp/contests/abc161/tasks/abc161_f

自分は  1 \le N \le 10^{12} という制約を見ると、前半パート  (2 \le K \le \sqrt{N}) と後半パート  (\sqrt{N} \lt K \le N) に分けて考えることが多くて、今回もそれで解けました。

前半パートについては、 K で割り切れる限り割って、割り切れなくなったら  K で割った余りを判定するだけです。かなり定数倍が小さい  O( \sqrt{N} \log N) で解けます。

後半パートは少し工夫が必要です。まず、 N Kで割り切れる場合は  N=K を除いて条件を満たさないことがわかるので除外して考えます( 1 \lt N\div{K} \lt K なので)。次がポイントなのですが、 N Kで割った商を  p で余りを r とすると、 N K+1 で割った商が  p であるならば余りは  r-p になります。例えば、 1333 = 100\times 13 + 33 = 101\times 13 + 20 = 102\times 13 + 7となり、余りがちょうど 13 (商の値)ずつ減っていくことがわかるでしょう。この着目より、 N K で割った商が  p のときに余りが  1 となる  K は高々1つしかなく、 p \gt 1 のときは  r\equiv 1 (\mod p) の場合にのみ存在し、 p=1 のときは必ず存在する (K=N-1) ことがわかります。つまり、後半パートは  K を増やしつつ、  p が同じ場合はまとめて計算することで解けます。  p の範囲は  (1 \le p \lt \sqrt{N-1}) となるので、計算量は  O( \sqrt{N} ) です。

全体の計算量は  O( \sqrt{N} \log N) となります。以下がコンテスト中の実装ほぼそのままなのですが、簡単のため境界は決め打ちしています。

    void solve(istream& cin, ostream& cout) {
      ll n; cin >> n;
      ll ans = 0;
      ll s = 1e+7;
      for(ll i=2; i<min(s, n); i++) {
        ll tmp = n;
        while(tmp%i==0) {
          tmp/=i;
        }
        if (tmp%i == 1) {
          ans++;
        }
      }

      for(ll i=s; i<n; i++) {
        ll p = n/i;
        ll r = n%i;
        if (r%p == 1) {
          ans++;
        }
        if (p == 2) {
          // If p=1, K=N-1 satisfies the condition. Skipping.
          ans++;
          break;
        }
        i = n/p;
      }
      // If K=N, the condition is satisified.
      ans++;
      cout << ans << endl;
    }

コンテスト中の提出: https://atcoder.jp/contests/abc161/submissions/11541243

公式解説は頭がいい・・・

上記の実装では後半パートでも  K を増やしながら進めていましたが、さらに考察を進めると後半パートの実装はさらに簡単になって、 p を列挙しながら  N p で割った余りが  1 かどうか判定するだけでよくなります (ただし、 K がvalidかの判定は必要)

    void solve(istream& cin, ostream& cout) {
      ll n; cin >> n;
      ll ans = 0;
      ll medk = 0;
      for(ll k=2; k*k<=n; k++) {
        ll tmp = n;
        while(tmp%k==0) {
          tmp/=k;
        }
        if (tmp%k == 1) {
          ans++;
        }
        medk = k;
      }

      for(ll p=1; p*p<n; p++) {
        if (p==1 || n%p == 1) {
          ll k = (n-1)/p;
          if (2 <= k && medk < k) {
            ans++;
          }
        }
      }
      // K=N
      ans += 1;
      cout << ans << endl;
    }

ここまで考察すると  n-1ソースコード上に見えてきて、公式解説に近づいてきている気がします

実際の提出: https://atcoder.jp/contests/abc161/submissions/11558100