알고리즘

[알고리즘] Segment Tree

curious_cat 2023. 4. 30. 00:45
728x90
728x90

Segment tree를 이해하기 위해서 구간 합을 구하는 문제를 생각해 보자:

  • list = [1,2,3,4,5,6,7,8,9,10] (0번째 값 = 1, 1번째 값 = 2,...)
  • l번째 index에서 r번째 index까지 구간 합을 구하는 문제

이 문제는 여러 가지 방법으로 풀 수 있는데 우선 다음과 같은 방법들을 생각해 보자:

  1. 무식하게 list[l] + ... + list[r] 를 더해서 구하는 방법
  2. 예전에 배웠던 prefix sum을 사용하는 방법 ([알고리즘] 구간 합 (+LeetCode 1314))

2번 같은 경우 효율적이라서 이 문제를 푸는데 적합하다. 

 

이제 주어진 list에 대해서 다양한 구간의 합을 구하는 도중, list의 특정 값을 업데이트를 하고 싶다고 하자. 예를 들어서 2번째 값인 3을 11로 바꾸고 싶다면 어떻게 할까? (list -> [1,2,11,4,5,6,7,8,9,10])

 

만약 prefix sum을 사용해서 구간 합을 구하려고 하면 2번째 sum부터 (list[0] + ... + list[2]) 9번째 sum (list[0] + ... + list[9])까지 모두 바꿔줘야 해서 비효율적이다. 이 경우 구간 합을 구하는 시간 복잡도는 O(1)이지만 업데이트를 하는 시간 복잡도는 O(n)이다.

 

이런 문제는 segment tree를 사용하면 더욱 효율적으로 풀 수 있다.  segment tree를 사용하면 구간 합의 시간 복잡도는 O(log(n) )이라서 prefix sum 대비 비효율적이지만 업데이트하는 시간 복잡는 O(log(n))이기 때문에 업데이트가 잦은 경우 segment tree를 사용하면 이득을 볼 수 있다.

 

일반적으로 segment tree를 만드는 방법을 소개하기 전에 위에서 소개한 예시에서 segment tree가 어떻게 생겼는지 살펴보자:

  • 이진 트리 (binary tree) 구조를 갖는다 (각 parent node는 2개의 child noder가 있다)
  • leaf node (가장 말단에 있는 노드)들은 회색으로 칠했는데, 이들은 list에 있는 값들과 같다. 
  • 자식 노드 (child node)의 값들을 더하면 부모 노드 (parent node)의 값이 된다
  • 노드 옆 하늘색 상자에 있는 숫자는 노드의 번호다 (index)

segment tree가 있으면 구간 합은 다음과 같이 구할 수 있다. 예를 들어 list[4]+...+list[7]을 구하려면 segment tree의 10번 노드와 5번 노드의 값을 더하면 된다 (5 + 21 = 5 + 6 + 7 + 8 = 26). 

 

이제 체계적으로 segment tree를 살펴보자. 우선 어떻게 주어진 list에서 segment tree를 만들 수 있는지 살펴보자.

  • 예시에서 segment tree의 0번 노드 같은 경우 list[0] + ... + list[9]
  • 1번 노드는 list[0] + ... + list[4]
  • 2번 노드는 list[5] + ... + list[9]
  • 3번 노드는 list[0] + list[1] + list[2]
  • 4번 노드는 list[3] + list[4]
  • 등등

으로 정의된다. 보다시피 다음과 같은 규칙이 있다: parent node의 값을 구하기 위한 list를 반으로 나눠서 각각 왼쪽, 오른쪽 child node에 할당한다. 할당된 list의 값들을 더하면 child node의 값을 구할 수 있게 된다.  예를 들어서 0번째 노드는 list의 [1,2,...,10]를 더해서 얻을 수 있다 (=55). 이 리스트를 반으로 쪼개서 [1,2,3,4,5]를 왼쪽 child node에 (값들을 더하면 15), [6,7,8,9,10]를 오른쪽 child node에 (값들을 더하면 40) 할당한다.

 

