Skip to content

Commit e0d799a

Browse files
committed
DQN loop Transformer, stub
1 parent d20fb59 commit e0d799a

18 files changed

+220
-0
lines changed
File renamed without changes.

DQN_logic_pyTorch.py DQN_logic.py

File renamed without changes.
File renamed without changes.
File renamed without changes.

DQN_loop_Transformer.py

+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
"""
2+
First multi-step experiment
3+
RL will output some "intermediate" results that aren't actions.
4+
actions 0-8 = tic-tac-toe actions
5+
actions 9-17 = intermediate thoughts
6+
These will be put into a special area of the "state".
7+
For more explanations see: README-RL-with-autoencoder.md
8+
9+
Using:
10+
PyTorch: 1.9.1+cpu
11+
gym: 0.8.0
12+
"""
13+
14+
import torch
15+
import torch.nn as nn
16+
import torch.optim as optim
17+
import torch.nn.functional as F
18+
from torch.autograd import Variable
19+
from torch.distributions import Categorical
20+
from torch.distributions import Normal
21+
22+
import random
23+
import numpy as np
24+
np.random.seed(7)
25+
torch.manual_seed(7)
26+
device = torch.device("cpu")
27+
28+
class ReplayBuffer:
29+
def __init__(self, capacity):
30+
self.capacity = capacity
31+
self.buffer = []
32+
self.position = 0
33+
34+
def push(self, state, action, reward, next_state, done):
35+
if len(self.buffer) < self.capacity:
36+
self.buffer.append(None)
37+
self.buffer[self.position] = (state, action, reward, next_state, done)
38+
self.position = (self.position + 1) % self.capacity
39+
40+
def last_reward(self):
41+
return self.buffer[self.position-1][2]
42+
43+
def sample(self, batch_size):
44+
batch = random.sample(self.buffer, batch_size)
45+
state, action, reward, next_state, done = \
46+
map(np.stack, zip(*batch)) # stack for each element
47+
'''
48+
the * serves as unpack: sum(a,b) <=> batch=(a,b), sum(*batch) ;
49+
zip: a=[1,2], b=[2,3], zip(a,b) => [(1, 2), (2, 3)] ;
50+
the map serves as mapping the function on each list element: map(square, [2,3]) => [4,9] ;
51+
np.stack((1,2)) => array([1, 2])
52+
'''
53+
# print("sampled state=", state)
54+
# print("sampled action=", action)
55+
return state, action, reward, next_state, done
56+
57+
def __len__(self):
58+
return len(self.buffer)
59+
60+
class DQN():
61+
62+
def __init__(
63+
self,
64+
action_dim,
65+
state_dim,
66+
learning_rate = 3e-4,
67+
gamma = 1.0 ):
68+
super(DQN, self).__init__()
69+
70+
self.action_dim = action_dim
71+
self.state_dim = state_dim
72+
self.lr = learning_rate
73+
self.gamma = gamma
74+
75+
self.replay_buffer = ReplayBuffer(int(1e6))
76+
77+
hidden_dim = 9
78+
79+
self._build_net()
80+
81+
self.q_criterion = nn.MSELoss()
82+
self.q_optimizer = optim.Adam(self.trm.parameters(), lr=self.lr)
83+
84+
def _build_net(self):
85+
encoder_layer = nn.TransformerEncoderLayer(d_model=3, nhead=1)
86+
self.trm = nn.TransformerEncoder(encoder_layer, num_layers=1)
87+
# W is a 3x9 matrix, to convert 3-vector to 9-vector probability distribution:
88+
self.W = Variable(torch.randn(2, 9), requires_grad=True)
89+
self.softmax = nn.Softmax(dim=0)
90+
91+
def forward(self, x):
92+
# input dim = n_features = 9 x 2 x 2 = 36
93+
# First we need to split the input into 18 parts:
94+
# print("x =", x)
95+
xs = torch.stack(torch.split(x, 2, 1), 1)
96+
# print("xs =", xs)
97+
# There is a question of how these are stacked, 9x3 or 3x9?
98+
# it has to conform with Transformer's d_model = 3
99+
ys = self.trm(xs) # no need to split results, already in 9x3 chunks
100+
# print("ys =", ys)
101+
# it seems that only the last 3-dim vector is useful
102+
u = torch.matmul( ys.select(1, 8), self.W )
103+
# *** sum the probability distributions together:
104+
# z = torch.stack(zs, dim=1)
105+
# u = torch.sum(z, dim=1)
106+
# v = self.softmax(u)
107+
# print("v =", v)
108+
return u
109+
110+
def choose_action(self, state, deterministic=True):
111+
state = torch.FloatTensor(state).unsqueeze(0).to(device)
112+
logits = self.symnet(state)
113+
probs = torch.softmax(logits, dim=1)
114+
dist = Categorical(probs)
115+
action = dist.sample().numpy()[0]
116+
117+
# print("chosen action=", action)
118+
return action
119+
120+
def update(self, batch_size, reward_scale, gamma=1.0):
121+
# alpha = 1.0 # trade-off between exploration (max entropy) and exploitation (max Q); not used now
122+
123+
state, action, reward, next_state, done = self.replay_buffer.sample(batch_size)
124+
# print('sample (state, action, reward, next state, done):', state, action, reward, next_state, done)
125+
126+
state = torch.FloatTensor(state).to(device)
127+
next_state = torch.FloatTensor(next_state).to(device)
128+
action = torch.LongTensor(action).to(device)
129+
reward = torch.FloatTensor(reward).to(device) # .to(device) # reward is single value, unsqueeze() to add one dim to be [reward] at the sample dim;
130+
done = torch.BoolTensor(done).to(device)
131+
132+
logits = self.symnet(state)
133+
next_logits = self.symnet(next_state)
134+
135+
# **** Train deep Q function, this is just Bellman equation:
136+
# DQN(st,at) += η [ R + γ max_a DQN(s_t+1,a) - DQN(st,at) ]
137+
# DQN[s, action] += self.lr *( reward + self.gamma * np.max(DQN[next_state, :]) - DQN[s, action] )
138+
# max 是做不到的,但似乎也可以做到。 DQN 输出的是 probs.
139+
# probs 和 Q 有什么关系? Q 的 Boltzmann 是 probs (SAC 的做法).
140+
# This implies that Q = logits.
141+
# logits[at] += self.lr *( reward + self.gamma * np.max(logits[next_state, next_a]) - logits[at] )
142+
q = logits[range(logits.shape[0]), action]
143+
# maxq = torch.softmax(next_logits, 1, keepdim=False).values
144+
softmaxQ = torch.log(torch.sum(torch.exp(next_logits), 1))
145+
# print("softmaxQ:", softmaxQ.shape)
146+
# q = q + self.lr *( reward + self.gamma * m - q )
147+
# torch.where: if condition then arg2 else arg3
148+
target_q = torch.where(done, reward, reward + self.gamma * softmaxQ)
149+
# print("q, target_q:", q.shape, target_q.shape)
150+
q_loss = self.q_criterion(q, target_q.detach())
151+
152+
self.q_optimizer.zero_grad()
153+
q_loss.backward()
154+
self.q_optimizer.step()
155+
156+
return
157+
158+
def visualize_q(self, board, memory):
159+
# convert board vector to state vector
160+
vec = []
161+
for i in range(9):
162+
symbol = board[i]
163+
vec += [symbol, i-4]
164+
for i in range(9):
165+
if memory[i] == 1:
166+
vec += [-2, i-4]
167+
else:
168+
vec += [2,0]
169+
state = torch.FloatTensor(vec).unsqueeze(0).to(device)
170+
logits = self.symnet(state)
171+
probs = torch.softmax(logits, dim=1)
172+
return probs.squeeze(0)
173+
174+
def net_info(self):
175+
config_h = "(2)-9-9"
176+
config_g = "9-9-(9)"
177+
total = 0
178+
neurons = config_h.split('-')
179+
last_n = 3
180+
for n in neurons[1:]:
181+
n = int(n)
182+
total += last_n * n
183+
last_n = n
184+
total *= 9
185+
186+
neurons = config_g.split('-')
187+
for n in neurons[1:-1]:
188+
n = int(n)
189+
total += last_n * n
190+
last_n = n
191+
total += last_n * 9
192+
return (config_h + ':' + config_g, total)
193+
194+
def play_random(self, state, action_space):
195+
# Select an action (0-9) randomly
196+
# NOTE: random player never chooses occupied squares
197+
empties = [0,1,2,3,4,5,6,7,8]
198+
# Find and collect all empty squares
199+
# scan through all 9 propositions, each proposition is a 2-vector
200+
for i in range(0, 18, 2):
201+
# 'proposition' is a numpy array[3]
202+
proposition = state[i : i + 2]
203+
sym = proposition[0]
204+
if sym == 1 or sym == -1:
205+
x = proposition[1]
206+
j = x + 4
207+
empties.remove(j)
208+
# Select an available square randomly
209+
action = random.sample(empties, 1)[0]
210+
return action
211+
212+
def save_net(self, fname):
213+
torch.save(self.symnet.state_dict(), \
214+
"PyTorch_models/" + fname + ".dict")
215+
print("Model saved.")
216+
217+
def load_net(self, fname):
218+
self.symnet.load_state_dict(torch.load("PyTorch_models/" + fname + ".dict"))
219+
self.symnet.eval()
220+
print("Model loaded.")
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

PG_full_pyTorch.py PG_full.py

File renamed without changes.

PG_logic_pyTorch.py PG_logic.py

File renamed without changes.

PG_symNN_pyTorch.py PG_symNN.py

File renamed without changes.

RL_DQN_pyTorch.py RL_DQN.py

File renamed without changes.

SAC_full_pyTorch.py SAC_full.py

File renamed without changes.

0 commit comments

Comments
 (0)