Skip to content

Commit 4cef006

Browse files
committed
hw4
1 parent e0b7aa0 commit 4cef006

12 files changed

+1844
-0
lines changed

hw4/.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
plots/
2+
data/

hw4/half_cheetah_env.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
from gym import utils
4+
from gym.envs.mujoco import mujoco_env
5+
6+
class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
7+
def __init__(self):
8+
mujoco_env.MujocoEnv.__init__(self, 'half_cheetah.xml', 1)
9+
utils.EzPickle.__init__(self)
10+
11+
def step(self, action):
12+
xposbefore = self.sim.data.qpos[0]
13+
self.do_simulation(action, self.frame_skip)
14+
xposafter = self.sim.data.qpos[0]
15+
ob = self._get_obs()
16+
reward_ctrl = - 0.1 * np.square(action).sum()
17+
reward_run = (xposafter - xposbefore)/self.dt
18+
reward = reward_ctrl + reward_run
19+
done = False
20+
return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl)
21+
22+
def _get_obs(self):
23+
return np.concatenate([
24+
self.sim.data.qpos.flat[1:],
25+
self.sim.data.qvel.flat,
26+
self.get_body_com("torso").flat,
27+
# self.get_body_comvel("torso").flat,
28+
])
29+
30+
def reset_model(self):
31+
qpos = self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)
32+
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
33+
self.set_state(qpos, qvel)
34+
return self._get_obs()
35+
36+
def viewer_setup(self):
37+
self.viewer.cam.distance = self.model.stat.extent * 0.5
38+
39+
@staticmethod
40+
def cost_fn(states, actions, next_states):
41+
is_tf = tf.contrib.framework.is_tensor(states)
42+
is_single_state = (len(states.get_shape()) == 1) if is_tf else (len(states.shape) == 1)
43+
44+
if is_single_state:
45+
states = states[None, ...]
46+
actions = actions[None, ...]
47+
next_states = next_states[None, ...]
48+
49+
scores = tf.zeros(actions.get_shape()[0].value) if is_tf else np.zeros(actions.shape[0])
50+
51+
heading_penalty_factor = 10
52+
53+
# dont move front shin back so far that you tilt forward
54+
front_leg = states[:, 5]
55+
my_range = 0.2
56+
if is_tf:
57+
scores += tf.cast(front_leg >= my_range, tf.float32) * heading_penalty_factor
58+
else:
59+
scores += (front_leg >= my_range) * heading_penalty_factor
60+
61+
front_shin = states[:, 6]
62+
my_range = 0
63+
if is_tf:
64+
scores += tf.cast(front_shin >= my_range, tf.float32) * heading_penalty_factor
65+
else:
66+
scores += (front_shin >= my_range) * heading_penalty_factor
67+
68+
front_foot = states[:, 7]
69+
my_range = 0
70+
if is_tf:
71+
scores += tf.cast(front_foot >= my_range, tf.float32) * heading_penalty_factor
72+
else:
73+
scores += (front_foot >= my_range) * heading_penalty_factor
74+
75+
scores -= (next_states[:, 17] - states[:, 17]) / 0.01
76+
77+
if is_single_state:
78+
scores = scores[0]
79+
80+
return scores

