크루스칼 알고리즘은 (connected, undirected, edge 에 weight가 있는) 그래프에서 최단 신장 트리 (minimal spanning tree)를 구하는 알고리즘이다.
- 신장 트리란 그래프의 모든 노드들을 (사이클이 없이) 연결하는 트리다
- 노드가 N개 있으면 N-1 간선(edge)으로 이뤄진다
- 하나의 그래프에 여러개의 신장 트리가 있을 수 있다
- 최단 신장 트리는 트리의 edge weight들을 모두 더했을 때 최소가 되는 신장 트리다
크루스칼 알고리즘은 이전 글에서 설명한 유니온 파인드 알고리즘을 기반으로 구성된다. 우선 간선들을 오름차순으로 정렬을 한다. 그리고 가장 edge weight가 작은 간선들을 차례차례 트리에 추가하고, 신장 트리가 완성되면 끝난다. 알고리즘의 시간 복잡도는 \( O(E \log E)\)가 된다 (E는 간선의 개수). 조금 더 구체적인 알고리즘은 예를 들면서 살펴보자.
다음과 같은 그래프에서 최단 신장 트리를 구해보자.
1. 시작할 때 각 노드가 (5개의) 트리의 root가 된다.
2. edge들을 edge weight로 정렬을 한다: (1,0,3), (2,2,3), (2,0,2), (3,3,4), (4,4,0), (6,1,2).
여기서 다음과 같은 notation을 사용했다: (edge weight, node1, node2)
3. 가장 edge weight가 작은 node들을 연결해서 트리를 만든다. 여기서는 (1,0,3).
이 부분을 구현할 때 0의 root와 3의 root가 파인드 (find) 알고리즘으로 같은지 확인을 하고 (같으면 사이클이 생기기 때문에) 유니온 알고리즘으로 합해준다. 결과적으로 0의 root는 0, 3의 root는 0이 된다.
4. 그 다음으로 edge weight가 작은 간선은 (2,2,3):
5. 그 다음은 (2,0,2)인데 이 간선을 추가하면 사이클이 생긴다 (0의 root와 2의 root가 0이다). 이런 경우 스킵한다.
6. 그 다음은 (3,3,4)
7. 그 다음은 (4,4,0)인데 사이클이 발생하므로 스킵
8. 마지막으로 (6,1,2)를 추가하면 완성 (최단 거리는 3 + 2 + 1 + 6 = 12)
정리하면 다음과 같다
def kruskall(distances):
# distances = [(distance, node1, node2)]
# 우선 edge 사이 distance 값으로 sort해준다
distances.sort(key = lambda x : x[0])
# 최소 거리
min_distance = 0
# 간선 거리가 작은 순서로 진행
for distance, node1, node2 in distances:
# node1, node2의 root node를 먼저 찾는다
root1 = find(node1)
root2 = find(node2)
# root node가 같으면 사이클이 발생하기 때문에 스킵
if root1==root2:
continue
# 사이클이 발생하지 않으면 최소 거리에 추가하고
min_distance += distance
# 유니온을 통해서 트리에 간선 추가
union(node1,node2)
return min_distance
이제 문제를 풀어보자
https://leetcode.com/problems/min-cost-to-connect-all-points/description/
2차원 평면에서 N개의 점이 있을 때, 점들의 사이는 manhattan distance로 정의된다: \( distance(point_1,point_2) = |x_1-x_2| + |y_1-y_2|\), \( point1 = [x_1,y_1]\). 이 점들을 최단 신장 트리로 연결했을 때 최단 거리를 구하는 문제다.
크루스칼 알고리즘을 거의 그대로 적용하면 된다.
class Solution:
def minCostConnectPoints(self, points: List[List[int]]) -> int:
N = len(points) # 점 개수
# 우선 union, find 구현.
def init_parents(N):
return [i for i in range(N)]
parents = init_parents(N)
def find(i):
if parents[i] == i:
return i
parents[i] = find(parents[i])
return parents[i]
def union(i,j):
i = find(i)
j = find(j)
if i<j:
parents[j] = i
else:
parents[i] = j
# manhattan distance 구현
def distance(i,j):
a = points[i]
b = points[j]
return abs(b[0]-a[0]) + abs(b[1]-a[1])
# 간선 거리와 node 정보를 리스트에 저장
distances = [(distance(i,j),i,j) for i in range(N) for j in range(i)]
# 간선 거리를 기준으로 오름차순 정렬
distances.sort(key = lambda x : x[0])
# 최단 거리
min_distance = 0
#크루스칼 알고리즘
for distance, node1, node2 in distances:
root1 = find(node1)
root2 = find(node2)
if root1==root2:
continue
min_distance += distance
union(node1,node2)
return min_distance
참고자료:
https://en.wikipedia.org/wiki/Kruskal%27s_algorithm,
https://www.geeksforgeeks.org/kruskals-minimum-spanning-tree-algorithm-greedy-algo-2/
'알고리즘' 카테고리의 다른 글
[알고리즘] Segment Tree (0) | 2023.04.30 |
---|---|
[알고리즘] 플로이드 워셜 알고리즘 (+LeetCode 1335) (0) | 2023.04.16 |
[알고리즘] 유니온 파인드 (0) | 2023.04.07 |
[알고리즘] 다이나믹 프로그래밍 (문제편 2: 백준 2098 외판원 순회) (0) | 2023.04.02 |
[알고리즘] 다이나믹 프로그래밍 (문제편 1: 배낭 문제) (0) | 2023.04.01 |