Skip to content

Commit

Permalink
Add Preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinmanne committed Dec 4, 2020
1 parent c0095de commit 4bd3999
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 12 deletions.
10 changes: 2 additions & 8 deletions agent.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import numpy as np
import gym
import tensorflow as tf

from model import create_dqn_model
from util.circular_buffer import CircularBuffer
from util.environment import MainPreprocessing, FrameStack
from config import *


class DQNAgent:
def __init__(self):
self.env = gym.wrappers.FrameStack(
gym.wrappers.AtariPreprocessing(
gym.make(ENVIRONMENT),
scale_obs=True
),
num_stack=4
)
self.env = FrameStack(MainPreprocessing(gym.make(ENVIRONMENT)))
self.action_space = self.env.action_space
self.observation_shape = self.env.observation_space.shape
self.observation_dtype = self.env.observation_space.dtype
Expand Down
4 changes: 0 additions & 4 deletions main_rllib.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import sys
import ray

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ tensorflow
gym[atari]
numpy
ray[rllib]
cv2
53 changes: 53 additions & 0 deletions util/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import cv2
import gym
import numpy as np

from util.circular_buffer import CircularBuffer

HEIGHT, WIDTH = 84, 84


class MainPreprocessing(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.shape = (WIDTH, HEIGHT)
self.observation_space = gym.spaces.Box(
low=0.,
high=1.,
shape=self.shape,
dtype=np.float32
)

def observation(self, obs):
obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
obs = cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA)
return np.array(obs).astype(np.float32)/255


FRAME_STACK_COUNT = 4


class FrameStack(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.frame_count = FRAME_STACK_COUNT
self.frames = CircularBuffer(4, env.observation_space.shape)
self.observation_space = gym.spaces.Box(
low=np.array([env.observation_space.low]*4),
high=np.array([env.observation_space.high]*4),
shape=(self.frame_count, *env.observation_space.shape),
dtype=env.observation_space.dtype
)

def get_frames(self):
return self.frames[np.arange(self.frame_count)]

def reset(self):
obs = self.env.reset()
for _ in range(self.frame_count):
self.frames.append(obs)
return self.get_frames()

def observation(self, obs):
self.frames.append(obs)
return self.get_frames()

0 comments on commit 4bd3999

Please sign in to comment.