forked from jkomiyama/banditlib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolicy_egreedy.hpp
50 lines (46 loc) · 1.26 KB
/
policy_egreedy.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#pragma once
#include "policy.hpp"
namespace bandit{
class EGreedyPolicy : public Policy{
const uint K;
std::vector<int> Ni;
std::vector<double> Gi;
const double epsilonCoef; //epsilon (random play prob) - epsilon_base/(current)t
//epsilonbase = cK/d2
public:
EGreedyPolicy(int K, double epsilonCoef=0.1): K(K), epsilonCoef(epsilonCoef) {
reset();
}
void reset(){
Ni = std::vector<int>(K, 0);
Gi = std::vector<double>(K, 0.0);
}
virtual int selectNextArm(){
double n = vectorSum(Ni);
double en = epsilonCoef/n;
double rand = std::uniform_real_distribution<double>(0.0,1.0)(randomEngine);
if(en > rand){ //random choice
return std::uniform_int_distribution<int>(0, K-1)(randomEngine);
}else{
std::vector<double> eExpectations = std::vector<double>(K, 0.0);
for(uint k=0;k<K;++k){
if(Ni[k]==0){
return k;
}
eExpectations[k] = Gi[k]/Ni[k] ;
}
int targetArm = vectorMaxIndex(eExpectations);
return targetArm;
}
}
virtual void updateState(int k, double r){
Ni[k]+=1;
Gi[k]+=r;
}
virtual std::string toString(){
std::string str="Egreedy Policy with epsilonCoef=";
str+=dtos(epsilonCoef);
return str;
}
};
} //namespace