상세 컨텐츠

본문 제목

[Algorithm] 백트래킹(Backtracking)

Computer Science/Algorithm

by G_Batman 2022. 9. 21. 21:42

본문

728x90

안녕하세요 배트맨🦇 입니다 !

오늘은 백트래킹이 무엇이고 어떻게 구현해서 사용해야 하는지에 대해 정리해보려고 합니다.

또 백트래킹 문제인 "BOJ 15649번 - N과M" 을 풀어보려고 합니다.

시작해보겠습니다!


백트래킹이란?

모든 경우의 수를 전부 고려하는 알고리즘. 상태 공간을 트리로 나타낼 수 있을 때 적합한 방식이다. 일종의 트리 탐색 알고리즘.
- 나무 위키 -

나무 위키의 정의만 보고는 이해가 쉽지 않습니다. 

 

즉, 

현재 상태에서 가능한 모든 후보군을 따라 들어가며 탐색하는 알고리즘입니다.

DFS를 사용하여 조건에 맞지 않으면 중단하고 이전으로 돌아가여 다시 확인하는 것을 반복하며 원하는 조건을 찾아갑니다.


BOJ 15649  - N과 M (1)

http://www.acmicpc.net/problem/15649

 

15649번: N과 M (1)

한 줄에 하나씩 문제의 조건을 만족하는 수열을 출력한다. 중복되는 수열을 여러 번 출력하면 안되며, 각 수열은 공백으로 구분해서 출력해야 한다. 수열은 사전 순으로 증가하는 순서로 출력해

www.acmicpc.net

 

사실 이 문제를 처음 봤을 때 파이썬 itertools의 permutation을 이용하여 쉽게 풀었습니다.

하지만 이 방법은 파이썬에서만 적용할 수 있는 방법이고,

백트래킹 문제로 분류된 문제인 만큼 백트래킹 알고리즘을 이용해 풀어보겠습니다.

 

( permutations 를 이용한 풀이 )

더보기
import itertools

nPr = itertools.permutations

n,m = map(int,input().split())

res = list(nPr(range(1,n+1),m))

for bind in res:
    print(*bind, sep=" ")

 

Solution

DFS를 사용하여 문제를 해결할 수 있는데, 백트래킹은 일반적인 DFS와 달리 가지치기를 합니다.

n과 m이 각각 5,2 일 때 (1, 2), (1, 3), (1, 4), (1, 5), (2, 1), (2, 3) ,,, 과 같이 중복되지 않은  순서쌍을 찾아야 합니다.

따라서 가지치기의 조건은 이미 고른 숫자라면 제외시키는 것입니다.

그럼 지금부터 코드를 부분으로 나눠가며 보겠습니다.

n,m = map(int,input().split())

s = []

n과 m을 입력받고, 출력을 위한 스택인 리스트 s를 선언합니다.

그 다음은 DFS 함수를 구현할 차례입니다. DFS 함수는 재귀적인 방법을 사용하므로 탈출 조건을 정해주어야 합니다.

스택 s에 정수가 m개만큼 들어왔을 때 숫자들을 출력하고 탈출할 수 있도록 합니다. 코드는 다음과 같습니다.

def dfs():
    if len(s) == m:
        print(' '.join(map(str,s)))
        return

그 다음 골라지지 않은 숫자에 대해 탐색하는 부분을 구현합니다. DFS의 전체 코드는 다음과 같습니다.

def dfs():
    if len(s) == m:
        print(' '.join(map(str,s)))
        return
    
    for i in range(1,n+1):
        if i not in s:
            s.append(i)
            dfs()
            s.pop()

골라지지 않은 숫자를 스택 s에 추가하고, 그 다음 자리를 탐색하기 위해 DFS를 재귀적으로 호출합니다. DFS함수에서 리턴되어 돌아온다면 pop을 이용해 스택 s를 비워줍니다.

 

( 전체 코드 )

n,m = map(int,input().split())

s = []
 
def dfs():
    if len(s) == m:
        print(' '.join(map(str,s)))
        return
    
    for i in range(1,n+1):
        if i not in s:
            s.append(i)
            dfs()
            s.pop()
 
dfs()

 

728x90

관련글 더보기

댓글 영역