hw4/logger.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import os
2+
from collections import defaultdict
3+
import logging
4+
from colorlog import ColoredFormatter
5+
6+
import pandas
7+
import numpy as np
8+
9+
from tabulate import tabulate
10+
11+
12+
class LoggerClass(object):
13+
GLOBAL_LOGGER_NAME = '_global_logger'
14+
15+
_color_formatter = ColoredFormatter(
16+
"%(asctime)s %(log_color)s%(name)-10s %(levelname)-8s%(reset)s %(white)s%(message)s",
17+
datefmt='%m-%d %H:%M:%S',
18+
reset=True,
19+
log_colors={
20+
'DEBUG': 'cyan',
21+
'INFO': 'green',
22+
'WARNING': 'yellow',
23+
'ERROR': 'red',
24+
'CRITICAL': 'red,bg_white',
25+
},
26+
secondary_log_colors={},
27+
style='%'
28+
)
29+
30+
_normal_formatter = logging.Formatter(
31+
'%(asctime)s %(name)-10s %(levelname)-8s %(message)s',
32+
datefmt='%m-%d %H:%M:%S',
33+
style='%'
34+
)
35+
36+
def __init__(self):
37+
self._dir = None
38+
self._logger = None
39+
self._log_path = None
40+
self._csv_path = None
41+
self._tabular = defaultdict(list)
42+
self._curr_recorded = list()
43+
self._num_dump_tabular_calls = 0
44+
45+
@property
46+
def dir(self):
47+
return self._dir
48+
49+
#############
50+
### Setup ###
51+
#############
52+
53+
def setup(self, display_name, log_path, lvl):
54+
self._dir = os.path.dirname(log_path)
55+
self._logger = self._get_logger(LoggerClass.GLOBAL_LOGGER_NAME,
56+
log_path,
57+
lvl=lvl,
58+
display_name=display_name)
59+
self._csv_path = os.path.splitext(log_path)[0] + '.csv'
60+
61+
### load csv if exists
62+
if os.path.exists(self._csv_path):
63+
self._tabular = {k: list(v) for k, v in pandas.read_csv(self._csv_path).items()}
64+
self._num_dump_tabular_calls = len(tuple(self._tabular.values())[0])
65+
66+
def _get_logger(self, name, log_path, lvl=logging.INFO, display_name=None):
67+
if isinstance(lvl, str):
68+
lvl = lvl.lower().strip()
69+
if lvl == 'debug':
70+
lvl = logging.DEBUG
71+
elif lvl == 'info':
72+
lvl = logging.INFO
73+
elif lvl == 'warn' or lvl == 'warning':
74+
lvl = logging.WARN
75+
elif lvl == 'error':
76+
lvl = logging.ERROR
77+
elif lvl == 'fatal' or lvl == 'critical':
78+
lvl = logging.CRITICAL
79+
else:
80+
raise ValueError('unknown logging level')
81+
82+
file_handler = logging.FileHandler(log_path)
83+
file_handler.setLevel(logging.DEBUG)
84+
file_handler.setFormatter(LoggerClass._normal_formatter)
85+
console_handler = logging.StreamHandler()
86+
console_handler.setLevel(lvl)
87+
console_handler.setFormatter(LoggerClass._color_formatter)
88+
if display_name is None:
89+
display_name = name
90+
logger = logging.getLogger(display_name)
91+
logger.setLevel(logging.DEBUG)
92+
logger.addHandler(console_handler)
93+
logger.addHandler(file_handler)
94+
95+
return logger
96+
97+
###############
98+
### Logging ###
99+
###############
100+
101+
def debug(self, s):
102+
assert (self._logger is not None)
103+
self._logger.debug(s)
104+
105+
def info(self, s):
106+
assert (self._logger is not None)
107+
self._logger.info(s)
108+
109+
def warn(self, s):
110+
assert (self._logger is not None)
111+
self._logger.warn(s)
112+
113+
def error(self, s):
114+
assert (self._logger is not None)
115+
self._logger.error(s)
116+
117+
def critical(self, s):
118+
assert (self._logger is not None)
119+
self._logger.critical(s)
120+
121+
####################
122+
### Data logging ###
123+
####################
124+
125+
def record_tabular(self, key, val):
126+
assert (str(key) not in self._curr_recorded)
127+
self._curr_recorded.append(str(key))
128+
129+
if key in self._tabular:
130+
self._tabular[key].append(val)
131+
else:
132+
self._tabular[key] = [np.nan] * self._num_dump_tabular_calls + [val]
133+
134+
def dump_tabular(self, print_func=None):
135+
if len(self._curr_recorded) == 0:
136+
return ''
137+
138+
### reset
139+
self._curr_recorded = list()
140+
self._num_dump_tabular_calls += 1
141+
142+
### make sure all same length
143+
for k, v in self._tabular.items():
144+
if len(v) == self._num_dump_tabular_calls:
145+
pass
146+
elif len(v) == self._num_dump_tabular_calls - 1:
147+
self._tabular[k].append(np.nan)
148+
else:
149+
raise ValueError('key {0} should not have {1} items when {2} calls have been made'.format(
150+
k, len(v), self._num_dump_tabular_calls))
151+
152+
### print
153+
if print_func is not None:
154+
log_str = tabulate(sorted([(k, v[-1]) for k, v in self._tabular.items()], key=lambda kv: kv[0]))
155+
for line in log_str.split('\n'):
156+
print_func(line)
157+
158+
### write to file
159+
tabular_pandas = pandas.DataFrame({k: pandas.Series(v) for k, v in self._tabular.items()})
160+
tabular_pandas.to_csv(self._csv_path)
161+
162+
163+
logger = LoggerClass()

hw4/main.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
import argparse
3+
import time
4+
5+
from half_cheetah_env import HalfCheetahEnv
6+
from logger import logger
7+
from model_based_rl import ModelBasedRL
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('question', type=str, choices=('q1, q2, q3'))
11+
parser.add_argument('--exp_name', type=str, default=None)
12+
parser.add_argument('--env', type=str, default='HalfCheetah', choices=('HalfCheetah',))
13+
parser.add_argument('--render', action='store_true')
14+
parser.add_argument('--mpc_horizon', type=int, default=15)
15+
parser.add_argument('--num_random_action_selection', type=int, default=4096)
16+
parser.add_argument('--nn_layers', type=int, default=1)
17+
args = parser.parse_args()
18+
19+
data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
20+
exp_name = '{0}_{1}_{2}'.format(args.env,
21+
args.question,
22+
args.exp_name if args.exp_name else time.strftime("%d-%m-%Y_%H-%M-%S"))
23+
exp_dir = os.path.join(data_dir, exp_name)
24+
assert not os.path.exists(exp_dir),\
25+
'Experiment directory {0} already exists. Either delete the directory, or run the experiment with a different name'.format(exp_dir)
26+
os.makedirs(exp_dir, exist_ok=True)
27+
logger.setup(exp_name, os.path.join(exp_dir, 'log.txt'), 'debug')
28+
29+
env = {
30+
'HalfCheetah': HalfCheetahEnv()
31+
}[args.env]
32+
33+
mbrl = ModelBasedRL(env=env,
34+
render=args.render,
35+
mpc_horizon=args.mpc_horizon,
36+
num_random_action_selection=args.num_random_action_selection,
37+
nn_layers=args.nn_layers)
38+
39+
run_func = {
40+
'q1': mbrl.run_q1,
41+
'q2': mbrl.run_q2,
42+
'q3': mbrl.run_q3
43+
}[args.question]
44+
run_func()

0 commit comments

Comments
 (0)