BOJ 10076 휴가

링크

BOJ 10076 휴가

문제 요약

0번부터 $n-1$번 도시가 일렬로 나열되어 있다. 도시와 도시 사이를 이동하는데는 하루가 걸리며, 현재 도시에 있는 모든 관광지를 보는 데에도 하루가 소모된다. 출발 도시의 번호와 총 휴가 기간 알 때, 최적으로 이동을 할 경우에 볼 수 있는 관광지 개수의 총 합의 최댓값을 구하여라. (단, 이미 본 관광지를 또 보는 경우는 개수에 추가하지 않는다.)

관찰

가장 먼저 해볼 수 있는 생각은 전체 휴가 기간 $D$일 중 $s$일을 관광에, $m$일을 이동에 사용한다고 가정하면, 특정 구간에 대한 상위 $k$개 값의 합을 구하는 문제로 환원된다는 사실이다.
편의상 시작 도시에서 왼쪽으로 $x$일을 먼저 이동하고, $m-x$일을 오른쪽으로 이동한다고 하자. 그러면 $0 \le{} x \le{} m-2x$에 대하여, 구간 [시작 도시-x, 시작 도시+m-2x]에서 가장 큰 $s$개의 원소를 뽑아 이들의 합을 구하는 문제와 동치가 된다.
따라서 구간을 고르게 되면, 고를 수 있는 도시의 개수가 결정되고 이는 배열의 임의의 구간에서 $k$번째 수를 PST를 통해 찾는 방식과 동일하게 구현 가능하며 $O(\log{N})$에 값을 구할 수 있다. 따라서 총 $O(N^2\log{N})$짜리 풀이가 완성되었음을 알 수 있다.

풀이

이제 이를 빠르게 구하는 방법에 대하여 생각해보도록 하자. 모든 구간에 대하여 값을 조사하지 않고, 후보가 될 수 없는 구간을 제외하여 더 적은 개수의 구간을 조사하는 것이 충분하다면 시간을 줄일 수 있을 것이다. 이러한 접근은 실제로 유효하며, 엄밀히 말해 다음과 같은 Lemma가 성립한다.

Lemma)
$f(l)=$구간 $[l,r]$에서 구하고자 하는 값을 최대화하는 $r$ 중에서 가장 왼쪽에 있는 $r$
$l \le{} l’$에 대하여 $f(l) \le{} f(l’)$이 성립한다. 즉, 최적의 오른쪽 끝 값이 단조성을 띈다.

증명은 다음과 같다.
(Proof)
$f(l) \le{} f(l+1)$이 성립함을 증명하도록 하자.
귀류법을 통해 증명하기 위해서 $f(l)>f(l+1)$가 성립함을 가정한다.
$f(l)$이 $[l,r]$에서 구하고자 하는 값을 최대화하는 $r$ 중에 가장 왼쪽에 있는 $r$이라고 정의했으므로, 아래 식이 성립한다. (이때, 함수 $C$는 주어진 구간에서 문제에서 구하고자 하는 값의 최댓값을 나타낸다.)

\[C[l,R]<C[l,f(l)] \hspace{1em} (l \le{} R < f(l)) \hspace{1em} \cdots{} \hspace{0.5em} (1)\]

위의 가정으로부터 $C[l+1,f(l+1)] \ge{} C[l+1,f(l)]$이고, 구간의 시작점을 $l+1$에서 $l$로 옮기는 과정을 생각해보자. $C[l+1,f(l+1)]$이 $C[l+1,f(l)]$보다 구간의 길이가 더 짧으므로 각각에 대하여 선택한 원소의 개수를 $X$와 $X’$라고 할 때, $X>X’$임이 자명하다.
또한 시작점을 $l+1$에서 $l$로 옮기게 되면, $C[l,f(l+1)]$ 값은 구간 $[l,f(l+1)]$에서 상위 $X-2$개를 고르게 되고, $C[l,f(l)]$ 값은 구간 $[l,f(l)]$에서 상위 $X’-2$개를 고르게 된다. 즉, 현재의 선택에서 값이 빠지는 현상이 나타난다. 추가로 새로운 원소 $m(l)$이 새롭게 고려되어야 하므로 경우를 나눠서 생각해보자. (이때, $m(l)$은 $l$번 도시의 관광지 개수를 의미한다.)

(i) $m(l)$이 $C[l,f(l)]$에 포함되는 경우
이 경우는 $m(l)$이 더 넓은 구간에서 더 높은 순위(상위 $X’-2$개 이내) 안에 들어있다는 뜻이므로, $C[l,f(l+1)]$에도 포함된다. 따라서 일단 $C[l+1,f(l+1)]$과 $C[l+1,f(l)]$에 모두 $m(l)$ 값을 추가하고, 각각에 대하여 선택된 원소 중 하위 3개를 제거하면 $C[l,f(l+1)]$과 $C[l,f(l)]$이 된다. 또한 이 과정에서 제거되는 값들의 합은 구간 $[l+1,f(l)]$에서 더 넓으므로, 제거되는 합이 더 크다는 것을 알 수 있다. 따라서 왼쪽 끝을 옮겨도 부등호의 방향은 바뀌지 않고, $C[l,f(l+1)] \ge{} C[l,f(l)]$이다.

(ii) $m(l)$이 $C[l,f(l+1)]$에만 포함되는 경우
이 경우는 일단 $C[l+1,f(l+1)]$에만 $m(l)$을 추가하고, 각각에 대하여 선택된 원소 중 $C[l+1,f(l+1)]$은 하위 3개를, $C[l+1,f(l)]$은 하위 2개 원소를 제거하면 $C[l,f(l+1)]$과 $C[l,f(l)]$이 된다. 또한 이 과정에서 제거되는 값의 합은 구간 $[l+1,f(l)]$에서 더 넓으므로, 제거되는 합이 더 크다는 것을 알 수 있다. 따라서 왼쪽 끝을 옮겨도 부등호의 방향은 바뀌지 않고, $C[l,f(l+1)] \ge{} C[l,f(l)]$이다.

