Skip to content

Commit a38218d

Browse files
committed
Change x-axis in figure to episode
1 parent fc08187 commit a38218d

File tree

6 files changed

+12
-12
lines changed

6 files changed

+12
-12
lines changed

docs/toy-examples.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def train(args):
175175
# Evaluate the performance. Play with random agents.
176176
if episode % args.evaluate_every == 0:
177177
logger.log_performance(
178-
env.timestep,
178+
episode,
179179
tournament(
180180
env,
181181
args.num_eval_games,
@@ -356,7 +356,7 @@ def train(args):
356356
if episode % args.evaluate_every == 0:
357357
agent.save() # Save model
358358
logger.log_performance(
359-
env.timestep,
359+
episode,
360360
tournament(
361361
eval_env,
362362
args.num_eval_games

examples/pettingzoo/run_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def train(args):
9494
# Evaluate the performance. Play with random agents.
9595
if episode % args.evaluate_every == 0:
9696
average_rewards = tournament_pettingzoo(env, agents, args.num_eval_games)
97-
logger.log_performance(num_timesteps, average_rewards[learning_agent_name])
97+
logger.log_performance(episode, average_rewards[learning_agent_name])
9898

9999
# Get the paths
100100
csv_path, fig_path = logger.csv_path, logger.fig_path

examples/run_cfr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def train(args):
5959
if episode % args.evaluate_every == 0:
6060
agent.save() # Save model
6161
logger.log_performance(
62-
env.timestep,
62+
episode,
6363
tournament(
6464
eval_env,
6565
args.num_eval_games

examples/run_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def train(args):
7777
# Evaluate the performance. Play with random agents.
7878
if episode % args.evaluate_every == 0:
7979
logger.log_performance(
80-
env.timestep,
80+
episode,
8181
tournament(
8282
env,
8383
args.num_eval_games,

rlcard/utils/logger.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __enter__(self):
2323

2424
self.txt_file = open(self.txt_path, 'w')
2525
self.csv_file = open(self.csv_path, 'w')
26-
fieldnames = ['timestep', 'reward']
26+
fieldnames = ['episode', 'reward']
2727
self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
2828
self.writer.writeheader()
2929

@@ -38,16 +38,16 @@ def log(self, text):
3838
self.txt_file.flush()
3939
print(text)
4040

41-
def log_performance(self, timestep, reward):
41+
def log_performance(self, episode, reward):
4242
''' Log a point in the curve
4343
Args:
44-
timestep (int): the timestep of the current point
44+
episode (int): the episode of the current point
4545
reward (float): the reward of the current point
4646
'''
47-
self.writer.writerow({'timestep': timestep, 'reward': reward})
47+
self.writer.writerow({'episode': episode, 'reward': reward})
4848
print('')
4949
self.log('----------------------------------------')
50-
self.log(' timestep | ' + str(timestep))
50+
self.log(' episode | ' + str(episode))
5151
self.log(' reward | ' + str(reward))
5252
self.log('----------------------------------------')
5353

rlcard/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,11 @@ def plot_curve(csv_path, save_path, algorithm):
232232
xs = []
233233
ys = []
234234
for row in reader:
235-
xs.append(int(row['timestep']))
235+
xs.append(int(row['episode']))
236236
ys.append(float(row['reward']))
237237
fig, ax = plt.subplots()
238238
ax.plot(xs, ys, label=algorithm)
239-
ax.set(xlabel='timestep', ylabel='reward')
239+
ax.set(xlabel='episode', ylabel='reward')
240240
ax.legend()
241241
ax.grid()
242242

0 commit comments

Comments
 (0)