Skip to content

Commit

Permalink
Add save/load for CFR
Browse files Browse the repository at this point in the history
  • Loading branch information
daochenzha committed Nov 11, 2019
1 parent c7ea916 commit 65a429d
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ We provide some running examples. We will update the examples if we achieve bett
* `uno_nfsp.py`: train NFSP on UNO.
* `uno_random.py`: run random agents on UNO.
* `uno_single.py`: train DQN on UNO as single-agent environment.
* **Save models**: refer to `leduc_holdem_nfsp_save_model.py`.
* **Save models**: refer to `leduc_holdem_nfsp_save_model.py` and `leduc_holdem_cfr.py` for CFR.
* **Load models**: refer to [rlcard/models/pretrained\_models.py](../rlcard/models/pretrained_models.py)
3 changes: 3 additions & 0 deletions examples/leduc_holdem_cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

# Initilize CFR Agent
agent = CFRAgent(env)
agent.load() # If we have saved model, we first load the model

# Evaluate CFR against pre-trained NFSP
eval_env.set_agents([agent, models.load('leduc-holdem-nfsp').agents[0]])
Expand All @@ -41,6 +42,8 @@
print('\rIteration {}'.format(episode), end='')
# Evaluate the performance. Play with NFSP agents.
if episode % evaluate_every == 0:
agent.save() # Save model

reward = 0
for eval_episode in range(evaluate_num):
_, payoffs = eval_env.run(is_training=False)
Expand Down
53 changes: 51 additions & 2 deletions rlcard/agents/cfr_agent.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import numpy as np
import collections

import os
import pickle

from rlcard.utils.utils import *

class CFRAgent():
''' Implement CFR algorithm
'''

def __init__(self, env):
def __init__(self, env, model_path='./cfr_model'):
''' Initilize Agent
Args:
env (Env): Env class
'''
self.env = env

self.model_path = model_path

# A policy is a dict state_str -> action probabilities
self.policy = collections.defaultdict(list)
self.average_policy =collections.defaultdict(np.array)
Expand Down Expand Up @@ -165,3 +169,48 @@ def get_state(self, player_id):
'''
state = self.env.get_state(player_id)
return state['obs'].tostring(), state['legal_actions']

def save(self):
''' Save model
'''
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)

policy_file = open(os.path.join(self.model_path, 'policy.pkl'),'wb')
pickle.dump(self.policy, policy_file)
policy_file.close()

average_policy_file = open(os.path.join(self.model_path, 'average_policy.pkl'),'wb')
pickle.dump(self.average_policy, average_policy_file)
average_policy_file.close()

regrets_file = open(os.path.join(self.model_path, 'regrets.pkl'),'wb')
pickle.dump(self.regrets, regrets_file)
regrets_file.close()

iteration_file = open(os.path.join(self.model_path, 'iteration.pkl'),'wb')
pickle.dump(self.iteration, iteration_file)
iteration_file.close()

def load(self):
''' Load model
'''
if not os.path.exists(self.model_path):
return

policy_file = open(os.path.join(self.model_path, 'policy.pkl'),'rb')
self.policy = pickle.load(policy_file)
policy_file.close()

average_policy_file = open(os.path.join(self.model_path, 'average_policy.pkl'),'rb')
self.average_policy = pickle.load(average_policy_file)
average_policy_file.close()

regrets_file = open(os.path.join(self.model_path, 'regrets.pkl'),'rb')
self.regrets = pickle.load(regrets_file)
regrets_file.close()

iteration_file = open(os.path.join(self.model_path, 'iteration.pkl'),'rb')
self.iteration = pickle.load(iteration_file)
iteration_file.close()

17 changes: 17 additions & 0 deletions tests/agents/test_cfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,20 @@ def test_train(self):
action = agent.eval_step(state)

self.assertIn(action, [0, 2])

def test_save_and_load(self):
env = rlcard.make('leduc-holdem', allow_step_back=True)
agent = CFRAgent(env)

for _ in range(100):
agent.train()

agent.save()

new_agent = CFRAgent(env)
new_agent.load()
self.assertEqual(len(agent.policy), len(new_agent.policy))
self.assertEqual(len(agent.average_policy), len(new_agent.average_policy))
self.assertEqual(len(agent.regrets), len(new_agent.regrets))
self.assertEqual(agent.iteration, new_agent.iteration)

0 comments on commit 65a429d

Please sign in to comment.