メモ帳がわり

個人的なメモを残します。主に競プロ

【ALPC】Lazy Segment Tree

問題のリンク

本当はAtCoder Libraryを使って解くことが想定されている問題ですが、全く使わず解きました。
とりあえずなんでもいいので、遅延評価セグメント木の練習がしたかったので。

転倒数について

まず、これがわからないと問題が解けません。
一応問題文でも説明されていますが、念のため。

数列を左から見て行ったときに、自分より左に自分より大きい数が何個あるか?をそれぞれの要素について求めた和が転倒数です。
$ A = (0, 1, 1, 0, 0) $という数列があったときに、$ 1 ≦ i < j ≦ 5, A_i > A_j $を満たすi, jの組の数が転倒数であり、$(i, j) = (2, 4), (2, 5), (3, 4), (3, 5) $がこれを満たすi, jの組になります。

何を2分木に保持するか

まず、どうすれば特定の区間の転倒数を求めることができるかを考えます。
求めたい区間が、木の上で2つの区間に跨っているとき、2つの区間それぞれの転倒数と2つの区間に跨がる転倒数を分けて考えることにします。
木の区間の転倒数は、木に保持しておくとして、跨がる場合を考えるとこれは、「左の区間にある1の数」× 「右の区間にある0の数」になります。
これは、区間が2つより多い場合でも同じように考えることができます。
よって、2分木には区間に含まれる「0の数」、「1の数」、「転倒数」を保持することになります。

更新クエリ

今回の問題では、数列の要素は0か1しか取りません。
また、範囲が与えられてその範囲内にある0と1が入れ替える更新クエリが与えられます。
同じ範囲の数列が、二回更新クエリの対象になったとき、元の数列に戻ることに注意します。
よって、遅延配列にはその区間を反転すべきかどうかのフラグを保持することにしてみます。こうすれば、更新があったときはこのフラグを反転することだけ考えれば良くなります。

evalの実装

0と1の数はswapすればいいだけです。
問題は転倒数ですが、これは「0の数」× 「1の数」- 「現在の転倒数」になります。
まず転倒数になりかねないのは、0と1の組み合わせのみです。これの数は、もちろん「0の数」× 「1の数」ですが、この区間の0と1を反転したときに転倒数になるのは現在転倒数になっていない0と1の組み合わせになります。これは、反転したときにそれぞれの0と1の組み合わせの位置が逆転する、と考えることができるからです。
よって、上の式になります。
また、子に伝播させるときは子の更新フラグを反転させます。

// 0: 0の数, 1: 1の数, 2: 転倒数
vector<tuple<ll, ll, ll>> node; // 完全2分木

// 中略

void eval(int k) {
    if (lazy[k]) { // lazy[k]が1の場合、伝播&更新の反映
        swap(get<0>(node[k]), get<1>(node[k])); // 0と1を入れ替え
        // 転倒数を計算
        get<2>(node[k]) = get<0>(node[k]) * get<1>(node[k]) - get<2>(node[k]); 
        if (k < n-1) { // 子に伝搬
            lazy[k*2+1] = 1-lazy[k*2+1]; // フラグを反転
            lazy[k*2+2] = 1-lazy[k*2+2];
        }
        lazy[k] = 0;
    }
}

最初は、下のように実装していたのですが、これだとまだ更新されていないnodeにアクセスしかねないのでまずいです。
eval関数の中では、nodeに直接アクセスしないようにしましょう。

if (k < n-1) { // 子に伝搬
    get<0>(lazy[k*2+1]) = get<1>(node[k*2+1]);
    get<1>(lazy[k*2+1]) = get<0>(node[k*2+1]);
    get<2>(lazy[k*2+1]) = get<0>(node[k*2+1]) * get<1>(node[k*2+1]) - get<2>(node[k*2+1]);


    get<0>(lazy[k*2+2]) = get<1>(node[k*2+2]);
    get<1>(lazy[k*2+2]) = get<0>(node[k*2+2]);
    get<2>(lazy[k*2+2]) = get<0>(node[k*2+2]) * get<1>(node[k*2+2]) - get<2>(node[k*2+2]);
}

検索クエリ

セグメント木を検索するときは、左の区間から順に見ていくことになります。(実装を変えれば逆にもできる)
この前提を元にどうすれば特定の区間の転倒数を求めることができるかを考えます。
木に保持されている区間については、それぞれの転倒数が既に求めてあるので、これは答えに直接足せばよいです。
後は複数の区間に跨っている転倒数についてですが、これは上で既に考えたように、「左の区間にある1の数」× 「現在の区間にある0の数」になります。
なので、今までに見た区間に含まれる0と1の数をそれぞれ保持して検索することにします。
↓実装です、変な変数名を使っていますがイメージです。(あまり良くないですね)

// white = 0の数
// black = 1の数
void find_query(int s, int t, int l, int r, int n, long long &ans, long long &black, long long &white) {
    eval(n);
    if (r <= s || t <= l) return; // 範囲外なら終了
    // [s, t)が[l, r)を内包しているとき
    else if (s <= l && t >= r) {
        ans += black * get<0>(node[n]);
        white += get<0>(node[n]);
        black += get<1>(node[n]);
        ans += get<2>(node[n]);
    } else {
        // (r+l)/2は区間の中心, 区間の中心を左端にするか、右端にするかで分岐する
        find_query(s, t, l, (r+l)/2, n*2+1, ans, black, white); // 左下の子を探索
        find_query(s, t, (r+l)/2, r, n*2+2, ans, black, white);  // 右下の子を探索
    }
}

