This repository was archived by the owner on Jan 15, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathsoft_actor_critic_inverted_double_pendulum.py
More file actions
324 lines (269 loc) · 13.6 KB
/
soft_actor_critic_inverted_double_pendulum.py
File metadata and controls
324 lines (269 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import typing as tp
from collections import OrderedDict, deque
import matplotlib.animation as anim
import random
from dataclasses import dataclass
from itertools import count
import torch
from torch import nn, Tensor
ENV_NAME = "InvertedDoublePendulum-v5"
def update_scene(num, frames, patch):
patch.set_data(frames[num])
return patch,
def plot_animation(frames:list, save_path:tp.Optional[str]=None, repeat=False, interval=40, title:tp.Optional[str]=None):
fig = plt.figure()
patch = plt.imshow(frames[0])
plt.axis('off')
if title:
plt.title(title)
animation = anim.FuncAnimation(
fig, update_scene, fargs=(frames, patch),
frames=len(frames), repeat=repeat, interval=interval)
if save_path is not None:
animation.save(save_path, writer="pillow", fps=20)
return animation
def show_one_episode(action_sampler:tp.Callable, save_path:tp.Optional[str]=None, repeat=False, title:tp.Optional[str]=None):
frames = []
env = gym.make(ENV_NAME, render_mode="rgb_array")
obs, info = env.reset()
sum_rewards = int(0)
with torch.no_grad():
for step in count(0):
frames.append(env.render())
action = action_sampler(obs)
obs, reward, done, truncated, info = env.step(action)
sum_rewards += reward
if done or truncated:
print("Sum of Rewards:", sum_rewards)
print(f"{'done' if done else 'truncated'} at step", step+1)
break
env.close()
return plot_animation(frames, repeat=repeat, save_path=save_path, title=title)
@dataclass
class xonfig:
num_episodes:int = 7000
gamma:float = 0.99
device:torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_updates:int = 1 # idk why but for 10 updates, result's weren't so great...
update_every_n_steps:int = 1
adaptive_alpha:bool = True
alpha:float = 0.2 # initial value
tau: float = 0.005
buffer_size:int = 500_000
batch_size:int = 64
dqn_lr:float = 5e-4
actor_lr:float = 5e-4
alpha_lr:float = 5e-4
weight_decay:float = 0.0
hidden_dim:int = 64
class ActorPolicy(nn.Module):
def __init__(
self,
state_dims:int,
hidden_dim:int,
action_dims:int,
):
super().__init__()
self.l1 = nn.Linear(state_dims, hidden_dim); self.relu1 = nn.ReLU()
self.l2 = nn.Linear(hidden_dim, hidden_dim); self.relu2 = nn.ReLU()
self.mu_mean = nn.Linear(hidden_dim, action_dims)
self.sigma_log_std = nn.Linear(hidden_dim, action_dims)
def forward(self, state:Tensor):
x = self.relu1(self.l1(state))
x = self.relu2(self.l2(x))
mu = self.mu_mean(x)
# If log_std is too small (e.g. log_std << 20, the standard deviation becomes extremely close to zero, leading to highly peaked distributions.
# This can cause numerical issues like exploding gradients or division by near-zero values during backpropagation
# if log_std is too large (e.g., log_std > 2), the standard deviation becomes excessively large, leading to very high-variance policies,
# high exploration, and poor convergence.
# [-20, 2] has been found to work well across a variety of continuous control tasks in reinforcement learning
std = torch.clip(self.sigma_log_std(x), -20, 2).exp() # after exp bounds => (2.06e-09, 7.3890)
return mu, std
class CriticActionValue(nn.Module):
def __init__(self, state_dims:int, hidden_dim:int, action_dims:int):
super().__init__()
self.l1 = nn.Linear(state_dims + action_dims, hidden_dim); self.relu1 = nn.ReLU()
self.l2 = nn.Linear(hidden_dim, hidden_dim); self.relu2 = nn.ReLU()
self.l3 = nn.Linear(hidden_dim, 1)
def forward(
self,
state:Tensor, # (B, state_dims)
action:Tensor # (B, action_dims)
):
x = torch.cat([state, action], dim=-1) # (B, dim = state_dims + action_dims)
x = self.relu1(self.l1(x))
x = self.relu2(self.l2(x))
q_value = self.l3(x)
return q_value # (B, 1)
@torch.no_grad()
def update_ema(ema_model:nn.Module, model:nn.Module, decay:float):
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
# ema = decay * ema + (1 - decay) * param
ema_params[name].mul_(decay).add_(param.data, alpha=1-decay)
def sample_actions(state:Tensor, actions_max_bound:float): # (B, state_dims)
mu, std = policy_net(state) # (B, action_dims), (B, action_dims)
dist = torch.distributions.Normal(mu, std)
unbound_action = dist.rsample() # (B, actions_dim) # dist.sample() is torch.no_grad() mode
action = torch.tanh(unbound_action)*actions_max_bound # [-1, 1] * max => [-max, max]
# Tanh correction: TODO add intuition
log_prob = dist.log_prob(unbound_action) - torch.log(1 - action.pow(2) + 1e-6) # (B, action_dim)
## why sum on log prob? because dist.log_prob(a_next) is (B, action_dim) and we want to sum over action_dim to get (B, 1)
## Each action dimension is independent, so the total log-prob of a joint action vector
### Imagine you have two independent clocks, each ticking with its own probability.
### The chance of clock 1 showing x and clock 2 showing y is the product of their individual chances:
### p(x, y) = p1(x) * p2(y)
### In log-space, that product becomes a sum:
### log p(x, y) = log p1(x) + log p2(y)
### Now generalize this to a d-dimensional action vector.
### If each component of the action is sampled independently from a Gaussian distribution,
### then the joint probability of the entire action is the product of all d marginal probabilities:
### p(a) = p1(a1) * p2(a2) * ... * pd(ad)
### Taking logs:
### log p(a) = log p1(a1) + log p2(a2) + ... + log pd(ad)
log_prob:Tensor = log_prob.sum(dim=-1, keepdim=True) # Sum over action dimensions
return action, log_prob
@torch.compile()
def sac_train_step(
states:Tensor,
actions:Tensor,
next_states:Tensor,
rewards:Tensor,
is_terminal:Tensor
):
"""
* `states`: `(B, state_dim)`
* `actions`: `(B, action_dim)`
* `next_states`: `(B, state_dim)`
* `rewards`: `(B,)`
* `is_terminal`: `(B,)`
"""
rewards, is_terminal = rewards.unsqueeze(-1), is_terminal.unsqueeze(-1) # (B,) => (B, 1)
# Optimize DQNs
## a_next ~ π(s_next)
## get target Q values: y = r + γ * ( Q_target(s_next, a_next) - α * log(π(a_next|s_next)) ) * (1 - is_terminal)
## L1 = MSE(Q1(s, a), y) ## L2 = MSE(Q2(s, a), y) ## optimize loss (L1, L2)
with torch.no_grad():
actions_next, log_prob = sample_actions(next_states, ACTION_BOUNDS)
q_next1, q_next2 = dqn_target1(next_states, actions_next), dqn_target2(next_states, actions_next) # (B, 1), (B, 1)
# why min of the two q values? To avoid maximization bias, see https://arxiv.org/abs/1812.05905
q_next:Tensor = torch.min(q_next1, q_next2) - xonfig.alpha * log_prob # (B, 1)
q_next_target:Tensor = rewards + xonfig.gamma * q_next * (1 - is_terminal) # (B, 1)
dqn1_loss = nn.functional.mse_loss(dqn1(states, actions), q_next_target, reduction="mean")
dqn2_loss = nn.functional.mse_loss(dqn2(states, actions), q_next_target, reduction="mean")
(dqn1_loss + dqn2_loss).backward() # dqn1_loss.backward(); dqn2_loss.backward()
dqn1_optimizer.step(); dqn2_optimizer.step()
dqn1_optimizer.zero_grad(); dqn2_optimizer.zero_grad()
# Optimize Policy
dqn1.requires_grad_(False); dqn2.requires_grad_(False)
actions, log_probs = sample_actions(states, ACTION_BOUNDS)
## maximize entropy, minimize negative entropy
## maximize q value by minimizing -q value, tweaks the policy weights through the actions to maximize q value, doesn't tweak the q network itself as they are freezed
pi_loss:Tensor = (xonfig.alpha * log_probs - torch.min(dqn1(states, actions), dqn2(states, actions))).mean()
pi_loss.backward()
policy_optimizer.step()
policy_optimizer.zero_grad()
dqn1.requires_grad_(True); dqn2.requires_grad_(True)
# Optimize Alpha
if xonfig.adaptive_alpha:
alpha_loss = -log_alpha * (log_prob + target_entropy).mean()
alpha_loss.backward()
alpha_optimizer.step()
alpha_optimizer.zero_grad()
xonfig.alpha = log_alpha.exp().item()
if __name__ == "__main__":
SEED = 42
random.seed(SEED)
np.random.seed(SEED+1)
torch.manual_seed(SEED+2)
torch.use_deterministic_algorithms(mode=True, warn_only=True)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
env = gym.make(ENV_NAME)
STATE_DIMS = env.observation_space.shape[0] # 24
ACTION_DIMS = env.action_space.shape[0] # 4
ACTION_BOUNDS = [env.action_space.low, env.action_space.high] # [array([-1., -1., -1., -1.], dtype=float32), array([1., 1., 1., 1.], dtype=float32)]
ACTION_BOUNDS = ACTION_BOUNDS[1][0] # 1.0
if xonfig.adaptive_alpha:
log_alpha = torch.nn.Parameter(
torch.tensor(np.log(xonfig.alpha)),
requires_grad=True
)
target_entropy = torch.tensor(-ACTION_DIMS, device=xonfig.device, dtype=torch.float32)
alpha_optimizer = torch.optim.AdamW([log_alpha], lr=xonfig.alpha_lr, weight_decay=0.0)
dqn1 = CriticActionValue(STATE_DIMS, xonfig.hidden_dim, ACTION_DIMS).to(xonfig.device); dqn1.compile()
dqn2 = deepcopy(dqn1)
dqn_target1 = deepcopy(dqn1).requires_grad_(False)
dqn_target2 = deepcopy(dqn2).requires_grad_(False)
dqn1_optimizer = torch.optim.AdamW(dqn1.parameters(), lr=xonfig.dqn_lr, weight_decay=xonfig.weight_decay)
dqn2_optimizer = torch.optim.AdamW(dqn2.parameters(), lr=xonfig.dqn_lr, weight_decay=xonfig.weight_decay)
policy_net = ActorPolicy(STATE_DIMS, xonfig.hidden_dim, ACTION_DIMS).to(xonfig.device); policy_net.compile()
policy_optimizer = torch.optim.AdamW(policy_net.parameters(), lr=xonfig.actor_lr, weight_decay=xonfig.weight_decay)
replay_buffer = deque(maxlen=xonfig.buffer_size)
sum_rewards_list = []; num_timesteps_list = []; num_steps_over = int(1)
try:
for episode in range(1, xonfig.num_episodes+1):
state, info = env.reset()
state = torch.as_tensor(state, device=xonfig.device, dtype=torch.float32)
sum_rewards = float(0)
for tstep in count(1):
# sample action from policy
with torch.no_grad():
action, _log_prob = sample_actions(state.unsqueeze(0), ACTION_BOUNDS) # (1, actions_dims)
action = action.squeeze(0) # (action_dims,)
# action into the environment and get the next state and reward
next_state, reward, done, truncated, info = env.step(action.cpu().detach().numpy())
next_state = torch.as_tensor(next_state, dtype=torch.float32, device=xonfig.device)
sum_rewards += reward
# store the transition in the replay buffer
replay_buffer.append((
next_state.cpu(), action.cpu(), torch.as_tensor(reward).cpu(),
state.cpu(), torch.as_tensor(done).cpu()
))
# optimize networks
if num_steps_over % xonfig.update_every_n_steps == 0 and len(replay_buffer) >= xonfig.batch_size*5:
for _ in range(xonfig.num_updates):
batched_samples = random.sample(replay_buffer, xonfig.batch_size)
next_states, actions, rewards, states, dones = [
torch.as_tensor(np.asarray(inst), device=xonfig.device, dtype=torch.float32) for inst in list(zip(*batched_samples))
] # (B, state_dim), (B, action_dim), (B,), (B, state_dim), (B,)
sac_train_step(states, actions, next_states, rewards, dones)
update_ema(dqn_target1, dqn1, decay=1 - xonfig.tau)
update_ema(dqn_target2, dqn2, decay=1 - xonfig.tau)
if done or truncated:
break
state = next_state
num_steps_over += 1
sum_rewards_list.append(sum_rewards)
num_timesteps_list.append(tstep)
print(f"|| Episode: {episode} || Sum of Rewards: {sum_rewards:.4f} || Timesteps: {tstep} ||")
except KeyboardInterrupt:
print("Training Interrupted")
adaptive_str = 'adaptive_alpha' if xonfig.adaptive_alpha else ''
def get_deterministic_actions(state:Tensor, action_bounds=ACTION_BOUNDS):
state = state.unsqueeze(0) # (1, state_dims)
mu, _ = policy_net(state)
action = torch.tanh(mu).mul(action_bounds)
action = action.squeeze(0).cpu().numpy()
return action
show_one_episode(
lambda x: get_deterministic_actions(torch.as_tensor(x, dtype=torch.float32, device=xonfig.device), ACTION_BOUNDS),
save_path=f"images/sac_{ENV_NAME.lower()}_.gif",
title=f"SAC Trained Agent {f'{adaptive_str} ' if adaptive_str else ''}"
); plt.close()
plt.plot(sum_rewards_list, label="Sum of Rewards")
plt.plot(num_timesteps_list, label="Timesteps")
plt.yticks(np.arange(0, mx:=max(sum_rewards_list), mx//10).tolist())
plt.xlabel("Episodes")
plt.grid(True)
plt.legend()
plt.ylabel("Sum of Rewards")
plt.title("Sum of Rewards per Episode")
plt.savefig(f"images/sac_rewards_{adaptive_str}_{ENV_NAME}.png")
plt.show()
plt.close()