Segment tree를 이해하기 위해서 구간 합을 구하는 문제를 생각해 보자:
- list = [1,2,3,4,5,6,7,8,9,10] (0번째 값 = 1, 1번째 값 = 2,...)
- l번째 index에서 r번째 index까지 구간 합을 구하는 문제
이 문제는 여러 가지 방법으로 풀 수 있는데 우선 다음과 같은 방법들을 생각해 보자:
- 무식하게 list[l] + ... + list[r] 를 더해서 구하는 방법
- 예전에 배웠던 prefix sum을 사용하는 방법 ([알고리즘] 구간 합 (+LeetCode 1314))
2번 같은 경우 효율적이라서 이 문제를 푸는데 적합하다.
이제 주어진 list에 대해서 다양한 구간의 합을 구하는 도중, list의 특정 값을 업데이트를 하고 싶다고 하자. 예를 들어서 2번째 값인 3을 11로 바꾸고 싶다면 어떻게 할까? (list -> [1,2,11,4,5,6,7,8,9,10])
만약 prefix sum을 사용해서 구간 합을 구하려고 하면 2번째 sum부터 (list[0] + ... + list[2]) 9번째 sum (list[0] + ... + list[9])까지 모두 바꿔줘야 해서 비효율적이다. 이 경우 구간 합을 구하는 시간 복잡도는 O(1)이지만 업데이트를 하는 시간 복잡도는 O(n)이다.
이런 문제는 segment tree를 사용하면 더욱 효율적으로 풀 수 있다. segment tree를 사용하면 구간 합의 시간 복잡도는 O(log(n) )이라서 prefix sum 대비 비효율적이지만 업데이트하는 시간 복잡는 O(log(n))이기 때문에 업데이트가 잦은 경우 segment tree를 사용하면 이득을 볼 수 있다.
일반적으로 segment tree를 만드는 방법을 소개하기 전에 위에서 소개한 예시에서 segment tree가 어떻게 생겼는지 살펴보자:
- 이진 트리 (binary tree) 구조를 갖는다 (각 parent node는 2개의 child noder가 있다)
- leaf node (가장 말단에 있는 노드)들은 회색으로 칠했는데, 이들은 list에 있는 값들과 같다.
- 자식 노드 (child node)의 값들을 더하면 부모 노드 (parent node)의 값이 된다
- 노드 옆 하늘색 상자에 있는 숫자는 노드의 번호다 (index)
segment tree가 있으면 구간 합은 다음과 같이 구할 수 있다. 예를 들어 list[4]+...+list[7]을 구하려면 segment tree의 10번 노드와 5번 노드의 값을 더하면 된다 (5 + 21 = 5 + 6 + 7 + 8 = 26).
이제 체계적으로 segment tree를 살펴보자. 우선 어떻게 주어진 list에서 segment tree를 만들 수 있는지 살펴보자.
- 예시에서 segment tree의 0번 노드 같은 경우 list[0] + ... + list[9]
- 1번 노드는 list[0] + ... + list[4]
- 2번 노드는 list[5] + ... + list[9]
- 3번 노드는 list[0] + list[1] + list[2]
- 4번 노드는 list[3] + list[4]
- 등등
으로 정의된다. 보다시피 다음과 같은 규칙이 있다: parent node의 값을 구하기 위한 list를 반으로 나눠서 각각 왼쪽, 오른쪽 child node에 할당한다. 할당된 list의 값들을 더하면 child node의 값을 구할 수 있게 된다. 예를 들어서 0번째 노드는 list의 [1,2,...,10]를 더해서 얻을 수 있다 (=55). 이 리스트를 반으로 쪼개서 [1,2,3,4,5]를 왼쪽 child node에 (값들을 더하면 15), [6,7,8,9,10]를 오른쪽 child node에 (값들을 더하면 40) 할당한다.
이 규칙에 따라서 tree의 깊이는 다음과 같이 얻을 수 있다 (참고로 노드가 1개인 경우 깊이는 0으로 정의한다): tree의 높이 h = \( \textrm{ceil} \log_2 n\)
높이가 h인 완전 이진 트리를 (완전 이진 트리: leaf node가 모두 같은 깊이에 있는 이진 트리) 만들 때 필요한 공간을 확보해 두고 작업하는 것이 편하기 때문에, 필요한 공간 복잡도는 \( 2^0 + ... + 2^h = 2\times 2^h-1 \)가 된다. 따라서 필요한 공간은 \( 2 \times 2^h-1 \)는 2n -1 혹은 4n-1이다 (이 값을 계산하기 싫으면 비효율적이지만 넉넉하게 4n으로 잡아도 된다)
이제 segment tree를 만드는 코드를 살펴보자:
from math import ceil, log2
# example list
nums = [1,2,3,4,5,6,7,8,9,10]
# number of elements
n = len(nums)-1
# height of segment tree
h = int((ceil(log2(n))))
# size of tree
size_tree = 2*int(2**h)-1
# segment tree
segment_tree = [0] * size_tree
def build_segment_tree(l,r,idx):
"""
l: leftmost index of (sub) list
r: rightmost index of (sub) list
idx: index of current node
"""
if l==r:
segment_tree[idx] = nums[l]
return segment_tree[idx]
mid = l+(r-l)//2
segment_tree[idx] = build_segment_tree(l,mid,idx*2+1) + build_segment_tree(mid+1,r,idx*2+2)
return segment_tree[idx]
build_segment_tree(0,9,0)
print(segment_tree)
# 출력값: [55, 15, 40, 6, 9, 21, 19, 3, 3, 4, 5, 13, 8, 9, 10, 1, 2, 0, 0, 0, 0, 0, 0, 6, 7, 0, 0, 0, 0, 0, 0]
- build_segment_tree의 코드는 재귀함수로 구현했다.
- build_segment_tree 함수는 가장 위에 있는 노드부터 (0번째 노드) 차례차례 값을 구하는 방식으로 생각하면 편하다
- build_segment_tree 함수를 부를 때 nums의 가장 왼쪽 index (=0), 가장 오른쪽 index (=9), tree index (=0)를 argument로 받는다.
- 이것의 의미는 nums의 0번째 index부터 9번째 index까지 합이 tree의 0번째 node의 값이 된다는 것이다 (i.e.[1,2,...,10]의 합).
- 이유: segment_tree[0] = build_segment_tree[0,4,1] + build_segment_tree[5,9,2]
- 여기서 build_segment_tree[0,4,1]는 nums의 0~4번째 index 값들의 합이 된다 (nums 왼쪽 반쪽의 합 = [1,2,3,4,5]의 합)
- build_segment_tree[5,9,2]는 nums의 5~9번째 값들의 합이 된다 (nums 오른쪽 반쪽의 합 = [6,7,8,9,10]의 합)
- build_segment_tree[0,4,1]같은 경우 위에서 설명한 것과 비슷하게 nums의 0번째부터 4번째까지 값들로 만들어진 리스트 (i.e. [1,2,3,45])의 왼쪽 반쪽과 (i.e. [1,2,3]) 오른쪽 반쪽(i.e. [4,5])의 합으로 나눠서 계산을 하게 되고, 이 값을 segment_tree[1]에 저장한다
- 참고로 i번째 노드는 2개의 child node가 있고, 2개의 child node의 index는 idx*2+1, idx*2+2가 된다.
- 따라서 0번째 노드의 child node들의 index는 각각 1, 2가 된다; 1번째 노드의 child node들의 index는 각각 3,4가 된다
- 다른 값들도 비슷하게 계산된다. 재귀적으로 리스트를 반으로 쪼개다가, 리스트에 1개의 값만 남게 되면 leaf node에 도달했다는 뜻이기 때문에, 이 값을 segment tree에 넣는다 (if l==r 관련 코드)
이제 특정 interval에 있는 수들의 합을 구하는 코드
def sum_interval(s,e,l,r,idx):
"""
for summing values in an interval
s: starting index of interval
e: ending index of interval
"""
if s > r or e < l:
return 0
if s <= l and e >= r:
return segment_tree[idx]
m = l + (r-l)//2
return sum_interval(s,e,l,m,idx*2+1)+sum_interval(s,e,m+1,r,idx*2+2)
- 아이디어는 build_segment_tree와 비슷하다. 이전과 비슷하게 idx는 현제 segment tree의 노드 index를 나타낸다. idx번째 노드의 값은 nums의 l에서 r (포함) 사이에 있는 값들을 더해서 얻어진다. 이해를 돕기 위해서 몇 가지 예를 살펴보자
- nums의 0번째부터 9번째(포함) 사이에 있는 수를 모두 더하려면 sum_interval(0,9,0,9,0))을 호출하면 된다. 두 번째 if문의 조건을 만족하기 때문에 segment_tree[0]의 값을 리턴하게 되는데, 이 값은 55다.
- nums의 0번째부터 4번째 (포함) 사이에 있는 수를 모두 더하려면 sum_interval(0,4,0,9,0)을 호출하면 된다. m=4가 되고, sum_interval(0,4,0,4,1) & sum_interval(0,4,5,9,2)를 호출하게 된다. 전자는 segment_tree[1] = 15를 리턴하고 후자는 0을 리턴한다. 최종적으로 15+0=15를 리턴한다.
- 보다시피 segment tree의 0번째 노드에서 출발한다. 현제 노드가 나타내는 interval이 구하고자 하는 interval [s,e]와 겹치지 않으면 (첫번째 if 문) 0을 리턴한다. 현제 노드가 나타내는 interval이 [s,e]에 포함되면 현제 노드의 값을 리턴한다. 둘 다 아니면 (다시 말해서 현제 노드가 나타내는 interval이 [s,e]와 겹치는 부분이 있으면) 현제 노드의 child node에 대해서 위에 설명한 과정을 반복한다
- segment tree의 깊이가 O(log(n))이기 때문에 구간 합의 시간 복잡도도 O(log(n))이 된다
이제 특정 값을 업데이트하는 코드를 보자.
def update(update_idx,delta_val,l,r,idx):
"""
update_idx: index in nums that is to be updated
delta_val: by how much to change the value num[idx]
l, r, idx: as before
"""
if update_idx < l or update_idx > r:
return
segment_tree[idx] += delta_val
if l==r:
return
m = l + (r-l)//2
update(update_idx,delta_val,l,m,idx*2+1)
update(update_idx,delta_val,m+1,r,idx*2+2)
- update_idx는 nums에서 바꾸고 싶은 index다. 그리고 이 값을 delta_val만큼 바꾸게 된다. 예를 들어 2번째 값을 3에서 11로 바꾸려면 (nums -> [1,2,11,4,5,6,7,8,9,10]) update(2,8,0,9,0)을 호출하면 된다 (8 = 11 - nums[2] = 11 - 3 )
- segment_tree의 각 노드는 nums에서 일정 범위의 값들을 더하기 때문에 update_idx가 이 범위에 들어가면 delta_val만큼 바꿔준다고 이해하면 된다.
- 0번째 노드는 모든 값을 더하기 때문에 55 -> 63
- 1번째 노드는 0~4번째 값들을 더하기 때문에 (index 2를 포함) 15 -> 23
- 2번째 노드는 5~9번째 값들을 더하기 때문에 (index 2를 미포함) 40
- 3번째 노드는 1~3번째 값들을 더하기 때문에 (index 2를 포함) 6->14
- 등등
참고로 segment tree가 사용되는 알고리즘에서 대표적으로 lazy propagation이 있는데, 같이 보면 좋다. 예전에 코딩테스트에서 나온 적이 있지만 자주 등장하지 않는 것으로 알고 있다.
참고 자료:
'알고리즘' 카테고리의 다른 글
[알고리즘] 플로이드 워셜 알고리즘 (+LeetCode 1335) (0) | 2023.04.16 |
---|---|
[알고리즘] 크루스칼 알고리즘 (+LeetCode 1584) (0) | 2023.04.07 |
[알고리즘] 유니온 파인드 (0) | 2023.04.07 |
[알고리즘] 다이나믹 프로그래밍 (문제편 2: 백준 2098 외판원 순회) (0) | 2023.04.02 |
[알고리즘] 다이나믹 프로그래밍 (문제편 1: 배낭 문제) (0) | 2023.04.01 |