문제 요약
병합 정렬(merge sort)로 배열 A를 오름차순 정렬할 때, 배열 A에 값이 저장되는 순간(= A[i] <- tmp[t])을 카운트해서
K번째로 저장되는 수를 출력한다. 저장 횟수가 K보다 작으면 -1.
N ≤ 500,000K ≤ 10^8
병합 정렬(Merge Sort) 핵심 개념 정리 (GPT 도움으로 이해)
이번 문제는 병합 정렬을 “그냥 정렬”로만 아는 수준이면 바로 구현하기가 좀 까다로웠다.
특히 merge()에서 tmp를 만들고, 다시 A[p..r]에 덮어쓰는 과정이 “저장 횟수”로 카운트되기 때문이다.
1) 큰 그림: 분할 → 정복 → 병합
merge_sort(A[p..r])
1) 가운데q = (p + r) // 2로 구간을 반으로 쪼갠다
2) 왼쪽A[p..q]정렬
3) 오른쪽A[q+1..r]정렬
4) 두 정렬된 구간을merge(A, p, q, r)로 합친다
2) merge에서 쓰는 포인터(변수) 역할
merge(A, p, q, r)는 이미 정렬된 두 구간
- 왼쪽:
A[p..q] - 오른쪽:
A[q+1..r]
을 합쳐서 A[p..r]을 정렬 상태로 만든다.
i: 왼쪽 구간 포인터 (처음p)j: 오른쪽 구간 포인터 (처음q+1)t: 임시 배열tmp에 값을 채워 넣는 포인터tmp: 병합 결과를 잠깐 저장하는 임시 배열
동작 방식:
1) A[i]와 A[j]를 비교해서 작은 값을 tmp[t]에 넣고 포인터를 증가
2) 한쪽이 끝나면 남은 쪽을 tmp에 그대로 복사
3) 마지막에 tmp를 A[p..r]에 덮어쓰며 정렬 구간 완성
이 문제의 저장 횟수는 3)에서 A[idx] = tmp[idx]가 실행될 때마다 1씩 증가한다.
내가 작성한 코드 (정답)
import sys
input = sys.stdin.readline
def merge_sort(a, p, r, tmp, out):
if p >= r:
return
q = (p + r) // 2
merge_sort(a, p, q, tmp, out)
merge_sort(a, q + 1, r, tmp, out)
merge(a, p, q, r, tmp, out)
def merge(a, p, q, r, tmp, out):
i, j = p, q + 1
t = p
while i <= q and j <= r:
if (a[i] <= a[j]):
tmp[t] = a[i]
i += 1
else:
tmp[t] = a[j]
j += 1
t += 1
while (i <= q):
tmp[t] = a[i]
i += 1
t += 1
while (j <= r):
tmp[t] = a[j]
j += 1
t += 1
for idx in range(p, r + 1):
a[idx] = tmp[idx]
out.append(tmp[idx])
n, k = map(int, input().split())
a = list(map(int, input().split()))
tmp = [0] * n
out = []
merge_sort(a, 0, n - 1, tmp, out)
if len(out) < k:
print(-1, end='')
else:
print(out[k - 1], end='')
개선 포인트
위 코드의 아이디어는 명확하지만, out에 “저장되는 모든 값”을 다 쌓는 방식은
- 최악에 가까운 입력에서 저장 횟수(≈
N log N)가 수백만~천만 단위가 될 수 있어 - 파이썬에서는
out리스트가 커지면서 메모리 사용이 크게 늘 수 있다.
이 문제는 사실 K번째 값만 필요하므로:
- 저장 횟수를 세면서
- K번째가 되는 순간에만 답을 기록
- 이후에는 더 이상 계산/저장을 진행하지 않도록(가능하면) 하는 방식이 더 안전하다.
개선된 코드 (추천): “K번째만” 카운트해서 찾기
import sys
input = sys.stdin.readline
n, k = map(int, input().split())
a = list(map(int, input().split()))
tmp = [0] * n
count = 0
ans = -1
def merge_sort(p, r):
global ans
if p >= r or ans != -1:
return
q = (p + r) // 2
merge_sort(p, q)
merge_sort(q + 1, r)
merge(p, q, r)
def merge(p, q, r):
global count, ans
i, j, t = p, q + 1, p
while i <= q and j <= r:
if a[i] <= a[j]:
tmp[t] = a[i]
i += 1
else:
tmp[t] = a[j]
j += 1
t += 1
while i <= q:
tmp[t] = a[i]
i += 1
t += 1
while j <= r:
tmp[t] = a[j]
j += 1
t += 1
for idx in range(p, r + 1):
a[idx] = tmp[idx]
count += 1
if count == k:
ans = a[idx]
return
merge_sort(0, n - 1)
print(ans)
이 코드에서 배울 수 있는 점
out.append(...)로 전부 저장하지 않고도, 카운터만으로 K번째 값을 찾을 수 있다.- 문제에서 말하는 “저장”은
tmp에 넣는 순간이 아니라A[p..r]에 덮어쓰는 순간이다. merge_sort()/merge()에ans != -1체크를 두면, 답을 찾은 뒤 불필요한 작업을 줄일 수 있다.tmp는 보통tmp = [0] * n처럼 한 번만 만들어서 재사용한다. (병합 때마다 새로 만들 필요 없음)- 병합 정렬의 시간 복잡도는
O(N log N), 추가 메모리는O(N)이다.