-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kechao CAI
committed
Jan 2, 2018
0 parents
commit b01c563
Showing
18 changed files
with
908 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
classdef (Abstract) Arm < handle | ||
|
||
properties | ||
end | ||
|
||
methods | ||
function reward_pull = pull(self) | ||
end | ||
|
||
function reward_expect = getExpectReward(self) | ||
end | ||
|
||
function reset(self) | ||
end | ||
|
||
function arm_info = getArmInfo(self) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
classdef L2Arm < Arm | ||
|
||
properties | ||
mu1 | ||
mu2 | ||
% total_rounds | ||
pulled_times = 0; | ||
end | ||
|
||
methods | ||
function self = L2Arm(mu1_, mu2_) | ||
assert(mu1_ >= 0 && mu1_ <= 1, 'L1 mean reward should be in [0,1]') | ||
assert(mu2_ >= 0 && mu2_ <= 1, 'L2 mean reward should be in [0,1]') | ||
self.mu1 = mu1_; | ||
self.mu2 = mu2_; | ||
% self.total_rounds = t_rounds; | ||
end | ||
|
||
function reward_pull = pull(self) | ||
% using the Bernoulli distribution in both levels. | ||
l1_reward = random('bino', 1, self.mu1); | ||
l2_reward = random('bino', 1, self.mu2); | ||
reward_pull = L2Reward(l1_reward, l2_reward); | ||
self.pulled_times = self.pulled_times + 1; | ||
end | ||
|
||
function reward_expect = getExpectReward(self) | ||
reward_expect = L2Reward(self.mu1, self.mu2); | ||
end | ||
|
||
function reset(self) | ||
self.pulled_times = 0; | ||
end | ||
|
||
function arm_info = getArmInfo(self) | ||
formatSpec = 'Two-level arm with: mu1 = %0.3f mu2 = %0.3f'; | ||
arm_info = sprintf(formatSpec, self.mu1, self.mu2); | ||
end | ||
|
||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
classdef L2Reward | ||
|
||
properties | ||
l1 | ||
l2 | ||
compound | ||
end | ||
|
||
methods | ||
function self = L2Reward(l1_, l2_) | ||
self.l1 = l1_; | ||
self.l2 = l2_; | ||
self.compound = l1_ * l2_; | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
function indices = L_max(l, vec) | ||
% | ||
% Find the indices of the largest l elements in vec. | ||
assert(l < length(vec), 'L should be less than the vector length') | ||
[~, sortIndex] = sort(vec, 'descend'); | ||
indices = sortIndex(1:l); | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
function indices = L_random(L, K) | ||
|
||
data = 1:K; | ||
|
||
% the indices should not be repeated. | ||
indices = datasample(data, L, 'Replace', false); | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
classdef LogWriter < handle | ||
|
||
properties | ||
file_name | ||
logger | ||
K | ||
L | ||
h | ||
T | ||
P | ||
arm_paramters | ||
arm_names % cell | ||
policy_names % cell | ||
opt_reward % optimal reward in each round | ||
end | ||
|
||
methods | ||
function self = LogWriter(logger_, L_, h_, arm_para_, optr_, arm_names_, policy_names_) | ||
% self.file_name = f_; | ||
self.logger = logger_; % (P_, T_, K_, 2); | ||
log_size = size(self.logger.roundwise_idx_reward); | ||
self.L = L_; | ||
self.h = h_; | ||
self.P = log_size(1); | ||
self.T = log_size(2); | ||
self.K = log_size(3); | ||
self.arm_paramters = arm_para_; % K * 2 matrix | ||
self.arm_names = arm_names_; | ||
self.policy_names = policy_names_; | ||
self.opt_reward = optr_; | ||
end | ||
|
||
function dump(self, filename_, level_) | ||
if nargin > 2 | ||
default_level = level_; | ||
else | ||
default_level = 0; | ||
end | ||
self.file_name = filename_; | ||
file_id = fopen(self.file_name, 'w'); | ||
fprintf(file_id, '#Reward and violation log of %d arms in each round\n', self.K); | ||
|
||
% arm names | ||
for i = 1:self.K | ||
fprintf(file_id, '#arm%d %s\n', i, self.arm_names{i}); | ||
end | ||
% policy names | ||
for i = 1:length(self.policy_names) | ||
fprintf(file_id, '#policy%d %s\n', i, self.policy_names{i}); | ||
end | ||
|
||
fprintf(file_id, '#K: %d\n', self.K); | ||
fprintf(file_id, '#L: %d\n', self.L); | ||
fprintf(file_id, '#h: %0.2f\n', self.h); | ||
fprintf(file_id, '#T: %d\n', self.T); | ||
|
||
% arm parameter | ||
fprintf(file_id, '#A1: '); | ||
for i = 1:self.K | ||
fprintf(file_id, '%0.2f ', self.arm_paramters(i, 1)); | ||
end | ||
fprintf(file_id, '\n'); | ||
|
||
fprintf(file_id, '#A2: '); | ||
for i = 1:self.K | ||
fprintf(file_id, '%0.2f ', self.arm_paramters(i, 2)); | ||
end | ||
fprintf(file_id, '\n'); | ||
if default_level ~= 0 | ||
fprintf(file_id, '#t: l1-l2 reward groups of %d policies ', self.P); | ||
end | ||
fprintf(file_id, '# optimal-reward '); | ||
for i = 1:self.P | ||
fprintf(file_id, 'cumreward%d ', i); | ||
end | ||
for i = 1:self.P | ||
fprintf(file_id, 'cumviolation%d ', i); | ||
end | ||
fprintf(file_id, '\n'); | ||
% out put the data | ||
round_cum_reward = cumsum(self.logger.roundwise_tot_reward, 2); | ||
for t = 1:self.T | ||
fprintf(file_id, '%d ', t); | ||
if default_level ~= 0 | ||
for p = 1:self.P | ||
for k = 1:self.K | ||
fprintf(file_id, '%0.2f %0.2f ', ... | ||
self.logger.roundwise_idx_reward(p, t, k, 1), ... | ||
self.logger.roundwise_idx_reward(p, t, k, 2)); | ||
end | ||
end | ||
end | ||
% optimal cum reward | ||
fprintf(file_id, '%0.2f ', t* self.opt_reward); | ||
% reward p | ||
for p = 1:self.P | ||
fprintf(file_id, '%0.2f ', round_cum_reward(p, t)); | ||
end | ||
% violation p | ||
for p = 1:self.P | ||
fprintf(file_id, '%0.2f ', self.logger.roundwise_tot_violation(p, t)); | ||
end | ||
fprintf(file_id, '\n'); | ||
end | ||
fclose(file_id); | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
classdef MbanditSimulator < handle | ||
|
||
properties | ||
arms % array | ||
policies % cell | ||
K | ||
L | ||
h | ||
end | ||
|
||
methods | ||
function self = MbanditSimulator(arms_, policies_, K_, L_, h_) | ||
self.arms = arms_; | ||
self.policies = policies_; | ||
self.K = K_; | ||
self.L = L_; | ||
self.h = h_; | ||
end | ||
|
||
function run_simulation(self, T_, logger_) | ||
for p = 1:length(self.policies) | ||
for t = 1:T_ | ||
if mod(t,1) == 0 | ||
fprintf('Policy %d at round %d ...\n', p, t) | ||
end | ||
self.run_single_round(logger_, p, t); | ||
end | ||
end | ||
% reset the arms, not necessary here. | ||
for iarm = self.arms | ||
iarm.reset(); | ||
end | ||
end | ||
|
||
function run_single_round(self, logger_, p_, t_) | ||
l_indices = self.policies{p_}.selectNextArms(); | ||
l_rewards = []; | ||
for idx = l_indices | ||
l_rewards = [l_rewards, self.arms(idx).pull()]; | ||
end | ||
self.policies{p_}.updateState(l_indices, l_rewards); | ||
logger_.record_reward(p_, t_, l_indices, l_rewards); | ||
logger_.record_violation(p_, t_, self.h) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
classdef (Abstract) Policy < handle | ||
|
||
properties | ||
end | ||
|
||
methods | ||
function selected_arms = selectNextArms(self) | ||
end | ||
function updateState(self, l_indices, l_rewards) | ||
end | ||
function info = getPolicyInfo(self) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
classdef Policy_EXP3M < Policy | ||
|
||
properties | ||
K | ||
L | ||
gamma | ||
wvec % weight vector | ||
pvec % prob vector | ||
S0t | ||
level | ||
beta | ||
end | ||
|
||
methods | ||
function self = Policy_EXP3M(K_, L_, gamma_,level_) | ||
self.K = K_; | ||
self.L = L_; | ||
self.gamma = gamma_; | ||
self.wvec = ones(1, K_); | ||
self.pvec = zeros(1, K_); | ||
self.S0t = []; | ||
self.level = level_; | ||
self.beta = (1 / self.L - self.gamma / self.K) / (1 - self.gamma); | ||
end | ||
|
||
function selected_arms = selectNextArms(self) | ||
|
||
sum_w = sum(self.wvec); | ||
wvec_prime = self.wvec; | ||
sorted_w = sort(self.wvec, 'descend'); | ||
|
||
|
||
th = self.beta * sum_w; | ||
|
||
if sorted_w(1) > th % find alpha | ||
alpha_t = getAlpha(self.beta, sorted_w); | ||
bool_idx = self.wvec > alpha_t; | ||
real_idx = 1:self.K; | ||
self.S0t = real_idx(bool_idx); | ||
wvec_prime(self.S0t) = alpha_t; | ||
else | ||
self.S0t = []; | ||
end | ||
|
||
sum_w_prime = sum(wvec_prime); | ||
for i = 1:self.K | ||
wi_prime = wvec_prime(i); | ||
self.pvec(i) = self.L * ((1 - self.gamma) * wi_prime / sum_w_prime + self.gamma / self.K); | ||
end | ||
|
||
selected_arms = depRound(self.L, self.pvec); | ||
|
||
end | ||
|
||
function updateState(self, l_indices, l_rewards) | ||
% Fix: can be simplified using matrix class op. | ||
assert(length(l_indices) == self.L, 'EXP3M: L-indices do not match the number of selected arms.') | ||
assert(length(l_rewards) == self.L, 'EXP3M: L-rewards do not match the number of selected arms.') | ||
xhatvec = zeros(1, self.K); | ||
% for i = 1:self.L | ||
% arm_idx = l_indices(i); | ||
% arm_reward = l_rewards(i); | ||
% arm_prob = self.pvec(arm_idx); | ||
if self.level == 1 | ||
% xhatvec(arm_idx) = arm_reward.l1 / arm_prob; | ||
xhatvec(l_indices) = [l_rewards.l1]./ self.pvec(l_indices); | ||
else | ||
% xhatvec(arm_idx) = arm_reward.compound / arm_prob; | ||
xhatvec(l_indices) = [l_rewards.compound]./ self.pvec(l_indices); | ||
end | ||
% end | ||
for j = 1:self.K | ||
if ~ismember(j, self.S0t) | ||
self.wvec(j) = self.wvec(j) * exp(self.L*self.gamma*xhatvec(j)/self.K); | ||
end | ||
end | ||
end | ||
|
||
function info = getPolicyInfo(self) | ||
formatSpec = 'EXP3M policy: K = %d L = %d gamma = %0.3f level = %d'; | ||
info = sprintf(formatSpec, self.K, self.L, self.gamma, self.level); | ||
end | ||
|
||
function reset(self) | ||
self.wvec = ones(1, self.K); | ||
self.pvec = zeros(1, self.K); | ||
end | ||
|
||
end | ||
|
||
end |
Oops, something went wrong.