데이터 엔지니어

백준 - [Gold 4] 1922번 네트워크 연결 본문

프로그래밍(Programming)/알고리즘(Algorithm)

백준 - [Gold 4] 1922번 네트워크 연결

kingsmo 2020. 9. 16. 22:34

문제링크: https://www.acmicpc.net/problem/1922

 

1922번: 네트워크 연결

이 경우에 1-3, 2-3, 3-4, 4-5, 4-6을 연결하면 주어진 output이 나오게 된다.

www.acmicpc.net


문제 설명

- N: 컴퓨터의 개수

- M: 연결하는 선의 개수

- (a, b, c) M개 주어짐 (c=비용)

 

위와 같은 입력이 주어졌을 때, 모든 컴퓨터를 연결할 수 있는 최소 비용을 찾는 문제입니다.


풀이

MST를 찾는 문제입니다.

MST는 Minimum Spanning Tree 최소 신장트리를 뜻하며, 모든 정점을 잇는 최소 비용의 트리입니다.

MST를 찾기 위해서 크루스칼(kruskal) 알고리즘을 사용합니다.

크루스칼 알고리즘
1. 비용순으로 edge들을 정렬한다.
2. 두 정점의 최상위 정점을 확인하고 서로 다를 경우 정점을 연결한다.
  - 여기서 주의할 점은 사이클이 생기지 않아야 한다.
3. 모든 정점이 포함되면 종료한다.

크루스칼 알고리즘에는 Union-Find 알고리즘이 사용됩니다.

모든 원소를 개별 집합으로 두고 두 집합을 하나로 합추는 방식을 뜻합니다.
Union = 두 집합을 하나의 집합으로 합침
Find = 각 집합의 루트노드를 확인
Union 하는 순서에따라 최악의 경우 아래와 같이 root까지 찾아가야 하는 경우가 생김.
그래서 path compression과 union-by-rank 방식을 사용합니다.
path-compression은 a-b-c-d를 루트에 다이렉트로 연결하여 a - b, c, d처럼 만드는 것
union-by-rank은 Union시 두 트리의 높이(rank)가 다르면, 높이가 작은 트리를 높이가 큰 트리에 붙임

자세한 그림과 설명은 이 블로그를 참조해 주세요.

그래서 코드는 MST를 찾기위해 크루스칼 알고리즘을 사용하면 되는 문제입니다!

 

코드

from sys import stdin

stdin = open("input.txt", "r")

def find(node):
    # path compression 기법
    if parent[node] != node:
        parent[node] = find(parent[node])
    return parent[node]

def union(node_v, node_u):
    root1 = find(node_v)
    root2 = find(node_u)
    
    # union-by-rank 기법
    if rank[root1] > rank[root2]:
        parent[root2] = root1
    else:
        parent[root1] = root2
        if rank[root1] == rank[root2]:
            rank[root2] += 1    
    
def make_set(node):
    parent[node] = node
    rank[node] = 0

def kruskal(vertices, edges):
    mst = list()
    
    # 1. 초기화
    for node in vertices:
        make_set(node)
    
    # 2. 간선 weight 기반 sorting
    edges.sort(key=lambda x: x[2]) # 비용순으로 sort
    
    # 3. 간선 연결 (사이클 없는)
    for edge in edges:
        node_v, node_u, weight = edge
        # 루트가 다를 경우 합쳐준다.
        if find(node_v) != find(node_u):
            union(node_v, node_u)
            mst.append(weight) # 합쳐줄 때 추가
    
    return sum(mst)
    
# N 컴퓨터의 개수
# M 선의 수
N = int(stdin.readline()) 
M = int(stdin.readline()) 

vertices = list(range(1, N+1))
parent = {}
rank = {}
    
# (a, b, c) M개
edges = [tuple(map(int, stdin.readline().rstrip().split())) for _ in range(M)]

print(kruskal(vertices, edges))    
Comments