(iii) $m(l)$이 2가지 경우에 모두 포함되지 않는 경우
이 경우는 각각에 대하여 선택된 원소 중 하위 2개를 제거하면 $C[l,f(l+1)]$과 $C[l,f(l)]$이 된다. 이 과정도 마찬가지로 제거되는 값의 합은 구간 $[l+1,f(l)]$에서 더 넓으므로, 제거되는 합이 더 크다는 것을 알 수 있다. 따라서 왼쪽 끝을 옮겨도 부등호의 방향은 바뀌지 않고, $C[l,f(l+1)] \ge{} C[l,f(l)]$이다.

(i), (ii), (iii)에서 모두 $C[l,f(l+1)] \ge{} C[l,f(l)]$가 성립함을 알 수 있고, 이는 $(1)$의 부등식에 모순이므로 귀류법에 의해 $f(l) \le{} f(l+1)$임이 증명된다.
(증명끝)

Lemma로부터 구간의 단조성이 존재하므로 DnC Optimization을 적용하여 조사할 구간의 개수를 $O(N^2)$에서 $O(N\log{N})$으로 최적화할 수 있고, 본 풀이는 좌측을 먼저 이동하는 경우에 대하여 해결하였으므로 배열을 뒤집어 한번 더 똑같은 과정을 수행하면 모든 경우에 대하여 해결이 가능함을 알 수 있다. 따라서 전체 시간 복잡도 $O(N\log^2{N})$에 문제가 해결된다. 아래는 구현에 사용된 코드이다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include<bits/stdc++.h>
#define rs resize

using namespace std;
typedef long long ll;

struct A{
    int x, y, v;
};

bool cmp1(const A &lhs, const A &rhs){
    return lhs.x<rhs.x;
}

bool cmp2(const A &lhs, const A &rhs){
    return lhs.v>rhs.v;
}

struct PST{
    vector<int> cnt, left, right, root;
    vector<ll> sum;
    int base, t;
    PST(int w, int rt, int up){
        int h=1; base=t=1;
        while(base<w) base<<=1, h++;
        int x=base*2+up*h+5;
        cnt.rs(x); sum.rs(x);
        left.rs(x); right.rs(x);
        root.rs(rt+5);
        root[0]=setUp(1,base);
    }
    int setUp(int l, int r){
        int cur=t++;
        if(l<r){
            int mid=(l+r)>>1;
            left[cur]=setUp(l,mid);
            right[cur]=setUp(mid+1,r);
        }
        return cur;
    }
    void updateKthTree(int k, int idx, int val){
        if(!root[k]) root[k]=root[k-1];
        root[k]=addNode(root[k],idx,val,1,base);
    }
    int addNode(int prev, int idx, int val, int l, int r){
        if(r<idx || idx<l) return prev;
        int cur=t++;
        if(l==r) cnt[cur]=1, sum[cur]=val;
        else{
            int mid=(l+r)>>1;
            left[cur]=addNode(left[prev],idx,val,l,mid);
            right[cur]=addNode(right[prev],idx,val,mid+1,r);
            l=left[cur], r=right[cur];
            sum[cur]=sum[l]+sum[r];
            cnt[cur]=cnt[l]+cnt[r];
        }
        return cur;
    }
    ll qry(int i1, int i2, int k){
        return qry(root[i1],root[i2],k,1,base);
    }
    ll qry(int i1, int i2, int k, int l, int r){
        if(k<=0) return 0LL;
        if(cnt[i1]-cnt[i2]<=k) return sum[i1]-sum[i2];
        int mid=(l+r)>>1;
        int x=cnt[left[i1]]-cnt[left[i2]];
        return qry(left[i1],left[i2],k,l,mid)+qry(right[i1],right[i2],k-x,mid+1,r);
    }
};

int N, S1, S2, D, S;
A M1[100005], M2[100005];
A *M; PST *pst;
ll ans;

void solve(int l, int r, int optl, int optr)
{
    if(l>r) return;
    int j=(l+r)>>1, opt=optl; ll v=0;
    for(int i=max(optl,S) ; i<=optr ; i++){
        ll val=pst->qry(i,j-1,D-(i-j)-(S-j));
        if(v<val) v=val, opt=i;
    }
    ans=max(ans,v);
    if(v) solve(l,j-1,optl,opt);
    solve(j+1,r,opt,optr);
}

int main()
{
    scanf("%d %d %d",&N,&S1,&D);
    S1++; S2=N-S1+1;
    for(int i=1 ; i<=N ; i++){
        scanf("%d",&M1[i].v);
        M1[i].x=i;
        M2[N-i+1].v=M1[i].v;
        M2[N-i+1].x=N-i+1;
    }
    sort(M1+1,M1+1+N,cmp2);
    sort(M2+1,M2+1+N,cmp2);
    for(int i=1 ; i<=N ; i++) M1[i].y=M2[i].y=i;
    sort(M1+1,M1+1+N,cmp1);
    sort(M2+1,M2+1+N,cmp1);
    
    M=M1, S=S1; pst=new PST(N,N,N);
    for(int i=1 ; i<=N ; i++) pst->updateKthTree(i,M[i].y,M[i].v);
    solve(1,S,S,N);
    delete pst;
    
    M=M2, S=S2; pst=new PST(N,N,N);
    for(int i=1 ; i<=N ; i++) pst->updateKthTree(i,M[i].y,M[i].v);
    solve(1,S,S,N);
    delete pst;
    printf("%lld",ans);
    return 0;
}