Skip to content

Commit 9ec4e9a

Browse files
committed
PR feedback
1 parent 5e6d715 commit 9ec4e9a

File tree

3 files changed

+8
-23
lines changed

3 files changed

+8
-23
lines changed

examples/run_rl.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
plot_curve,
1717
)
1818

19-
2019
def train(args):
2120

2221
# Check whether gpu is available
@@ -37,9 +36,7 @@ def train(args):
3736
if args.algorithm == 'dqn':
3837
from rlcard.agents import DQNAgent
3938
if args.load_checkpoint_path != "":
40-
dict = torch.load(args.load_checkpoint_path)
41-
agent = DQNAgent.from_checkpoint(checkpoint = dict)
42-
del dict
39+
agent = DQNAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
4340
else:
4441
agent = DQNAgent(
4542
num_actions=env.num_actions,
@@ -53,9 +50,7 @@ def train(args):
5350
elif args.algorithm == 'nfsp':
5451
from rlcard.agents import NFSPAgent
5552
if args.load_checkpoint_path != "":
56-
dict = torch.load(args.load_checkpoint_path)
57-
agent = NFSPAgent.from_checkpoint(checkpoint = dict)
58-
del dict
53+
agent = NFSPAgent.from_checkpoint(checkpoint=torch.load(args.load_checkpoint_path))
5954
else:
6055
agent = NFSPAgent(
6156
num_actions=env.num_actions,
@@ -64,7 +59,7 @@ def train(args):
6459
q_mlp_layers=[64,64],
6560
device=device,
6661
save_path=args.log_dir,
67-
save_every=500
62+
save_every=args.save_every
6863
)
6964
agents = [agent]
7065
for _ in range(1, env.num_players):
@@ -111,7 +106,7 @@ def train(args):
111106
torch.save(agent, save_path)
112107
print('Model saved in', save_path)
113108

114-
if __name__ == '__main__':
109+
if __name__ == '__main__':
115110
parser = argparse.ArgumentParser("DQN/NFSP example in RLCard")
116111
parser.add_argument(
117112
'--env',

rlcard/agents/dqn_agent.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self,
109109

110110
# The epsilon decay scheduler
111111
self.epsilons = np.linspace(epsilon_start, epsilon_end, epsilon_decay_steps)
112-
112+
113113
# Create estimators
114114
self.q_estimator = Estimator(num_actions=num_actions, learning_rate=learning_rate, state_shape=state_shape, \
115115
mlp_layers=mlp_layers, device=self.device)
@@ -226,7 +226,7 @@ def train(self):
226226
if self.train_t % self.update_target_estimator_every == 0:
227227
self.target_estimator = deepcopy(self.q_estimator)
228228
print("\nINFO - Copied model parameters to target network.")
229-
229+
230230
self.train_t += 1
231231

232232
if self.save_path and self.train_t % self.save_every == 0:
@@ -312,17 +312,7 @@ def from_checkpoint(cls, checkpoint):
312312

313313

314314
return agent_instance
315-
316-
317-
318-
def save(self, path):
319-
''' Save the model (q_estimator weights only)
320-
321-
Args:
322-
path (str): the path to save the model
323-
'''
324-
torch.save(self.q_estimator.model.state_dict(), path)
325-
315+
326316
def save_checkpoint(self, path, filename='checkpoint_dqn.pt'):
327317
''' Save the model checkpoint (all attributes)
328318

rlcard/agents/nfsp_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(self,
114114

115115
# Total timesteps
116116
self.total_t = 0
117-
117+
118118
# Total training step
119119
self.train_t = 0
120120

0 commit comments

Comments
 (0)