Skip to content

Commit 58794c7

Browse files
committed
Update hyper params and set seeds
1 parent d627981 commit 58794c7

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@
9191
"cpu"
9292
)
9393

94+
# set the seeds for reproducibility
95+
seed = 42
96+
random.seed(seed)
97+
torch.manual_seed(seed)
98+
env.reset(seed=seed)
99+
env.action_space.seed(seed)
100+
env.observation_space.seed(seed)
101+
if torch.cuda.is_available():
102+
torch.cuda.manual_seed(seed)
103+
94104

95105
######################################################################
96106
# Replay Memory
@@ -253,13 +263,14 @@ def forward(self, x):
253263
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
254264
# TAU is the update rate of the target network
255265
# LR is the learning rate of the ``AdamW`` optimizer
266+
256267
BATCH_SIZE = 128
257268
GAMMA = 0.99
258-
EPS_START = 0.9
259-
EPS_END = 0.05
260-
EPS_DECAY = 1000
269+
EPS_START = 1
270+
EPS_END = 0.01
271+
EPS_DECAY = 2500
261272
TAU = 0.005
262-
LR = 1e-4
273+
LR = 5e-4
263274

264275
# Get number of actions from gym action space
265276
n_actions = env.action_space.n

0 commit comments

Comments
 (0)