이 규칙에 따라서 tree의 깊이는 다음과 같이 얻을 수 있다 (참고로 노드가 1개인 경우 깊이는 0으로 정의한다): tree의 높이 h = \( \textrm{ceil} \log_2 n\)

높이가 h인 완전 이진 트리를 (완전 이진 트리: leaf node가 모두 같은 깊이에 있는 이진 트리) 만들 때 필요한 공간을 확보해 두고 작업하는 것이 편하기 때문에, 필요한 공간 복잡도는 \( 2^0 + ... + 2^h = 2\times 2^h-1 \)가 된다. 따라서 필요한 공간은 \( 2 \times 2^h-1 \)는 2n -1 혹은 4n-1이다 (이 값을 계산하기 싫으면 비효율적이지만 넉넉하게 4n으로 잡아도 된다)

 

이제 segment tree를 만드는 코드를 살펴보자:

from math import ceil, log2

# example list
nums = [1,2,3,4,5,6,7,8,9,10]

# number of elements
n = len(nums)-1

# height of segment tree
h = int((ceil(log2(n))))

# size of tree
size_tree = 2*int(2**h)-1

# segment tree
segment_tree = [0] * size_tree

def build_segment_tree(l,r,idx):
	""" 
	l: leftmost index of (sub) list
	r: rightmost index of (sub) list
	idx: index of current node
	"""
	if l==r:
		segment_tree[idx] = nums[l]
		return segment_tree[idx]
	mid = l+(r-l)//2
	segment_tree[idx] = build_segment_tree(l,mid,idx*2+1) + build_segment_tree(mid+1,r,idx*2+2)
	return segment_tree[idx]

build_segment_tree(0,9,0)

print(segment_tree)
# 출력값: [55, 15, 40, 6, 9, 21, 19, 3, 3, 4, 5, 13, 8, 9, 10, 1, 2, 0, 0, 0, 0, 0, 0, 6, 7, 0, 0, 0, 0, 0, 0]
  • build_segment_tree의 코드는 재귀함수로 구현했다.
  • build_segment_tree 함수는 가장 위에 있는 노드부터 (0번째 노드) 차례차례 값을 구하는 방식으로 생각하면 편하다
  • build_segment_tree 함수를 부를 때 nums의 가장 왼쪽 index (=0), 가장 오른쪽 index (=9), tree index (=0)를 argument로 받는다.
    • 이것의 의미는 nums의 0번째 index부터 9번째 index까지 합이 tree의 0번째 node의 값이 된다는 것이다 (i.e.[1,2,...,10]의 합).
    • 이유: segment_tree[0] = build_segment_tree[0,4,1] + build_segment_tree[5,9,2]
    • 여기서 build_segment_tree[0,4,1]는 nums의 0~4번째 index 값들의 합이 된다 (nums 왼쪽 반쪽의 합 = [1,2,3,4,5]의 합)
    • build_segment_tree[5,9,2]는 nums의 5~9번째 값들의 합이 된다 (nums 오른쪽 반쪽의 합 = [6,7,8,9,10]의 합)
  • build_segment_tree[0,4,1]같은 경우 위에서 설명한 것과 비슷하게 nums의 0번째부터 4번째까지 값들로 만들어진 리스트 (i.e. [1,2,3,45])의 왼쪽 반쪽과 (i.e. [1,2,3]) 오른쪽 반쪽(i.e. [4,5])의 합으로 나눠서 계산을 하게 되고, 이 값을 segment_tree[1]에 저장한다
    • 참고로 i번째 노드는 2개의 child node가 있고, 2개의 child node의 index는 idx*2+1, idx*2+2가 된다.
    • 따라서 0번째 노드의 child node들의 index는 각각 1, 2가 된다; 1번째 노드의 child node들의 index는 각각 3,4가 된다
  • 다른 값들도 비슷하게 계산된다. 재귀적으로 리스트를 반으로 쪼개다가, 리스트에 1개의 값만 남게 되면 leaf node에 도달했다는 뜻이기 때문에, 이 값을 segment tree에 넣는다 (if l==r 관련 코드)

이제 특정 interval에 있는 수들의 합을 구하는 코드

