-
Notifications
You must be signed in to change notification settings - Fork 3
/
MCCFRNode_trim.java
151 lines (140 loc) · 5.36 KB
/
MCCFRNode_trim.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import java.util.Arrays;
public class MCCFRNode_trim {
private double[][] regretSum;
private double[] strategy;
private double[][] strategySum;
private boolean[] is_valid;
private int total_game_actions; //number of actions the information set with most actions. used to set arrays length
private int num_valid_actions; //number of actions in this information set.
private final int iteration_mod = 3; //has to be at least 2 to allow updating only for the next iteration
private int current_iteration_mod_pointer = 0;
public static final int UTILITY_HISTORY_LENGTH = 500;
public static final double CUTOFF_THRESHOLD = 0.0001;
public static final int NUM_PLAYERS = 2;
private double total_utility[];
private int utility_history_counter[];
private boolean[] trim;
private double[] mean_square_est; //sums the square of the averege utilities
private double[] mean_est; //sums the averege utilities
private static final double TRIM_EPSILON = 1; //probability that a node will be trimmed if stable
public void Print() {
System.out.println(Arrays.toString(getAverageStrategy()));
}
MCCFRNode_trim(DecisionNode h){
total_game_actions = h.total_game_actions();
num_valid_actions = h.num_valid_actions();
regretSum = new double[iteration_mod][total_game_actions];
strategy = new double[total_game_actions];
strategySum = new double[iteration_mod][total_game_actions];
is_valid = new boolean[total_game_actions];
for (int a=0; a < total_game_actions; a++)
{
is_valid[a] = h.action_valid(a);
}
utility_history_counter = new int[NUM_PLAYERS];
trim = new boolean[NUM_PLAYERS];
total_utility = new double[NUM_PLAYERS];
mean_square_est = new double[NUM_PLAYERS];
mean_est = new double[NUM_PLAYERS];
for (int i=0; i< NUM_PLAYERS; i++){
utility_history_counter[i] = 0;
total_utility[i] = 0;
trim[i] = false;
mean_square_est[i] = 0;
mean_est[i] = 0;
} }
public void updateRegretSum(int action_index, double regret, int current_iteration) {
current_iteration_mod_pointer = current_iteration%iteration_mod;
int next_iteration_mod = (current_iteration+1)%iteration_mod;
int next_next_iteration_mod = (current_iteration+2)%iteration_mod;
regretSum[next_iteration_mod][action_index] += regret;
regretSum[next_next_iteration_mod][action_index] = regretSum[next_iteration_mod][action_index];
}
public void updateUtility(double utility, int player){
utility_history_counter[player]++;
total_utility[player] = total_utility[player] + utility;
double mean = get_mean(player);
mean_square_est[player] += mean*mean;
mean_est[player] += mean;
if (utility_history_counter[player] % UTILITY_HISTORY_LENGTH == 0) {
double var = get_var(player);
//double cutoff = get_mean_est(player)/1000;
if (var < CUTOFF_THRESHOLD) {
//if (var < cutoff*cutoff) {
trim[player] = true;
}
else {
mean_square_est[player] = 0;
mean_est[player] = 0;
}
}
}
public boolean can_trim(int player) {
if (trim[player] == true) {
if(Math.random() < TRIM_EPSILON) return true;
else trim[player] = false;
}
return false;
}
public double get_mean(int player) { //returns avegare utility
return total_utility[player] / utility_history_counter[player];
}
public double get_mean_est(int player) { //returns the average of the average utilities
return mean_est[player] / UTILITY_HISTORY_LENGTH;
}
private double get_var(int player) {
double mean_average = get_mean_est(player);
double var = mean_square_est[player]/ UTILITY_HISTORY_LENGTH - mean_average*mean_average;
return var;
}
public void updateStrategySum(int action_index, int current_iteration) {
current_iteration_mod_pointer = current_iteration%iteration_mod;
int next_iteration_mod = (current_iteration+1)%iteration_mod;
int next_next_iteration_mod = (current_iteration+2)%iteration_mod;
strategySum[next_iteration_mod][action_index] += strategy[action_index];
strategySum[next_next_iteration_mod][action_index] = strategySum[next_iteration_mod][action_index];
}
public double[] getStrategy(int current_iteration)
{
double normalizingSum = 0.0;
int current_iteration_mod = current_iteration%iteration_mod;
for (int a=0; a < total_game_actions; a++)
{
if (is_valid[a] == false) continue;
strategy[a] = regretSum[current_iteration_mod][a] > 0 ? regretSum[current_iteration_mod][a] : 0;
normalizingSum += strategy[a];
}
for (int a=0; a < total_game_actions; a++)
{
if (is_valid[a] == false) continue;
if (normalizingSum > 0)
{
strategy[a] /= normalizingSum;
}
else
{
strategy[a] = 1.0 /num_valid_actions;
}
}
return strategy;
}
public double[] getAverageStrategy () {
double[] avgStrategy = new double[total_game_actions];
double normalizingSum = 0.0;
int next_iteration_mod_pointer = (current_iteration_mod_pointer+1)%iteration_mod;
for (int a=0; a < total_game_actions; a++){
if (is_valid[a] == false) continue;
normalizingSum += strategySum[next_iteration_mod_pointer][a];
}
for (int a=0; a < total_game_actions; a++){
if (is_valid[a] == false) continue;
if (normalizingSum > 0) {
avgStrategy[a] = strategySum[next_iteration_mod_pointer][a] / normalizingSum;
}
else {
avgStrategy[a] = 1.0 /num_valid_actions;
}
}
return avgStrategy;
}
}