본문 바로가기
ALGORITHM/개념 정리

[Algorithm] 최소 신장 트리 (MST)

by 안녕나는현서 2021. 10. 21.
728x90

신장 트리

  • n개의 정점으로 이루어진 무방향 그래프에서 n개의 정점과 n-1개의 간선으로 이루어진 트리

 

최소 신장 트리 (Minimum Spanning Tree)

  • 무방향 가중치 그래프에서 신장 트리를 구성하는 간선들의 가중치의 합이 최소인 신장 트리

 

Prim Algorithm

  • 하나의 정점에서 연결된 간선들 중에 하나씩 선택하면서 MST를 만들어 가는 방식
    1. 임의 정점 하나 선택해서 시작
    2. 선택한 정점과 인접하는 정점들 중 최소 비용의 간선이 존재하는 정점 선택
    3. 모든 정점이 선택될 때 까지 반복
'''
6 11
0 1 32
0 2 31
0 5 60
0 6 51
1 2 21
2 4 46
2 6 25
3 4 34
3 5 18
4 5 40
4 6 51
'''

def find_MST(s):
    key[s] = 0
    
    # 정점의 개수만큼 반복
    for _ in range(V):
        # 방문하지 않은 정점 중, 최소 가중치를 가진 정점 찾기
        min_idx = -1
        min_val = float('inf')
        for i in range(V+1):
            if not visited[i] and key[i] < min_val:
                min_idx = i
                min_val = key[i]

        # 현재 최소 가중치를 가진 정점 (== min_idx)
        # 방문 처리
        visited[min_idx] = 1

        # 인접 노드들을 확인하면서
        for i in range(V+1):
            # 이웃 노드와 연결되어있고 방문한 적이 없다면
            if adj_mat[min_idx][i] and not visited[i]:
                # 인접 정점들의 key 값을 필요하면 갱신
                weight = adj_mat[min_idx][i]
                if weight < key[i]:
                    key[i] = weight       # 더 작은 키 값으로 갱신
                    parents[i] = min_idx  # 이웃 노드의 부모 노드 변경


V, E = map(int, input().split())
edges = [list(map(int, input().split())) for _ in range(E)]

# 인접 행렬
adj_mat = [[0 for _ in range(V+1)] for _ in range(V+1)]
for n1, n2, w in edges:
    adj_mat[n1][n2] = w
    adj_mat[n2][n1] = w

# 초기화
key = [float('inf')] * (V + 1)  # 초기 각 노드의 가중치는 무한대로 초기화
parents = [None] * (V + 1)      # 부모 정점 리스트 초기화
visited = [0] * (V + 1)         # 방문 배열 초기화

s = 0
find_MST(s)

# parents가 이루는 트리가 결국 MST가 됨
# ex) MST의 가중치의 합을 구하시오.
print(sum(key))

 

Kruskal Algorithm

  • 간선을 하나씩 선택해서 MST를 찾는 알고리즘
    1. 모든 간선을 가중치에 따라 오름차순 정렬
    2. 가중치가 가장 낮은 간선부터 사이클을 이루지 않으면, 선택하면서 트리를 증가
    3. n-1개의 간선이 선택될 때 까지 반복
'''
6 11
0 1 32
0 2 31
0 5 60
0 6 51
1 2 21
2 4 46
2 6 25
3 4 34
3 5 18
4 5 40
4 6 51
'''

def make_set(x):
    parents[x] = x
    ranks[x] = 0


def find_set(x):
    if x != parents[x]:
        parents[x] = find_set(parents[x])

    return parents[x]


def union(x, y):
    root_x = find_set(x)
    root_y = find_set(y)

    # root_x의 트리의 높이(rank)가 더 클 경우
    if ranks[root_x] > ranks[root_y]:
        parents[root_y] = root_x
    # root_y의 트리의 높이가 더 크거나 같을 경우
    else:
        parents[root_x] = root_y
        if ranks[root_x] == ranks[root_y]:
            ranks[root_y] += 1


def find_MST(edges):
    MST = []

    for edge in edges:
        x, y, w = edge
        if find_set(x) != find_set(y):  # Cycle Detection
            # 현재 간선이 MST로 선택 가능하다는 뜻
            MST.append((x, y, w))
            union(x, y)  # MST를 구성하는 부분

    return MST


V, E = map(int, input().split())
edges = [list(map(int, input().split())) for _ in range(E)]

parents = [0] * (V + 1)
ranks = [0] * (V + 1)

for v in range(1, 1+V):
    make_set(v)

# 가중치 오름차순으로 정렬
edges.sort(key=lambda x: x[2])
print(find_MST(edges))
728x90

댓글