def sum_interval(s,e,l,r,idx):
	"""
    for summing values in an interval
    s: starting index of interval
    e: ending index of interval
    """
	if s > r or e < l:
		return 0
	if s <= l and e >= r:
		return segment_tree[idx]
	m = l + (r-l)//2
	return sum_interval(s,e,l,m,idx*2+1)+sum_interval(s,e,m+1,r,idx*2+2)
  • 아이디어는  build_segment_tree와 비슷하다. 이전과 비슷하게 idx는 현제 segment tree의 노드 index를 나타낸다. idx번째 노드의 값은 nums의 l에서 r (포함) 사이에 있는 값들을 더해서 얻어진다. 이해를 돕기 위해서 몇 가지 예를 살펴보자
  • nums의 0번째부터 9번째(포함) 사이에 있는 수를 모두 더하려면 sum_interval(0,9,0,9,0))을 호출하면 된다. 두 번째 if문의 조건을 만족하기 때문에 segment_tree[0]의 값을 리턴하게 되는데, 이 값은 55다. 
  • nums의 0번째부터 4번째 (포함) 사이에 있는 수를 모두 더하려면 sum_interval(0,4,0,9,0)을 호출하면 된다. m=4가 되고, sum_interval(0,4,0,4,1) & sum_interval(0,4,5,9,2)를 호출하게 된다. 전자는 segment_tree[1] = 15를 리턴하고 후자는 0을 리턴한다. 최종적으로 15+0=15를 리턴한다.
  • 보다시피 segment tree의 0번째 노드에서 출발한다. 현제 노드가 나타내는 interval이 구하고자 하는 interval [s,e]와 겹치지 않으면 (첫번째 if 문) 0을 리턴한다. 현제 노드가 나타내는 interval이 [s,e]에 포함되면 현제 노드의 값을 리턴한다. 둘 다 아니면 (다시 말해서 현제 노드가 나타내는 interval이 [s,e]와 겹치는 부분이 있으면) 현제 노드의 child node에 대해서 위에 설명한 과정을 반복한다
  • segment tree의 깊이가 O(log(n))이기 때문에 구간 합의 시간 복잡도도 O(log(n))이 된다

이제 특정 값을 업데이트하는 코드를 보자.

def update(update_idx,delta_val,l,r,idx):
    """
    update_idx: index in nums that is to be updated
    delta_val: by how much to change the value num[idx]
    l, r, idx: as before
    """
    if update_idx < l or update_idx > r:
        return
    segment_tree[idx] += delta_val
    if l==r:
        return
    m = l + (r-l)//2
    update(update_idx,delta_val,l,m,idx*2+1)
    update(update_idx,delta_val,m+1,r,idx*2+2)
  • update_idx는 nums에서 바꾸고 싶은 index다. 그리고 이 값을 delta_val만큼 바꾸게 된다. 예를 들어 2번째 값을 3에서 11로 바꾸려면 (nums -> [1,2,11,4,5,6,7,8,9,10]) update(2,8,0,9,0)을 호출하면 된다 (8 = 11 - nums[2] = 11 - 3 )
  • segment_tree의 각 노드는 nums에서 일정 범위의 값들을 더하기 때문에 update_idx가 이 범위에 들어가면 delta_val만큼 바꿔준다고 이해하면 된다.
    • 0번째 노드는 모든 값을 더하기 때문에 55 -> 63
    • 1번째 노드는 0~4번째 값들을 더하기 때문에 (index 2를 포함) 15 -> 23
    • 2번째 노드는 5~9번째 값들을 더하기 때문에 (index 2를 미포함) 40
    • 3번째 노드는 1~3번째 값들을  더하기 때문에 (index 2를 포함) 6->14
    • 등등

참고로 segment tree가 사용되는 알고리즘에서 대표적으로 lazy propagation이 있는데, 같이 보면 좋다. 예전에 코딩테스트에서 나온 적이 있지만 자주 등장하지 않는 것으로 알고 있다.

 

참고 자료:

https://www.geeksforgeeks.org/segment-tree-data-structure/

728x90
728x90