Skip to content

Commit

Permalink
Init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Kechao CAI committed Jan 2, 2018
0 parents commit b01c563
Show file tree
Hide file tree
Showing 18 changed files with 908 additions and 0 deletions.
20 changes: 20 additions & 0 deletions Arm.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
classdef (Abstract) Arm < handle

properties
end

methods
function reward_pull = pull(self)
end

function reward_expect = getExpectReward(self)
end

function reset(self)
end

function arm_info = getArmInfo(self)
end
end

end
42 changes: 42 additions & 0 deletions L2Arm.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
classdef L2Arm < Arm

properties
mu1
mu2
% total_rounds
pulled_times = 0;
end

methods
function self = L2Arm(mu1_, mu2_)
assert(mu1_ >= 0 && mu1_ <= 1, 'L1 mean reward should be in [0,1]')
assert(mu2_ >= 0 && mu2_ <= 1, 'L2 mean reward should be in [0,1]')
self.mu1 = mu1_;
self.mu2 = mu2_;
% self.total_rounds = t_rounds;
end

function reward_pull = pull(self)
% using the Bernoulli distribution in both levels.
l1_reward = random('bino', 1, self.mu1);
l2_reward = random('bino', 1, self.mu2);
reward_pull = L2Reward(l1_reward, l2_reward);
self.pulled_times = self.pulled_times + 1;
end

function reward_expect = getExpectReward(self)
reward_expect = L2Reward(self.mu1, self.mu2);
end

function reset(self)
self.pulled_times = 0;
end

function arm_info = getArmInfo(self)
formatSpec = 'Two-level arm with: mu1 = %0.3f mu2 = %0.3f';
arm_info = sprintf(formatSpec, self.mu1, self.mu2);
end

end

end
16 changes: 16 additions & 0 deletions L2Reward.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
classdef L2Reward

properties
l1
l2
compound
end

methods
function self = L2Reward(l1_, l2_)
self.l1 = l1_;
self.l2 = l2_;
self.compound = l1_ * l2_;
end
end
end
8 changes: 8 additions & 0 deletions L_max.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function indices = L_max(l, vec)
%
% Find the indices of the largest l elements in vec.
assert(l < length(vec), 'L should be less than the vector length')
[~, sortIndex] = sort(vec, 'descend');
indices = sortIndex(1:l);
end

8 changes: 8 additions & 0 deletions L_random.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function indices = L_random(L, K)

data = 1:K;

% the indices should not be repeated.
indices = datasample(data, L, 'Replace', false);

end
108 changes: 108 additions & 0 deletions LogWriter.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
classdef LogWriter < handle

properties
file_name
logger
K
L
h
T
P
arm_paramters
arm_names % cell
policy_names % cell
opt_reward % optimal reward in each round
end

methods
function self = LogWriter(logger_, L_, h_, arm_para_, optr_, arm_names_, policy_names_)
% self.file_name = f_;
self.logger = logger_; % (P_, T_, K_, 2);
log_size = size(self.logger.roundwise_idx_reward);
self.L = L_;
self.h = h_;
self.P = log_size(1);
self.T = log_size(2);
self.K = log_size(3);
self.arm_paramters = arm_para_; % K * 2 matrix
self.arm_names = arm_names_;
self.policy_names = policy_names_;
self.opt_reward = optr_;
end

function dump(self, filename_, level_)
if nargin > 2
default_level = level_;
else
default_level = 0;
end
self.file_name = filename_;
file_id = fopen(self.file_name, 'w');
fprintf(file_id, '#Reward and violation log of %d arms in each round\n', self.K);

% arm names
for i = 1:self.K
fprintf(file_id, '#arm%d %s\n', i, self.arm_names{i});
end
% policy names
for i = 1:length(self.policy_names)
fprintf(file_id, '#policy%d %s\n', i, self.policy_names{i});
end

fprintf(file_id, '#K: %d\n', self.K);
fprintf(file_id, '#L: %d\n', self.L);
fprintf(file_id, '#h: %0.2f\n', self.h);
fprintf(file_id, '#T: %d\n', self.T);

% arm parameter
fprintf(file_id, '#A1: ');
for i = 1:self.K
fprintf(file_id, '%0.2f ', self.arm_paramters(i, 1));
end
fprintf(file_id, '\n');

