개요
강화학습 공부를 하다가 간단한 예제 말고 내가 실제로 하던 게임에 적용해보면 재밌을것 같아서
그나마 제일 단순한 Chrome Dino Game에 강화학습을 적용하여 자동으로 플레이하는 모델을 만들어보고 싶었다
구현에 도움을 얻고자 검색을 해보니 많은 사람들이 공룡 앞의 부분을 crop하여 CNN으로 어떤 동작을 해야할 지 classification하는 모델을 만들었으나, 이는 내가 만들고자 하는 모델과 차이가 있었다
위와 같은 방식 말고 전체 이미지를 캡쳐해서 어떠한 명시적 특징 추출 없이 범용적으로 적용될 수 있는 AI를 제작하고 싶었다
내가 저렇게 앞의 이미지를 잘라주는건 특징추출이나 마찬가지. 전체 이미지를 캡쳐해서 real time 훈련을 통해 agent가 저 부분의 이미지가 중요하다는걸 학습하는 것이 목표
또한 gym에 의존하지 않고 실시간 화면만 보고 적은 학습을 통해 높은 점수를 내고자 하였다. gym env에 의존하고싶지 않았기 때문
알고리즘은 강화학습 중 가장 기본적인 DQN을 사용한다
코드 및 설명
라이브러리 import
import numpy as np
from PIL import Image
import cv2
import io
import time
import random
import pickle
import os
from io import BytesIO
import base64
import json
import pandas as pd
from collections import deque
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.keys import Keys
from webdriver_manager.chrome import ChromeDriverManager
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
게임 실행 준비
GAME_URL = "chrome://dino"
CHROME_DRIVER_PATH = ChromeDriverManager().install()
DATA_DIR = "./data"
MODEL_DIR = "./model"
SAVE_INTERVAL = 1000
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
PARAMS_FILE = os.path.join(DATA_DIR, "params.pkl")
INIT_SCRIPT = "document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'"
GET_BASE64_SCRIPT = "canvasRunner = document.getElementById('runner-canvas'); return canvasRunner.toDataURL().substring(22)"
크롬창에 chrome://dino 치면 바로 dinogame으로 연결되는데, 이를 selenium에게 전달하여 코드가 엑세스 할 수 있도록 한다.
데이터와 모델을 저장할 디렉토리를 생성하고
parameter를 저장할 파일을 만든다 (중간에 훈련 중단되도 나중에 다시 불러올 수 있도록)
document로 html에 접근하여 게임 플레이 스크립트 부분을 가져온다
모델 관리 및 환경제어를 위한 함수 생성
def save_params(params):
with open(PARAMS_FILE, 'wb') as f:
pickle.dump(params, f, pickle.HIGHEST_PROTOCOL)
def load_params():
if os.path.isfile(PARAMS_FILE):
with open(PARAMS_FILE, 'rb') as f:
return pickle.load(f)
return {
"D": deque(maxlen=50000),
"time": 0,
"epsilon": 0.01
}
def load_model(model):
if os.path.isfile('./latest.pth'):
model.load_state_dict(torch.load('./latest.pth'))
return model
def grab_screen(driver):
image_b64 = driver.execute_script(GET_BASE64_SCRIPT)
screen = np.array(Image.open(BytesIO(base64.b64decode(image_b64))))
return process_img(screen)
def process_img(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = image[:300, :500]
image = cv2.resize(image, (80, 80))
return image
def show_img(graphs=False):
while True:
screen = (yield)
window_title = "logs" if graphs else "game_play"
cv2.namedWindow(window_title, cv2.WINDOW_NORMAL)
imS = cv2.resize(screen, (800, 400))
cv2.imshow(window_title, screen)
if cv2.waitKey(1) & 0xFF == ord('q'):
cv2.destroyAllWindows()
break
함수명대로다
Game 클래스 생성
class Game:
def __init__(self):
chrome_options = Options()
chrome_options.add_argument("disable-infobars")
chrome_options.add_argument("--mute-audio")
service = Service(CHROME_DRIVER_PATH)
self._driver = webdriver.Chrome(service=service, options=chrome_options)
self._driver.set_window_position(x=300, y=300)
self._driver.set_window_size(900, 600)
try :
self._driver.get(GAME_URL)
except:
pass
self._driver.execute_script("Runner.config.ACCELERATION=0")
self._driver.execute_script(INIT_SCRIPT)
def get_crashed(self):
return self._driver.execute_script("return Runner.instance_.crashed")
def get_playing(self):
return self._driver.execute_script("return Runner.instance_.playing")
def restart(self):
self._driver.execute_script("Runner.instance_.restart()")
def press_up(self):
self._driver.find_element("tag name", "body").send_keys(Keys.ARROW_UP)
def press_down(self):
self._driver.find_element("tag name", "body").send_keys(Keys.ARROW_DOWN)
def get_score(self):
score_array = self._driver.execute_script("return Runner.instance_.distanceMeter.digits")
return int(''.join(score_array))
def pause(self):
self._driver.execute_script("return Runner.instance_.stop()")
def resume(self):
self._driver.execute_script("return Runner.instance_.play()")
def end(self):
self._driver.close()
게임의 현재 status를 파악하거나 게임 상황을 제어하기 위한 메소드들이 담겨있는 class이다
이 클래스의 초기화 단계에 selenium을 실행하여 게임에 접근하기 시작한다
DinoAgent 생성
class DinoAgent:
def __init__(self, game):
self._game = game
self.jump()
def is_running(self):
return self._game.get_playing()
def is_crashed(self):
return self._game.get_crashed()
def jump(self):
self._game.press_up()
def duck(self):
self._game.press_down()
DinoAgent는 게임을 플레이하는 에이전트이다. 사람과 같이 게임에 접근하여 동작을 할 수 있지만, 동작이나 status 파악에 제한이 있다
GameState
class GameState:
def __init__(self, agent, game):
self._agent = agent
self._game = game
self._display = show_img()
self._display.__next__()
def get_state(self, actions):
score = self._game.get_score()
reward = 0.1
is_over = False
if actions[1] == 1:
self._agent.jump()
reward = -0.01
image = grab_screen(self._game._driver)
self._display.send(image)
if self._agent.is_crashed():
self._game.restart()
reward = -10
is_over = True
return image, reward, is_over
GameState 클래스는 게임의 상태를 관리하고, 에이전트 동작에 따른 게임 진행상황을 추적한다.
이곳에서 agent의 행동에 대해 reward(보상)의 양을 제어할 수 있다
HyperParameters
ACTIONS = 2
GAMMA = 0.99
OBSERVATION = 1000
EXPLORE = 500000
FINAL_EPSILON = 0.0001
INITIAL_EPSILON = 0.1
REPLAY_MEMORY = 100000
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
IMG_CHANNELS = 4
이따가 쓸 하이퍼 파라미터들. 딥러닝에 익숙하다면 변수명만으로 어떤 역할을 하는지 유추할 수 있다
Model
class DinoNet(nn.Module):
def __init__(self):
super(DinoNet, self).__init__()
self.conv1 = nn.Conv2d(4, 32, (8, 8), stride=4, padding=1)
self.conv2 = nn.Conv2d(32, 64, (4, 4), stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 64, (3, 3), stride=1, padding=1)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d((2, 2))
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, 2)
def forward(self, x):
x = x.permute(0, 3, 1, 2)
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.max_pool2d(self.relu(self.conv3(x)))
x = x.reshape(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
매우 간단한 CNN모델, 마지막 출력은 2차원이다. [가만히있기, 점프]
Training
제일 긴 부분이자 핵심 부분이므로, 나눠 설명합니다
def train_network(model, game_state, observe=False):
params = load_params()
D = params["D"]
t = params["time"]
epsilon = params["epsilon"]
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()
do_nothing = np.zeros(2)
do_nothing[0] = 1
x_t, r_0, terminal = game_state.get_state(do_nothing)
s_t = np.stack((x_t, x_t, x_t, x_t), axis=2)
s_t = s_t.reshape(1, s_t.shape[0], s_t.shape[1], s_t.shape[2])
OBSERVE = 999999999 if observe else 100
저장된 파라미터를 불러오고, 훈련을 위한 초기화를 진행한다
Adam optimizer / MSE Loss 사용
do nothing은 [1,0]. 아무것도 안하는 동작을 정의한다
x_t : 현재 게임의 이미지
r_0 : 현재 보상
terminal : 게임 종료 여부
x_t 를 네번 쌓아올려 s_t를 만든 후, 배치차원을 추가한다 (1, 높이, 너비, 4)
while True:
loss_sum = 0
a_t = np.zeros([2])
if random.random() <= epsilon:
action_index = random.randrange(2)
a_t[action_index] = 1
else:
q = model(torch.tensor(s_t).float())
_, action_index = torch.max(q, 1)
action_index = action_index.item()
a_t[action_index] = 1
if epsilon > FINAL_EPSILON and t > OBSERVE:
epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
x_t1, r_t, terminal = game_state.get_state(a_t)
x_t1 = x_t1.reshape(1, x_t1.shape[0], x_t1.shape[1], 1)
s_t1 = np.append(x_t1, s_t[:, :, :, :3], axis=3)
if len(D) > 50000:
D.pop()
D.append((s_t, action_index, r_t, s_t1, terminal))
epsilon 확률에 따라 무작위 행동(탐색) 또는 모델 예측에 따른 행동(최적)을 선택한다
그 후 epsilon을 점진적으로 감소시켜 훈련이 진행될수록 무작위 활동을 줄이고 모델 예측에 의존하도록 한다 (충분히 탐색했기 때문)
행동 후 state를 업데이트 하고, 이를 모델의 형태에 맞게 변형한다
그 후, 기존에 기록된 행동 s_t와 새롭게 관측된 state인 x_t1을 활용하여 새로운 상태 s_t1을 생성한다
리플레이 메모리 D가 일정 크기를 넘어설 때, 가장 오래된 경험을 제거하기 시작한다 (불필요)
if t > OBSERVE:
minibatch = random.sample(D, 16)
inputs = np.zeros((16, s_t.shape[1], s_t.shape[2], s_t.shape[3]))
targets = np.zeros((16, 2))
for i in range(16):
state_t, action_t, reward_t, state_t1, terminal = minibatch[i]
inputs[i:i + 1] = state_t
target = model(torch.tensor(state_t).float()).detach().numpy()[0]
Q_sa = model(torch.tensor(state_t1).float()).detach().numpy()[0]
if terminal:
target[action_t] = reward_t
else:
target[action_t] = reward_t + 0.99 * np.max(Q_sa)
targets[i] = target
outputs = model(torch.tensor(inputs).float())
loss = loss_fn(outputs, torch.tensor(targets).float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
s_t = s_t1 if not terminal else s_t
t += 1
if t % SAVE_INTERVAL == 0:
game_state._game.pause()
torch.save(model.state_dict(), os.path.join(MODEL_DIR, f"episode_{t}.pth"))
torch.save(model.state_dict(), "./latest.pth")
save_params({"D": D, "time": t, "epsilon": epsilon})
game_state._game.resume()
print(f'timestep: {t}, epsilon: {round(epsilon, 3)}, action: {action_index}, reward: {r_t}, loss: {round(loss_sum, 3)}')
모델 학습을 위해 리플레이 메모리에서 미니배치를 샘플링한다
현재 상태의 모델 예측 값을 target에 저장하고, 다음 상태의 모델 예측 값을 Q_sa에 저장한다
종료 상태일 경우 reward만을 할당하지만
게임이 실행되고 있는 경우는 다음의 모델 예측 값을 포함하여 보상을 업데이트 해준다
그 아래는 loss 계산 및 역전파 파라미터 업데이트
파라미터 업데이트까지 끝났으면 다음 상태를 현재 상태로 옮겨 다음훈련단계를 밟을 준비를 한다
그 밑은 모델저장 및 출력부
Run
def play_game(observe=False):
params = {"D": deque(maxlen=50000), "time": 0, "epsilon": 0.001}
save_params(params)
game = Game()
agent = DinoAgent(game)
game_state = GameState(agent, game)
try:
model = DinoNet()
model = load_model(model)
train_network(model, game_state, observe)
except StopIteration:
game.end()
play_game(observe=False)
위에서 선언한 함수들을 실제로 실행하는 부분
결과
스크린샷에서는 700점인데 실제로 훈련을 더 많이 시켜보니 최고점수가 1300점 이상까지 나왔다.
realtime 훈련임에도 불구하고 수 시간 내에 human level까지 학습이 완료되었다
현재의 행동이 미래의 행동에 영향이 거의 없다고 볼 수 있으나,
장애물이 오고있는데 쓸데없는 점프를 하는 경우에는 현재의 행동이 미래에 영향이 있어지기 때문에 이를 해결하기 위해
점프할 시 reward를 아주 조금 줄이는 방향으로 수정하였다
위 이슈를 제외하면 현재 시점의 행동은 미래에 잠재적인 영향이 없는
매우 단순한 게임이기 때문에 학습을 진행하면 할수록 더 높은 점수가 나올 것이라고 예상한다
Dueling DQN과 같은 더 효율적인 알고리즘을 사용한다면 빠른 학습이 가능할 것이다