全体の実装

struct LazySegTree {
    int n; // 葉の数
    // 0: 0の数, 1: 1の数, 2: 転倒数
    vector<tuple<ll, ll, ll>> node; // 完全2分木
    vector<bool> lazy; // 遅延配列

    // 初期化
    LazySegTree(vector<long long> v) {
        int sz = v.size();
        n = 1;
        while (n < sz) n *= 2; // 与えられた数列の項数以上の2^n個、葉を作る
        node.resize(n*2-1,  {0, 0, 0});
        lazy.resize(n*2-1, 0);

        for (int i = 0; i < sz; i++) {
            if (v[i]) get<1>(node[n-1+i]) = 1;
            else get<0>(node[n-1+i]) = 1;
        }
        // 下から順に葉以外のnodeを初期化
        for (int i = n-2; i >= 0; i--) {
            get<0>(node[i]) = get<0>(node[i*2+1]) + get<0>(node[i*2+2]);
            get<1>(node[i]) = get<1>(node[i*2+1]) + get<1>(node[i*2+2]);
            get<2>(node[i]) = get<2>(node[i*2+1]) + get<2>(node[i*2+2]) + get<1>(node[i*2+1]) * get<0>(node[i*2+2]);
        }
    }

    void eval(int k) {
        if (lazy[k]) { // lazy[k]が1の場合、伝播&更新の反映
            swap(get<0>(node[k]), get<1>(node[k]));
            get<2>(node[k]) = get<0>(node[k]) * get<1>(node[k]) - get<2>(node[k]);
            if (k < n-1) { // 子に伝搬
                lazy[k*2+1] = 1-lazy[k*2+1];
                lazy[k*2+2] = 1-lazy[k*2+2];
            }
            lazy[k] = 0;
        }
    }

    // クエリ処理
    // [s, t)を探す
    // 再帰的に探索するために呼び出す側を別の関数に
    void find(int s, int t, long long &ans) {
        long long black = 0;
        long long white = 0;
        find_query(s, t, 0, n, 0, ans, black, white);
    }

    void find_query(int s, int t, int l, int r, int n, long long &ans, long long &black, long long &white) {
        eval(n); // 普通のセグメント木と違うのはここだけ
        if (r <= s || t <= l) return; // 範囲外なら終了
        // [s, t)が[l, r)を内包しているとき
        else if (s <= l && t >= r) {
            ans += black * get<0>(node[n]);
            white += get<0>(node[n]);
            black += get<1>(node[n]);
            ans += get<2>(node[n]);
        } else {
            // (r+l)/2は区間の中心, 区間の中心を左端にするか、右端にするかで分岐する
            find_query(s, t, l, (r+l)/2, n*2+1, ans, black, white); // 左下の子を探索
            find_query(s, t, (r+l)/2, r, n*2+2, ans, black, white);  // 右下の子を探索
        }
    }


    void update(int s, int t) { update_query(s, t, 0, n, 0); }

    void update_query(int s, int t, int l, int r, int n) {
        eval(n);
        if (r <= s || t <= l) return; // 範囲外なら終了
        // [s, t)が[l, r)を内包しているとき
        else if (s <= l && t >= r) {
            lazy[n] = 1-lazy[n];
            eval(n);
        } else {
            // (r+l)/2は区間の中心, 区間の中心を左端にするか、右端にするかで分岐する
            update_query(s, t, l, (r+l)/2, n*2+1); // 左下の子を更新
            update_query(s, t, (r+l)/2, r, n*2+2);  // 右下の子を更新

            get<0>(node[n]) = get<0>(node[n*2+1]) + get<0>(node[n*2+2]);
            get<1>(node[n]) = get<1>(node[n*2+1]) + get<1>(node[n*2+2]);
            get<2>(node[n]) = get<2>(node[n*2+1]) + get<2>(node[n*2+2]) + get<1>(node[n*2+1]) * get<0>(node[n*2+2]);
        }
    }

    // デバック用
    void output() {
        for (int i = 0; i < n*2-1; i++) cout << get<0>(node[i]) << " " << get<1>(node[i]) << " " << get<2>(node[i]) << " | ";
        cout << endl;
    }

};

void solve(long long N, long long Q, std::vector<long long> A, std::vector<long long> T, std::vector<long long> L, std::vector<long long> R) {
    auto segtree = LazySegTree(A);
    REP (i, Q) {
        if (T[i] == 1) {
            segtree.update(L[i]-1, R[i]);
        } else {
            long long ans = 0;
            segtree.find(L[i]-1, R[i], ans);
            cout << ans << endl;
        }
    }
}

int main(){
    long long N;
    scanf("%lld",&N);
    long long Q;
    scanf("%lld",&Q);
    std::vector<long long> A(N);
    for(int i = 0 ; i < N ; i++){
        scanf("%lld",&A[i]);
    }
    std::vector<long long> T(Q);
    std::vector<long long> L(Q);
    std::vector<long long> R(Q);
    for(int i = 0 ; i < Q ; i++){
        scanf("%lld",&T[i]);
        scanf("%lld",&L[i]);
        scanf("%lld",&R[i]);
    }
    solve(N, Q, std::move(A), std::move(T), std::move(L), std::move(R));
    return 0;
}

感想

AOJの問題は基本的なものだったので、この問題には結構苦戦しました。
ACL使わずにACしましたが、いい練習になったと思います。