• Home
  • About
    • Eyedicamp 개발 이야기 photo

      Eyedicamp 개발 이야기

      Big Data, Machine Learning, AI 등의 다양한 이야기를 하는 곳.

    • Learn More
    • Email
    • Instagram
    • Github
  • Posts
    • All Posts
    • All Tags
  • Projects

[BOJ 24060] 알고리즘 수업 - 병합 정렬 1

09 Feb 2026

Reading time ~6 minutes

문제 요약

병합 정렬(merge sort)로 배열 A를 오름차순 정렬할 때, 배열 A에 값이 저장되는 순간(= A[i] <- tmp[t])을 카운트해서
K번째로 저장되는 수를 출력한다. 저장 횟수가 K보다 작으면 -1.

  • N ≤ 500,000
  • K ≤ 10^8

병합 정렬(Merge Sort) 핵심 개념 정리 (GPT 도움으로 이해)

이번 문제는 병합 정렬을 “그냥 정렬”로만 아는 수준이면 바로 구현하기가 좀 까다로웠다.
특히 merge()에서 tmp를 만들고, 다시 A[p..r]에 덮어쓰는 과정이 “저장 횟수”로 카운트되기 때문이다.

1) 큰 그림: 분할 → 정복 → 병합

  • merge_sort(A[p..r])
    1) 가운데 q = (p + r) // 2로 구간을 반으로 쪼갠다
    2) 왼쪽 A[p..q] 정렬
    3) 오른쪽 A[q+1..r] 정렬
    4) 두 정렬된 구간을 merge(A, p, q, r)로 합친다

2) merge에서 쓰는 포인터(변수) 역할

merge(A, p, q, r)는 이미 정렬된 두 구간

  • 왼쪽: A[p..q]
  • 오른쪽: A[q+1..r]

을 합쳐서 A[p..r]을 정렬 상태로 만든다.

  • i: 왼쪽 구간 포인터 (처음 p)
  • j: 오른쪽 구간 포인터 (처음 q+1)
  • t: 임시 배열 tmp에 값을 채워 넣는 포인터
  • tmp: 병합 결과를 잠깐 저장하는 임시 배열

동작 방식:

1) A[i]와 A[j]를 비교해서 작은 값을 tmp[t]에 넣고 포인터를 증가
2) 한쪽이 끝나면 남은 쪽을 tmp에 그대로 복사
3) 마지막에 tmp를 A[p..r]에 덮어쓰며 정렬 구간 완성

이 문제의 저장 횟수는 3)에서 A[idx] = tmp[idx]가 실행될 때마다 1씩 증가한다.


내가 작성한 코드 (정답)

import sys

input = sys.stdin.readline

def merge_sort(a, p, r, tmp, out):
    if p >= r:
        return
    q = (p + r) // 2

    merge_sort(a, p, q, tmp, out)
    merge_sort(a, q + 1, r, tmp, out)
    merge(a, p, q, r, tmp, out)

def merge(a, p, q, r, tmp, out):
    i, j = p, q + 1
    t = p

    while i <= q and j <= r:
        if (a[i] <= a[j]):
            tmp[t] = a[i]
            i += 1
        else:
            tmp[t] = a[j]
            j += 1
        t += 1

    while (i <= q):
        tmp[t] = a[i]
        i += 1
        t += 1

    while (j <= r):
        tmp[t] = a[j]
        j += 1
        t += 1

    for idx in range(p, r + 1):
        a[idx] = tmp[idx]
        out.append(tmp[idx])

n, k = map(int, input().split())
a = list(map(int, input().split()))

tmp = [0] * n
out = []

merge_sort(a, 0, n - 1, tmp, out)

if len(out) < k:
    print(-1, end='')
else:
    print(out[k - 1], end='')

개선 포인트

위 코드의 아이디어는 명확하지만, out에 “저장되는 모든 값”을 다 쌓는 방식은

  • 최악에 가까운 입력에서 저장 횟수(≈ N log N)가 수백만~천만 단위가 될 수 있어
  • 파이썬에서는 out 리스트가 커지면서 메모리 사용이 크게 늘 수 있다.

이 문제는 사실 K번째 값만 필요하므로:

  • 저장 횟수를 세면서
  • K번째가 되는 순간에만 답을 기록
  • 이후에는 더 이상 계산/저장을 진행하지 않도록(가능하면) 하는 방식이 더 안전하다.

개선된 코드 (추천): “K번째만” 카운트해서 찾기

import sys
input = sys.stdin.readline

n, k = map(int, input().split())
a = list(map(int, input().split()))
tmp = [0] * n

count = 0
ans = -1

def merge_sort(p, r):
    global ans
    if p >= r or ans != -1:
        return
    q = (p + r) // 2
    merge_sort(p, q)
    merge_sort(q + 1, r)
    merge(p, q, r)

def merge(p, q, r):
    global count, ans
    i, j, t = p, q + 1, p

    while i <= q and j <= r:
        if a[i] <= a[j]:
            tmp[t] = a[i]
            i += 1
        else:
            tmp[t] = a[j]
            j += 1
        t += 1

    while i <= q:
        tmp[t] = a[i]
        i += 1
        t += 1

    while j <= r:
        tmp[t] = a[j]
        j += 1
        t += 1

    for idx in range(p, r + 1):
        a[idx] = tmp[idx]
        count += 1
        if count == k:
            ans = a[idx]
            return

merge_sort(0, n - 1)
print(ans)

이 코드에서 배울 수 있는 점

  • out.append(...)로 전부 저장하지 않고도, 카운터만으로 K번째 값을 찾을 수 있다.
  • 문제에서 말하는 “저장”은 tmp에 넣는 순간이 아니라 A[p..r]에 덮어쓰는 순간이다.
  • merge_sort()/merge()에 ans != -1 체크를 두면, 답을 찾은 뒤 불필요한 작업을 줄일 수 있다.
  • tmp는 보통 tmp = [0] * n처럼 한 번만 만들어서 재사용한다. (병합 때마다 새로 만들 필요 없음)
  • 병합 정렬의 시간 복잡도는 O(N log N), 추가 메모리는 O(N)이다.



pythonmerge-sortrecursiondivide-and-conquersysstdin Share Tweet +1