Algorithm/acmicpc.net

파이썬 시간초과는 논리도 필요해요. (#14888 연산자 끼워넣기)

winney916 2024. 3. 20. 22:52
728x90

#14888 연산자 끼워넣기 https://www.acmicpc.net/problem/14888

 

14888번: 연산자 끼워넣기

첫째 줄에 수의 개수 N(2 ≤ N ≤ 11)가 주어진다. 둘째 줄에는 A1, A2, ..., AN이 주어진다. (1 ≤ Ai ≤ 100) 셋째 줄에는 합이 N-1인 4개의 정수가 주어지는데, 차례대로 덧셈(+)의 개수, 뺄셈(-)의 개수, 곱

www.acmicpc.net

 

쉽다고 생각했다. 브루트포스 알고리즘을 사용해야하는 만큼 시간 초과가 마음에 걸렸지만 그래도 일단 해봤다.

 

 

시간초과 코드

import sys

input = sys.stdin.readline

num = int(input())

nums = [int(x) for x in input().split()]

# +, -, *, /
operator = ["+", "-", "*", "//"]
times = [int(x) for x in input().split()]
operators = []
for i in range(4):
    for j in range(times[i]):
        operators.append(operator[i])

max_value, min_value = -1000000000, 1000000000

# 4개 기준으로 테스트, 4! = 4*3*2*1 = 24
visited = [False for _ in range(len(operators))]


def dfs(operators, visited, result):
    global max_value, min_value
    if sum(visited) == len(visited):
        # print(result)
        try:
            # print(nums)
            # print(result)
            calculated = str(nums[0])
            for i in range(len(result)):
                if int(calculated) < 0 and result[i] == "//":  # 음수 나눗셈 예외처리
                    tmp = -1 * int(calculated)
                    tmp = str(tmp) + str(result[i]) + str(nums[i + 1])
                    calculated = -1 * eval(tmp)
                    calculated = str(calculated)
                else:
                    calculated += str(result[i])
                    calculated += str(nums[i + 1])
                    calculated = str(eval(calculated))
            calculated = int(calculated)
            # print(calculated)
        except ZeroDivisionError:
            return
        max_value = max(max_value, calculated)
        min_value = min(min_value, calculated)

        return result

    for i in range(len(visited)):
        if not visited[i]:
            visited[i] = True
            result.append(operators[i])
            dfs(operators, visited, result)
            visited[i] = False
            result.pop()


dfs(operators, visited, [])

print(max_value)
print(min_value)

 

 

모든 케이스를 dfs를 통해 탐색하면서 식이 완성될 때마다 계산하고, 최댓값 최솟값을 그 때에 갱신하는 방식이다.

까다로웠던 부분은 예외상황 처리였다.

0을 나누는 경우를 예외처리로 진행했고

 

음수를 나누는 경우 문제에 나와있는 것 처럼, 양수로 변환한 후 나눈셈을 진행하고 다시 음수로 바꾸는 로직을 추가해야했다.

(이걸 추가하지 않으면 다른 값이 나온다. 컴퓨팅 논리 상의 문제인걸로 보인다.)

 

하지만 시간초과

 

sys.stdin.readline도 적용해 보았지만 어림도 없었다.

 

고민도 좀 해보고 질문창도 좀 뒤져본 결과

모든 경우의 수를 탐색하다 보면 겹치는 경우가 생길 수 밖에 없었다.

이를 제외해주어야 시간을 줄일 수 있을 것이다.

 

따라서, dfs로직에서 만든 식을 set에 저장했다.

 

set은 특성상 중복이 사라진다.

대신에 list를 저장할 수 없다. 중복되지 않은 데이터를 저장하는 특성상 immutable한 데이터만 저장가능한 것으로 보인다. (뇌피셜임)

파이썬은 mutable데이터로 얼마든지 수정이 가능하다. 

따라서 리스트에 저장하던 연산자 정보를 tuple로 변경한다. (tuple은 수정이 불가능하다. immutable.)

tuple로 변경한 후에는 set에 저장할 수 있고, 연산자 리스트의 중복을 방지할 수 있다.

 

그리고 계산 및 최대/최솟값 갱신 로직을 실행한다.

 

정답 코드

## 개선내용 ##
# calculate와 경우의 수 생성하는 기능을 분리한다.
# 경우의 수가 중복되는 경우를 방지하기 위해 set에 저장한다.
import sys

input = sys.stdin.readline
num = int(input())
nums = [int(x) for x in input().split()]

# +, -, *, /
operator = ["+", "-", "*", "//"]
times = [int(x) for x in input().split()]
operators = []
for i in range(4):
    for j in range(times[i]):
        operators.append(operator[i])

max_value, min_value = -1000000000, 1000000000


calculates_set = set()
# 4개 기준으로 테스트, 4! = 4*3*2*1 = 24
visited = [False for _ in range(len(operators))]


def dfs(operators, visited, result):
    global max_value, min_value
    if sum(visited) == len(visited):
        calculates_set.add(tuple(result))
        return result

    for i in range(len(visited)):
        if not visited[i]:
            visited[i] = True
            result.append(operators[i])
            dfs(operators, visited, result)
            visited[i] = False
            result.pop()


dfs(operators, visited, [])

# print(result)

for result in calculates_set:
    try:
        # print(nums)
        # print(result)
        calculated = str(nums[0])
        for i in range(len(result)):
            if int(calculated) < 0 and result[i] == "//":  # 음수 나눗셈 예외처리
                tmp = -1 * int(calculated)
                tmp = str(tmp) + str(result[i]) + str(nums[i + 1])
                calculated = -1 * eval(tmp)
                calculated = str(calculated)
            else:
                calculated += str(result[i])
                calculated += str(nums[i + 1])
                calculated = str(eval(calculated))
        calculated = int(calculated)
        # print(calculated)
    except ZeroDivisionError:
        continue
    max_value = max(max_value, calculated)
    min_value = min(min_value, calculated)


print(max_value)
print(min_value)