Skip to content

Commit

Permalink
bugfix for no variable case
Browse files Browse the repository at this point in the history
  • Loading branch information
gulang2019 committed Apr 22, 2023
1 parent c42adc0 commit 7c3ff7c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 37 deletions.
3 changes: 1 addition & 2 deletions include/tileflow/mapper/mapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ struct Env {
bool terminated_ = false;
bool expanded_ = false;
double best_reward_ = -1e9;
double reward_;

public:
Env(const std::vector<Constraint>& constraints,
Expand All @@ -104,7 +103,7 @@ struct Env {

bool is_terminated() const {return terminated_;}
bool is_expanded() const {return expanded_;}
double get_reward() const {return reward_;}
double get_reward();
State* get_curr_state() const {return curr_state_;}
const std::vector<Constraint>& get_constraints() const {return constraints_;}
const SymbolTable* get_best_symbol_table() const {
Expand Down
78 changes: 43 additions & 35 deletions src/mapper/mapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ const SymbolTable* Mapper::search() {
std::string obj = obj_ == Objective::CYCLE? "cycle" : "energy";
TILEFLOW_LOG("Optimize " << obj << "...");
analysis::TileFlow::NestAnalysis analyzer(workloads_, mapping_, arch_specs_, topology_);
if (!global_symbol_table_.count_unfixed()){
optimum_ = global_symbol_table_;
return &optimum_;
}
Env env(constraints_, global_symbol_table_, analyzer, obj_);
MCTS mcts(&env, timeout_);
mcts.search();
Expand Down Expand Up @@ -70,40 +74,44 @@ Action Env::step(bool random) {
std::tie(act, expanded_) = curr_state_->select_action(random);
curr_state_ = curr_state_->take_action(act, constraints_);
terminated_ = curr_state_->is_terminated();
if (terminated_) {
if (curr_state_->n_visit > 0) {
reward_ = curr_state_->ave_reward;
}
else if (curr_state_->is_error_out()) {
reward_ = punish_;
return act;
}

double Env::get_reward(){
assert(terminated_ && curr_state_->is_terminated());
double reward;
if (curr_state_->n_visit > 0) {
reward = curr_state_->ave_reward;
}
else if (curr_state_->is_error_out()) {
reward = punish_;
}
else
{
analyzer_.set_symbol_table(&curr_state_->symbol_table_);
if (verbose_level > 1) {
std::cout << "begin analyze..." << std::endl;
std::cout << "symbol table: ";
curr_state_->symbol_table_.show_brief(std::cout);
std::cout << std::endl;
analyzer_.Print();
}
else
{
analyzer_.set_symbol_table(&curr_state_->symbol_table_);
if (verbose_level > 1) {
std::cout << "begin analyze..." << std::endl;
std::cout << "symbol table: ";
curr_state_->symbol_table_.show_brief(std::cout);
std::cout << std::endl;
analyzer_.Print();
}
analyzer_.analyze();
double value;
if (obj_ == Objective::CYCLE) value = (double)(analyzer_.get_cycle());
else if (obj_ == Objective::ENERGY) value = (double)(analyzer_.get_energy());
reward_ = -std::log10(value);

if (reward_ > best_reward_) {
TILEFLOW_LOG("Update best "; curr_state_->symbol_table_.show_brief(std::cerr); std::cerr
<< " value: " << value);
best_reward_ = reward_;
best_symbol_table_ = &curr_state_->symbol_table_;
}
analyzer_.analyze();
double value;
if (obj_ == Objective::CYCLE) value = (double)(analyzer_.get_cycle());
else if (obj_ == Objective::ENERGY) value = (double)(analyzer_.get_energy());
reward = -std::log10(value);

if (reward > best_reward_) {
TILEFLOW_LOG("Update best "; curr_state_->symbol_table_.show_brief(std::cerr); std::cerr
<< " value: " << value);
best_reward_ = reward;
best_symbol_table_ = &curr_state_->symbol_table_;
}
if (punish_ + 2 > reward_)
punish_ = reward_ - 2;
}
return act;
if (punish_ + 2 > reward)
punish_ = reward - 2;
return reward;
}

State* MCTS::select_state() {
Expand Down Expand Up @@ -149,14 +157,14 @@ void MCTS::back_prop(State* state, double reward){
void MCTS::search() {
start_timer();
for (int i = 0; i < n_iteration; i++ ) {
auto state = select_state();
double reward = rollout(state);
back_prop(state, reward);
if (!n_unexplored) {
if (verbose_level)
std::cout << "MCTS: finished searching after " << i << std::endl;
std::cout << "MCTS: finished searching after " << i + 1 << " rounds" << std::endl;
return;
}
auto state = select_state();
double reward = rollout(state);
back_prop(state, reward);
if (get_elapsed_time() > timeout_) {
TILEFLOW_LOG("MCTS exit becuase of timeout.");
return;
Expand Down

0 comments on commit 7c3ff7c

Please sign in to comment.