Minimum Spanning Tree(์ต์ ์ ์ฅ ํธ๋ฆฌ)
Kruskal MST Algorithm
Prim MST Algorithm
๊ทธ๋ํ G์ spanning tree ์ค edge weight์ ํฉ์ด ์ต์์ธ spanning tree
๋ฅผ ์๋ฏธํ๋ค.
- ๊ฐ ๊ฐ์ ์ ๊ฐ์ค์น๊ฐ ๋์ผํ์ง ์์ ๋ ๋จ์ํ ๊ฐ์ฅ ์ ์ ๊ฐ์ ์ ์ฌ์ฉํ๋ค๊ณ ํด์ ์ต์ ๋น์ฉ์ด ์ป์ด์ง๋ ๊ฒ์ ์๋๋ค.
- MST๋ ๊ฐ์ ์ ๊ฐ์ค์น๋ฅผ ๊ณ ๋ คํ์ฌ ์ต์ ๋น์ฉ์ spanning tree๋ฅผ ์ ํํ๋ ๊ฒ์ ๋งํ๋ค.
- ์ฆ ๋คํธ์ํฌ(๊ฐ์ค์น๋ฅผ ๊ฐ์ ์ ํ ๋นํ ๊ทธ๋ํ)์ ์๋ ๋ชจ๋ ์ ์ ๋ค์ ๊ฐ์ฅ ์ ์ ์์ ๊ฐ์ ๊ณผ ๋น์ฉ์ผ๋ก ์ฐ๊ฒฐํ๋ ๊ฒ์ด๋ค.
- ๊ฐ์ ์ ๊ฐ์ค์น์ ํฉ์ด ์ต์์ฌ์ผ ํ๋ค.
- n๊ฐ์ ์ ์ ์ ๊ฐ์ง๋ ๊ทธ๋ํ์ ๋ํด ๋ฐ๋์ (n-1)๊ฐ์ ๊ฐ์ ๋ง์ ์ฌ์ฉํด์ผ ํ๋ค.
- ์ฌ์ดํด์ด ํฌํจ๋์ด์๋ ์๋๋ค.
spanning tree
: ๊ทธ๋ํ ๋ด์ ์๋ ๋ชจ๋ ์ ์ ์ ์ฐ๊ฒฐํ๊ณ ์ฌ์ดํด์ด ์๋ ๊ทธ๋ํ๋ฅผ ์๋ฏธํ๋ค.
ํ์์ ์ธ ๋ฐฉ๋ฒ(greedy method)๋ฅผ ์ด์ฉํ์ฌ ๋คํธ์ํฌ(๊ฐ์ค์น๋ฅผ ๊ฐ์ ์ ํ ๋นํ ๊ทธ๋ํ)์ ๋ชจ๋ ์ ์ ์ ์ต์ ๋น์ฉ์ผ๋ก ์ฐ๊ฒฐํ๋ ์ต์ ํด๋ต์ ๊ตฌํ๋ ๊ฒ
- MST๊ฐ 1)์ต์ ๋น์ฉ์ ๊ฐ์ ์ผ๋ก ๊ตฌ์ฑ๋จ 2)์ฌ์ดํด์ ํฌํจํ์ง ์์ ์ ์กฐ๊ฑด์ ๊ทผ๊ฑฐํ์ฌ ๊ฐ ๋จ๊ณ์์ ์ฌ์ดํด์ ์ด๋ฃจ์ง ์๋ ์ต์ ๋น์ฉ ๊ฐ์ ์ ์ ํํ๋ค.
- ๊ฐ์ ์ ํ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ ์๊ณ ๋ฆฌ์ฆ์ด๋ค.
- ์ด์ ๋จ๊ณ์์ ๋ง๋ค์ด์ง ์ ์ฅํธ๋ฆฌ์๋ ์๊ด์์ด ๋ฌด์กฐ๊ฑด ์ต์ ๊ฐ์ ๋ง์ ์ ํํ๋ ๋ฐฉ๋ฒ์ด๋ค.
- ๊ทธ๋ํ์ ๊ฐ์ ๋ค์ ๊ฐ์ค์น์ ์ค๋ฆ์ฐจ์์ผ๋ก ์ ๋ ฌํ๋ค.
- ์ ๋ ฌ๋ ๊ฐ์ ๋ฆฌ์คํธ์์ ์์๋๋ก ์ฌ์ดํด์ ํ์ฑํ์ง ์๋ ๊ฐ์ ์ ์ ํํ๋ค.
- ์ฆ, ๊ฐ์ฅ ๋ฎ์ ๊ฐ์ค์น๋ฅผ ๋จผ์ ์ ํํ๋ค.
- ์ฌ์ดํด์ ํ์ฑํ๋ ๊ฐ์ ์ ์ ์ธํ๋ค.
- ํด๋น ๊ฐ์ ์ ํ์ฌ์ MST์ ์งํฉ์ ์ถ๊ฐํ๋ค.
- ๋ค์ ๊ฐ์ ์ ์ด๋ฏธ ์ ํ๋ ๊ฐ์ ๋ค์ ์งํฉ์ ์ถ๊ฐํ ๋ ์ฌ์ดํด์ ์์ฑํ๋์ง ํ์ธํด์ผ ํ๋ค.
- ์๋ก์ด ๊ฐ์ ์ด ์ด๋ฏธ ๋ค๋ฅธ ๊ฒฝ๋ก์ ์ํด ์ฐ๊ฒฐ๋์ด ์๋ ์ ์ ๋ค์ ์ฐ๊ฒฐํ ๋ ์ฌ์ดํด์ด ํ์ฑ๋๋ค.
- ์ฆ ์ถ๊ฐํ ์๋ก์ด ๊ฐ์ ์ ์ ๋ ์ ์ ์ด ๊ฐ์ ์งํฉ์ ์ํด ์์ผ๋ฉด ์ฌ์ดํด์ด ํ์ฑ๋๋ค.
- ์ฌ์ดํด ์์ฑ ์ฌ๋ถ๋ฅผ ํ์ธํ๋ ๋ฐฉ๋ฒ
- ์ถ๊ฐํ๊ณ ์ ํ๋ ๊ฐ์ ์ ์ ๋ ์ ์ ์ด ๊ฐ์ ์งํฉ์ ์ํด ์๋์ง๋ฅผ ๋จผ์ ๊ฒ์ฌํด์ผ ํ๋ค.
union-find ์๊ณ ๋ฆฌ์ฆ
์ด์ฉ
union-find
์๊ณ ๋ฆฌ์ฆ์ ์ด์ฉํ๋ฉด Kruskal ์๊ณ ๋ฆฌ์ฆ์ ์๊ฐ ๋ณต์ก๋๋ ๊ฐ์ ๋ค์ ์ ๋ ฌํ๋ ์๊ฐ์ ์ข์ฐ๋๋ค.- ์ฆ ๊ฐ์ e๊ฐ๋ฅผ ํต ์ ๋ ฌ๊ณผ ๊ฐ์ ํจ์จ์ ์ธ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก ์ ๋ ฌํ๋ค๋ฉด ์๊ฐ ๋ณต์ก๋๋
O(ElogE)
๊ฐ ๋๋ค. - ๊ทธ๋ฆฌ๊ณ ์ ๋ ฌ๋ ๊ฐ์ ์ ์ํํ๋ฉฐ
union-find
์ฐ์ฐ์ ํ๋ฒ์ฉ ์ํํ๋ค.O(1)*E
- ๊ฒฐ๊ณผ์ ์ผ๋ก
O(ElogE + E)
import sys
v, e = map(int, input().split())
# ๋ถ๋ชจ ํ
์ด๋ธ ์ด๊ธฐํ
parent = [0] * (v+1)
for i in range(1, v+1):
parent[i] = i
# find ์ฐ์ฐ
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
# union ์ฐ์ฐ
def union_parent(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
else:
parent[a] = b
# ๊ฐ์ ์ ๋ณด ๋ด์ ๋ฆฌ์คํธ์ ์ต์ ์ ์ฅ ํธ๋ฆฌ ๊ณ์ฐ ๋ณ์ ์ ์
edges = []
total_cost = 0
# ๊ฐ์ ์ ๋ณด ์ฃผ์ด์ง๊ณ ๋น์ฉ์ ๊ธฐ์ค์ผ๋ก ์ ๋ ฌ
for _ in range(e):
a, b, cost = map(int, input().split())
edges.append((cost, a, b))
# ๊ฐ์ ์ ๋ณด ๋น์ฉ ๊ธฐ์ค์ผ๋ก ์ค๋ฆ์ฐจ์ ์ ๋ ฌ
edges.sort()
# ๊ฐ์ ์ ๋ณด ํ๋์ฉ ํ์ธํ๋ฉด์ ํฌ๋ฃจ์ค์นผ ์๊ณ ๋ฆฌ์ฆ ์ํ
for i in range(e):
cost, a, b = edges[i]
# find ์ฐ์ฐ ํ, ๋ถ๋ชจ๋
ธ๋ ๋ค๋ฅด๋ฉด ์ฌ์ดํด ๋ฐ์ X์ผ๋ฏ๋ก union ์ฐ์ฐ ์ํ -> ์ต์ ์ ์ฅ ํธ๋ฆฌ์ ํฌํจ!
if find_parent(parent, a) != find_parent(parent, b):
union_parent(parent, a, b)
total_cost += cost
print(total_cost)
์์ ์ ์
์์๋ถํฐ ์ถ๋ฐํ์ฌ ์ ์ฅํธ๋ฆฌ ์งํฉ์ ๋จ๊ณ์ ์ผ๋ก ํ์ฅํด๋๊ฐ๋ ๋ฐฉ๋ฒ์ด๋ค.
- ์ ์ ์ ํ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ ์๊ณ ๋ฆฌ์ฆ์ด๋ค.
- ์ด์ ๋จ๊ณ์์ ๋ง๋ค์ด์ง ์ ์ฅ ํธ๋ฆฌ๋ฅผ ํ์ฅํ๋ ๋ฐฉ๋ฒ์ด๋ค.
- ์์ ๋จ๊ณ์์๋ ์์ ์ ์ ๋ง์ด MST ์งํฉ์ ํฌํจ๋๋ค.
- ์ ๋จ๊ณ์์ ๋ง๋ค์ด์ง MST ์งํฉ์ ์ธ์ ํ ์ ์ ๋ค ์ค์์ ์ต์ ๊ฐ์ ์ผ๋ก ์ฐ๊ฒฐ๋ ์ ์ ์ ์ ํํ์ฌ ํธ๋ฆฌ๋ฅผ ํ์ฅํ๋ค.
- ์ฆ, ๊ฐ์ฅ ๋ฎ์ ๊ฐ์ค์น๋ฅผ ๋จผ์ ์ ํํ๋ค.
- ์์ ๊ณผ์ ์ ํธ๋ฆฌ๊ฐ (N-1)๊ฐ์ ๊ฐ์ ์ ๊ฐ์ง ๋๊น์ง ๋ฐ๋ณตํ๋ค.
- ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ์์ ์๊ฐ ๋ณต์ก๋์ ๊ฐ์ฅ ํฐ ์ํฅ์ ๋ฏธ์น๋ ๊ฒ์ ๊ฐ์ค์น๊ฐ ๊ฐ์ฅ ์์ ์ ์ ์ ์ฐพ์๋ด๋ ๊ฒ๊ณผ ์ธ์ ํ ์ ์ ์ ํ์์ด๋ค.
- ๋ชจ๋ ๋
ธ๋์ ๋ํด ํ์์ ์งํํ๋ฏ๋ก
O(V)
์ด๋ค. ๊ทธ๋ฆฌ๊ณ์ฐ์ ์์ ํ
๋ฅผ ์ฌ์ฉํ์ฌ ๋งค๋ ธ๋๋ง๋ค ์ต์ ๊ฐ์ ์ ์ฐพ๋ ์๊ฐ์O(logV)
์ด๋ค. ๋ฐ๋ผ์ ํ์๊ณผ์ ์๋O(VlogV)
๊ฐ ์์๋๋ค. ๊ทธ๋ฆฌ๊ณ ๊ฐ ๋ ธ๋์ ์ธ์ ๊ฐ์ ์ ์ฐพ๋ ์๊ฐ์ ๋ชจ๋ ๋ ธ๋์ ์ฐจ์์ ๊ฐ์ผ๋ฏ๋กO(E)
๋ค. ๊ทธ๋ฆฌ๊ณ ๊ฐ ๊ฐ์ ์ ๋ํด ํ์ ๋ฃ๋ ๊ณผ์ ์ดO(logV)
๊ฐ ๋์ด ์ฐ์ ์์ ํ ๊ตฌ์ฑ์O(ElogV)
๊ฐ ์์๋๋ค. ๋ฐ๋ผ์O(VlogV+ElogV)
๋กO(ElogV)
๊ฐ ๋๋ค. (โตE๊ฐ ์ผ๋ฐ์ ์ผ๋ก V๋ณด๋ค ํฌ๊ธฐ ๋๋ฌธ) - ๋ง์ฝ ์ฐ์ ์์ ํ๊ฐ ์๋๋ผ
๋ฐฐ์ด
๋ก ๊ตฌํํ๋ค๋ฉด ๊ฐ ์ ์ ์ ์ต์ ๊ฐ์ ์ ๊ฐ๋ ์ ์ ํ์์ ๋งค๋ฒ ์ ์ ๋ง๋ค ์ํํ๋ฏ๋กO(V^2)
๊ฐ๋๊ณ ํ์ ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ ์ ์ ์ ์ต์ ๋น์ฉ ์ฐ๊ฒฐ ์ ์ ํ์์๋O(1)
์ด ์์๋๋ค. ๋ฐ๋ผ์ ์๊ฐ๋ณต์ก๋๋O(V^2)
์ด๋ค.
import heapq
import collections
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
n, m = map(int,input().split()) # ๋
ธ๋ ์, ๊ฐ์ ์
graph = collections.defaultdict(list) # ๋น ๊ทธ๋ํ ์์ฑ
visited = [0] * (n+1) # ๋
ธ๋์ ๋ฐฉ๋ฌธ ์ ๋ณด ์ด๊ธฐํ
# ๋ฌด๋ฐฉํฅ ๊ทธ๋ํ ์์ฑ
for i in range(m): # ๊ฐ์ฑ ์ ๋ณด ์
๋ ฅ ๋ฐ๊ธฐ
u, v, weight = map(int,input().split())
graph[u].append([weight, u, v])
graph[v].append([weight, v, u])
# ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ
def prim(graph, start_node):
visited[start_node] = 1 # ๋ฐฉ๋ฌธ ๊ฐฑ์
candidate = graph[start_node] # ์ธ์ ๊ฐ์ ์ถ์ถ
heapq.heapify(candidate) # ์ฐ์ ์์ ํ ์์ฑ
mst = [] # mst
total_weight = 0 # ์ ์ฒด ๊ฐ์ค์น
while candidate:
weight, u, v = heapq.heappop(candidate) # ๊ฐ์ค์น๊ฐ ๊ฐ์ฅ ์ ์ ๊ฐ์ ์ถ์ถ
if visited[v] == 0: # ๋ฐฉ๋ฌธํ์ง ์์๋ค๋ฉด
visited[v] = 1 # ๋ฐฉ๋ฌธ ๊ฐฑ์
mst.append((u,v)) # mst ์ฝ์
total_weight += weight # ์ ์ฒด ๊ฐ์ค์น ๊ฐฑ์
for edge in graph[v]: # ๋ค์ ์ธ์ ๊ฐ์ ํ์
if visited[edge[2]] == 0: # ๋ฐฉ๋ฌธํ ๋
ธ๋๊ฐ ์๋๋ผ๋ฉด, (์ํ ๋ฐฉ์ง)
heapq.heappush(candidate, edge) # ์ฐ์ ์์ ํ์ edge ์ฝ์
return total_weight
print(prim(graph,1))
- ํ๋ฆผ์ ์ ์ ์์ฃผ์ ์๊ณ ๋ฆฌ์ฆ, ํฌ๋ฃจ์ค์นผ์ ๊ฐ์ ์์ฃผ์ ์๊ณ ๋ฆฌ์ฆ
- ํ๋ฆผ์ ์์์ ์ ์ ํ๊ณ , ์์์ ์์ ๊ฐ๊น์ด ์ ์ ์ ์ ํํ๋ฉด์ ํธ๋ฆฌ๋ฅด ๊ตฌ์ฑ ํ๋ฏ๋ก ๊ทธ ๊ณผ์ ์์ ์ฌ์ดํด์ ์ด๋ฃจ์ง ์์ง๋ง ํฌ๋ฃจ์ค์นผ์ ์์์ ์ ๋ฐ๋ก ์ ํ์ง ์๊ณ ์ต์ ๋น์ฉ์ ๊ฐ์ ์ ์ฐจ๋ก๋ก ๋์ ํ๋ฉด์ ํธ๋ฆฌ๋ฅผ ๊ตฌ์ฑํ๊ธฐ ๋๋ฌธ์ ์ฌ์ดํด์ด ์ด๋ฃจ์ด์ง๋ ํญ์ ํ์ธ ํด์ผํ๋ค.
- ํ๋ฆผ์ ๊ฒฝ์ฐ ์ต์ ๊ฑฐ๋ฆฌ์ ์ ์ ์ ์ฐพ๋ ๋ถ๋ถ์์ ์๋ฃ๊ตฌ์กฐ์ ์ฑ๋ฅ์ ์ํฅ์ ๋ฐ๋๋ค.
- ํฌ๋ฃจ์ค์นผ์ ๊ฐ์ ์ ๊ธฐ์ค์ผ๋ก ์ ๋ ฌํ๋ ๊ณผ์ ์ด ์ค๋ ๊ฑธ๋ฆฐ๋ค.
- ๊ฐ์ ์ ๊ฐ์๊ฐ ์์ ๊ฒฝ์ฐ์๋ ํฌ๋ฃจ์ค์นผ, ๊ฐ์ ์ ๊ฐ์๊ฐ ๋ง์ ๊ฒฝ์ฐ์๋ ํ๋ฆผ.
- ์ต์ ์คํจ๋ ํธ๋ฆฌ์ ๋ํด์ ์ค๋ช ํด์ฃผ์ธ์.
- ํฌ๋ฃจ์ค์นผ๊ณผ ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ์ ์๊ฐ๋ณต์ก๋์ ๋ํด ์ค๋ช ํด์ฃผ์ธ์.
- ํฌ๋ฃจ์ค์นผ ์๊ณ ๋ฆฌ์ฆ๊ณผ ํ๋ฆผ ์๊ณ ๋ฆฌ์ฆ์ ์ฐจ์ด์ ์ ๋ํด ์ค๋ช ํด์ฃผ์ธ์.