fprintf(file_id, '#A2: ');
for i = 1:self.K
fprintf(file_id, '%0.2f ', self.arm_paramters(i, 2));
end
fprintf(file_id, '\n');
if default_level ~= 0
fprintf(file_id, '#t: l1-l2 reward groups of %d policies ', self.P);
end
fprintf(file_id, '# optimal-reward ');
for i = 1:self.P
fprintf(file_id, 'cumreward%d ', i);
end
for i = 1:self.P
fprintf(file_id, 'cumviolation%d ', i);
end
fprintf(file_id, '\n');
% out put the data
round_cum_reward = cumsum(self.logger.roundwise_tot_reward, 2);
for t = 1:self.T
fprintf(file_id, '%d ', t);
if default_level ~= 0
for p = 1:self.P
for k = 1:self.K
fprintf(file_id, '%0.2f %0.2f ', ...
self.logger.roundwise_idx_reward(p, t, k, 1), ...
self.logger.roundwise_idx_reward(p, t, k, 2));
end
end
end
% optimal cum reward
fprintf(file_id, '%0.2f ', t* self.opt_reward);
% reward p
for p = 1:self.P
fprintf(file_id, '%0.2f ', round_cum_reward(p, t));
end
% violation p
for p = 1:self.P
fprintf(file_id, '%0.2f ', self.logger.roundwise_tot_violation(p, t));
end
fprintf(file_id, '\n');
end
fclose(file_id);
end
end
end
47 changes: 47 additions & 0 deletions MbanditSimulator.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
classdef MbanditSimulator < handle

properties
arms % array
policies % cell
K
L
h
end

methods
function self = MbanditSimulator(arms_, policies_, K_, L_, h_)
self.arms = arms_;
self.policies = policies_;
self.K = K_;
self.L = L_;
self.h = h_;
end

function run_simulation(self, T_, logger_)
for p = 1:length(self.policies)
for t = 1:T_
if mod(t,1) == 0
fprintf('Policy %d at round %d ...\n', p, t)
end
self.run_single_round(logger_, p, t);
end
end
% reset the arms, not necessary here.
for iarm = self.arms
iarm.reset();
end
end

function run_single_round(self, logger_, p_, t_)
l_indices = self.policies{p_}.selectNextArms();
l_rewards = [];
for idx = l_indices
l_rewards = [l_rewards, self.arms(idx).pull()];
end
self.policies{p_}.updateState(l_indices, l_rewards);
logger_.record_reward(p_, t_, l_indices, l_rewards);
logger_.record_violation(p_, t_, self.h)
end
end

end
15 changes: 15 additions & 0 deletions Policy.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
classdef (Abstract) Policy < handle

properties
end

methods
function selected_arms = selectNextArms(self)
end
function updateState(self, l_indices, l_rewards)
end
function info = getPolicyInfo(self)
end
end

end
91 changes: 91 additions & 0 deletions Policy_EXP3M.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
classdef Policy_EXP3M < Policy

properties
K
L
gamma
wvec % weight vector
pvec % prob vector
S0t
level
beta
end

methods
function self = Policy_EXP3M(K_, L_, gamma_,level_)
self.K = K_;
self.L = L_;
self.gamma = gamma_;
self.wvec = ones(1, K_);
self.pvec = zeros(1, K_);
self.S0t = [];
self.level = level_;
self.beta = (1 / self.L - self.gamma / self.K) / (1 - self.gamma);
end

function selected_arms = selectNextArms(self)

sum_w = sum(self.wvec);
wvec_prime = self.wvec;
sorted_w = sort(self.wvec, 'descend');


th = self.beta * sum_w;

if sorted_w(1) > th % find alpha
alpha_t = getAlpha(self.beta, sorted_w);
bool_idx = self.wvec > alpha_t;
real_idx = 1:self.K;
self.S0t = real_idx(bool_idx);
wvec_prime(self.S0t) = alpha_t;
else
self.S0t = [];
end

sum_w_prime = sum(wvec_prime);
for i = 1:self.K
wi_prime = wvec_prime(i);
self.pvec(i) = self.L * ((1 - self.gamma) * wi_prime / sum_w_prime + self.gamma / self.K);
end

selected_arms = depRound(self.L, self.pvec);

end

function updateState(self, l_indices, l_rewards)
% Fix: can be simplified using matrix class op.
assert(length(l_indices) == self.L, 'EXP3M: L-indices do not match the number of selected arms.')
assert(length(l_rewards) == self.L, 'EXP3M: L-rewards do not match the number of selected arms.')
xhatvec = zeros(1, self.K);
% for i = 1:self.L
% arm_idx = l_indices(i);
% arm_reward = l_rewards(i);
% arm_prob = self.pvec(arm_idx);
if self.level == 1
% xhatvec(arm_idx) = arm_reward.l1 / arm_prob;
xhatvec(l_indices) = [l_rewards.l1]./ self.pvec(l_indices);
else
% xhatvec(arm_idx) = arm_reward.compound / arm_prob;
xhatvec(l_indices) = [l_rewards.compound]./ self.pvec(l_indices);
end
% end
for j = 1:self.K
if ~ismember(j, self.S0t)
self.wvec(j) = self.wvec(j) * exp(self.L*self.gamma*xhatvec(j)/self.K);
end
end
end

function info = getPolicyInfo(self)
formatSpec = 'EXP3M policy: K = %d L = %d gamma = %0.3f level = %d';
info = sprintf(formatSpec, self.K, self.L, self.gamma, self.level);
end

function reset(self)
self.wvec = ones(1, self.K);
self.pvec = zeros(1, self.K);
end

end

end
Loading

0 comments on commit b01c563

Please sign in to comment.