Skip to content

Commit e0b7aa0

Browse files
committed
2 parents a47cf24 + 6711b8b commit e0b7aa0

12 files changed

+2484
-0
lines changed

hw3/README.md

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# CS294-112 HW 3: Q-Learning
2+
3+
Dependencies:
4+
* Python **3.5**
5+
* Numpy version **1.14.5**
6+
* TensorFlow version **1.10.5**
7+
* MuJoCo version **1.50** and mujoco-py **1.50.1.56**
8+
* OpenAI Gym version **0.10.5**
9+
* seaborn
10+
* Box2D==**2.3.2**
11+
* OpenCV
12+
* ffmpeg
13+
14+
Before doing anything, first replace `gym/envs/box2d/lunar_lander.py` with the provided `lunar_lander.py` file.
15+
16+
The only files that you need to look at are `dqn.py` and `train_ac_f18.py`, which you will implement.
17+
18+
See the [HW3 PDF](http://rail.eecs.berkeley.edu/deeprlcourse/static/homeworks/hw3.pdf) for further instructions.
19+
20+
The starter code was based on an implementation of Q-learning for Atari generously provided by Szymon Sidor from OpenAI.

hw3/atari_wrappers.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#import sys
2+
#sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
3+
4+
import cv2
5+
import numpy as np
6+
from collections import deque
7+
import gym
8+
from gym import spaces
9+
10+
11+
class NoopResetEnv(gym.Wrapper):
12+
def __init__(self, env=None, noop_max=30):
13+
"""Sample initial states by taking random number of no-ops on reset.
14+
No-op is assumed to be action 0.
15+
"""
16+
super(NoopResetEnv, self).__init__(env)
17+
self.noop_max = noop_max
18+
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
19+
20+
def _reset(self):
21+
""" Do no-op action for a number of steps in [1, noop_max]."""
22+
self.env.reset()
23+
noops = np.random.randint(1, self.noop_max + 1)
24+
for _ in range(noops):
25+
obs, _, _, _ = self.env.step(0)
26+
return obs
27+
28+
class FireResetEnv(gym.Wrapper):
29+
def __init__(self, env=None):
30+
"""Take action on reset for environments that are fixed until firing."""
31+
super(FireResetEnv, self).__init__(env)
32+
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
33+
assert len(env.unwrapped.get_action_meanings()) >= 3
34+
35+
def _reset(self):
36+
self.env.reset()
37+
obs, _, _, _ = self.env.step(1)
38+
obs, _, _, _ = self.env.step(2)
39+
return obs
40+
41+
class EpisodicLifeEnv(gym.Wrapper):
42+
def __init__(self, env=None):
43+
"""Make end-of-life == end-of-episode, but only reset on true game over.
44+
Done by DeepMind for the DQN and co. since it helps value estimation.
45+
"""
46+
super(EpisodicLifeEnv, self).__init__(env)
47+
self.lives = 0
48+
self.was_real_done = True
49+
self.was_real_reset = False
50+
51+
def _step(self, action):
52+
obs, reward, done, info = self.env.step(action)
53+
self.was_real_done = done
54+
# check current lives, make loss of life terminal,
55+
# then update lives to handle bonus lives
56+
lives = self.env.unwrapped.ale.lives()
57+
if lives < self.lives and lives > 0:
58+
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
59+
# so its important to keep lives > 0, so that we only reset once
60+
# the environment advertises done.
61+
done = True
62+
self.lives = lives
63+
return obs, reward, done, info
64+
65+
def _reset(self):
66+
"""Reset only when lives are exhausted.
67+
This way all states are still reachable even though lives are episodic,
68+
and the learner need not know about any of this behind-the-scenes.
69+
"""
70+
if self.was_real_done:
71+
obs = self.env.reset()
72+
self.was_real_reset = True
73+
else:
74+
# no-op step to advance from terminal/lost life state
75+
obs, _, _, _ = self.env.step(0)
76+
self.was_real_reset = False
77+
self.lives = self.env.unwrapped.ale.lives()
78+
return obs
79+
80+
class MaxAndSkipEnv(gym.Wrapper):
81+
def __init__(self, env=None, skip=4):
82+
"""Return only every `skip`-th frame"""
83+
super(MaxAndSkipEnv, self).__init__(env)
84+
# most recent raw observations (for max pooling across time steps)
85+
self._obs_buffer = deque(maxlen=2)
86+
self._skip = skip
87+
88+
def _step(self, action):
89+
total_reward = 0.0
90+
done = None
91+
for _ in range(self._skip):
92+
obs, reward, done, info = self.env.step(action)
93+
self._obs_buffer.append(obs)
94+
total_reward += reward
95+
if done:
96+
break
97+
98+
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
99+
100+
return max_frame, total_reward, done, info
101+
102+
def _reset(self):
103+
"""Clear past frame buffer and init. to first obs. from inner env."""
104+
self._obs_buffer.clear()
105+
obs = self.env.reset()
106+
self._obs_buffer.append(obs)
107+
return obs
108+
109+
def _process_frame84(frame):
110+
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
111+
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
112+
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_LINEAR)
113+
x_t = resized_screen[18:102, :]
114+
x_t = np.reshape(x_t, [84, 84, 1])
115+
return x_t.astype(np.uint8)
116+
117+
class ProcessFrame84(gym.Wrapper):
118+
def __init__(self, env=None):
119+
super(ProcessFrame84, self).__init__(env)
120+
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
121+
122+
def _step(self, action):
123+
obs, reward, done, info = self.env.step(action)
124+
return _process_frame84(obs), reward, done, info
125+
126+
def _reset(self):
127+
return _process_frame84(self.env.reset())
128+
129+
class ClippedRewardsWrapper(gym.Wrapper):
130+
def _step(self, action):
131+
obs, reward, done, info = self.env.step(action)
132+
return obs, np.sign(reward), done, info
133+
134+
def wrap_deepmind_ram(env):
135+
env = EpisodicLifeEnv(env)
136+
env = NoopResetEnv(env, noop_max=30)
137+
env = MaxAndSkipEnv(env, skip=4)
138+
if 'FIRE' in env.unwrapped.get_action_meanings():
139+
env = FireResetEnv(env)
140+
env = ClippedRewardsWrapper(env)
141+
return env
142+
143+
def wrap_deepmind(env):
144+
assert 'NoFrameskip' in env.spec.id
145+
env = EpisodicLifeEnv(env)
146+
env = NoopResetEnv(env, noop_max=30)
147+
env = MaxAndSkipEnv(env, skip=4)
148+
if 'FIRE' in env.unwrapped.get_action_meanings():
149+
env = FireResetEnv(env)
150+
env = ProcessFrame84(env)
151+
env = ClippedRewardsWrapper(env)
152+
return env

0 commit comments

Comments
 (0)