Skip to content

Commit

Permalink
one-night mcts mapper!
Browse files Browse the repository at this point in the history
  • Loading branch information
gulang2019 committed Apr 21, 2023
1 parent 081de8a commit 6e2dae9
Show file tree
Hide file tree
Showing 17 changed files with 1,097 additions and 172 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "3rdparty/timeloop"]
path = 3rdparty/timeloop
url = [email protected]:gulang2019/timeloop.git
[submodule "3rdparty/mcts"]
path = 3rdparty/mcts
url = [email protected]:Konijnendijk/cpp-mcts.git
1 change: 1 addition & 0 deletions 3rdparty/mcts
Submodule mcts added at fc8f93
2 changes: 2 additions & 0 deletions include/tileflow/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#define TILEFLOW_WARNING(msg) do{std::cerr << "[WARNING]: " << msg << std::endl;}while(0)

#define TILEFLOW_LOG(msg) do{std::cerr << "[LOG]: " << msg << std::endl;}while(0)

#define TILEFLOW_COND_WARNING(cond, msg) do{if(!(cond)) {std::cerr << "[WARNING]: " << msg << std::endl;}}while(0)

#include "compound-config/compound-config.hpp"
Expand Down
35 changes: 21 additions & 14 deletions include/tileflow/loop-analysis/nest-analysis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,20 @@ namespace TileFlow {
NestAnalysis(const problem::TileFlow::Workloads& workloads_,
const mapping::TileFlow::Mapping& mapping_,
const model::Engine::Specs& arch_specs_,
const model::Topology& topology_,
const SymbolTable* symb_table_ = nullptr);
const model::Topology& topology_);

void set_symbol_table(const SymbolTable* symbol_table) {symb_table_ = symbol_table;}

const tiling::CompoundTile& get_tile(const Node* node) const
{
if (tiles_.count(node) == 0) {
std::cerr << "ERROR node not found:" << std::endl;
node->display("", false);
}
return tiles_.at(node);
{
if (tiles_.count(node) == 0) {
std::cerr << "ERROR node not found:" << std::endl;
node->display("", false);
}

return tiles_.at(node);
}

void reset();
void analyze();
void Print();
void Report();
Expand Down Expand Up @@ -307,12 +309,14 @@ namespace TileFlow {

class Displayer: public mapping::TileFlow::Visitor {
NestAnalysis & analysis_;
const SymbolTable* symbol_table_;
void visitTile(const TileNode*) override;
void visitScope(const ScopeNode*) override;
void visitOp(const OpNode*) override;
std::string prefix_;
public:
Displayer(NestAnalysis& analysis): analysis_(analysis){}
Displayer(NestAnalysis& analysis, const SymbolTable* symbol_table_ = nullptr)
: analysis_(analysis), symbol_table_(symbol_table_) {}
void display() {analysis_.mapping_.root->accept(this);}
};

Expand All @@ -330,13 +334,16 @@ namespace TileFlow {
};

class LoopNestConstructor: public mapping::TileFlow::Visitor {
NestAnalysis& analysis_;
const SymbolTable* symbol_table_;

void visitTile(const TileNode* node) override {
analysis_.configs[node].loop_nest = node->constructLoopNest();
analysis_.configs[node].loop_nest = node->constructLoopNest(symbol_table_);
for (auto child: node->get_children()) {child->accept(this);}
}
NestAnalysis& analysis_;
public:
LoopNestConstructor(NestAnalysis& analysis): analysis_(analysis){}
LoopNestConstructor(NestAnalysis& analysis, const SymbolTable*symbol_table_)
: analysis_(analysis), symbol_table_(symbol_table_){}
void construct(const Node* root) {root->accept(this);}
};

Expand All @@ -347,7 +354,7 @@ namespace TileFlow {
NestAnalysis& analysis_;
public:
StorageLevelCalculator(NestAnalysis& analysis): analysis_(analysis){
if (verbose_level) {
if (verbose_level > 1) {
std::cout << "Begin storge level calculation..." << std::endl;
std::cout << "\tfanoutX: ";
for (auto& kv: analysis_.mapping_.fanoutX_map)
Expand Down
13 changes: 1 addition & 12 deletions include/tileflow/mapper/checker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,7 @@ using mapping::TileFlow::ScopeNode;
using mapping::TileFlow::Visitor;

namespace TileFlow {

struct Constraint {
enum {
LOOPCOUNT,
MEM,
SPATIAL
}type_;
std::shared_ptr<Expr> expr;
std::string msg;
std::string short_msg = "";
};


class ShapeConstraintParser: public Visitor {
void visitTile(const TileNode*) override;
void visitOp(const OpNode*) override;
Expand Down
114 changes: 98 additions & 16 deletions include/tileflow/mapper/expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,100 @@
#include <string>
#include <vector>
#include <memory>
#include <set>
#include <functional>

namespace TileFlow {
typedef size_t num_t;

struct Constraint;

struct Entry {
std::string name_;
num_t value_;
int idx_;
std::set<num_t> candidates_;
bool fixed_ = false;
};

std::ostream& operator<< (std::ostream& o, const Entry&);

class SymbolTable {
std::unordered_map<std::string, int> name2idx_;
std::unordered_map<int, Entry> idx2values_;
int idx = 0;
bool failed_ = false;
bool fail_check(const std::vector<Constraint>& constraints_);

public:
Entry lookup(const std::string& key) const {return idx2values_.at(name2idx_.at(key));}
Entry lookup(int key) const {return idx2values_.at(key);}
int idx = 0;
const Entry& lookup(const std::string& key) const {return idx2values_.at(name2idx_.at(key));}
const Entry& lookup(int key) const {return idx2values_.at(key);}
int count(int key) const {return idx2values_.count(key);}
bool is_terminated() const {
for(auto& kv: idx2values_)
if(!kv.second.fixed_) return false;
return true;
}
int get_num_variables() const {return -idx;}
int count_unfixed() const {int ret = idx2values_.size(); for (auto& kv: idx2values_) ret -= kv.second.fixed_; return ret;}
int get_next_var() const;
void show_brief(std::ostream& o) const;


Entry& operator[](int key) {return idx2values_[key];}
int insert(const std::string name = "") {
std::string name_ = name;
for (int i = 0; name2idx_.count(name_); i++){
name_ = name + std::to_string(i);
}
idx--;
name2idx_[name_] = idx;
idx2values_[idx] = {name_, 0, idx};
idx2values_[idx] = {name_, 0, idx, {}, false};
return idx;
}
int get_num_variables() const {return -idx;}


void init(const std::vector<Constraint>& constraints_);
void fix_and_update(int index, num_t value, const std::vector<Constraint>& constraints_);
};

std::ostream& operator<< (std::ostream& o, const SymbolTable&);

extern SymbolTable global_symbol_table_;

struct PairExpr;
struct PairSumExpr;
struct PairMaxExpr;
struct PairCondExpr;
struct ProductExpr;
struct VariableExpr;
struct ParameterExpr;
struct CondExpr;
struct SumExpr;

struct ExprVisitor {
virtual void visitPairExpr(const PairExpr*);
virtual void visitPairSumExpr(const PairSumExpr*);
virtual void visitPairMaxExpr(const PairMaxExpr*);
virtual void visitPairCondExpr(const PairCondExpr*);
virtual void visitProductExpr(const ProductExpr*);
virtual void visitSumExpr(const SumExpr*);
virtual void visitVariableExpr(const VariableExpr*);
virtual void visitParameterExpr(const ParameterExpr*);
virtual void visitCondExpr(const CondExpr*);
};

struct Expr {
virtual num_t eval(const SymbolTable&) = 0;
virtual void display(const SymbolTable&) = 0;
virtual void accept(ExprVisitor*) const = 0;
};

struct ResourceExpr: public Expr {
virtual num_t eval(const SymbolTable& ) override {return 0;}
virtual void display(const SymbolTable& ) override {}
// given y's limit, compute minimum required x
virtual std::pair<int, int> eval_pair(const SymbolTable& symb_table, int limit_y) = 0;
virtual void accept(ExprVisitor* visitor) const = 0;
};

struct PairExpr: public ResourceExpr {
Expand All @@ -56,6 +107,7 @@ namespace TileFlow {
const std::shared_ptr<Expr>& y): x_(x), y_(y) {}
void display(const SymbolTable& symb_table) override;
std::pair<int, int> eval_pair(const SymbolTable& symb_table, int limit_y) override;
void accept(ExprVisitor* visitor) const override {visitor->visitPairExpr(this);}
};

struct PairSumExpr: public ResourceExpr {
Expand All @@ -64,6 +116,7 @@ namespace TileFlow {
operands_(operands){}
void display(const SymbolTable& symb_table) override;
std::pair<int, int> eval_pair(const SymbolTable& symb_table, int limit_y) override;
void accept(ExprVisitor* visitor) const override {visitor->visitPairSumExpr(this);}
};

struct PairMaxExpr: public ResourceExpr {
Expand All @@ -72,6 +125,7 @@ namespace TileFlow {
operands_(operands){}
void display(const SymbolTable& symb_table) override;
std::pair<int, int> eval_pair(const SymbolTable& symb_table, int limit_y) override;
void accept(ExprVisitor* visitor) const override {visitor->visitPairMaxExpr(this);}
};

struct PairCondExpr: public ResourceExpr {
Expand All @@ -87,6 +141,7 @@ namespace TileFlow {
void display(const SymbolTable& symb_table) override;
num_t eval(const SymbolTable& symb_table) override;
std::pair<int, int> eval_pair(const SymbolTable& symb_table, int limit_y) override;
void accept(ExprVisitor* visitor) const override {visitor->visitPairCondExpr(this);}
};

struct SumExpr: public Expr {
Expand All @@ -95,18 +150,9 @@ namespace TileFlow {
operands_(operands){}
num_t eval(const SymbolTable& symb_table) override;
void display(const SymbolTable& symb_table) override;
void accept(ExprVisitor* visitor) const override {visitor->visitSumExpr(this);}
};

template <typename T>
struct MaxExpr: public Expr {
std::vector<std::shared_ptr<T> > operands_;
MaxExpr(const std::vector<std::shared_ptr<T> >& operands):
operands_(operands){}
num_t eval(const SymbolTable& symb_table) override;
void display(const SymbolTable& symb_table) override;
};


struct ProductExpr: public Expr {
std::vector<std::shared_ptr<Expr> > operands_;
ProductExpr(const std::vector<std::shared_ptr<Expr> >& operands):
Expand All @@ -117,19 +163,22 @@ namespace TileFlow {
ProductExpr(const std::pair<num_t, std::vector<int> >& operands);
num_t eval(const SymbolTable& symb_table) override;
void display(const SymbolTable& symb_table) override;
void accept(ExprVisitor* visitor) const override {visitor->visitProductExpr(this);}
};
struct VariableExpr: public Expr {
int idx_;
VariableExpr(int idx): idx_(idx){}
num_t eval(const SymbolTable& symb_table) override;
void display(const SymbolTable& symb_table) override;
void accept(ExprVisitor* visitor) const override {visitor->visitVariableExpr(this);}
};

struct ParameterExpr: public Expr {
num_t value_;
ParameterExpr(num_t value): value_(value){}
num_t eval(const SymbolTable& symb_table) override;
void display(const SymbolTable& symb_table) override;
void accept(ExprVisitor* visitor) const override {visitor->visitParameterExpr(this);}
};

struct CondExpr: public Expr {
Expand All @@ -146,6 +195,39 @@ namespace TileFlow {
op_(op), left_(left), right_(right){}
num_t eval(const SymbolTable& symb_table) override;
void display(const SymbolTable& symb_table) override;
void accept(ExprVisitor* visitor) const override {visitor->visitCondExpr(this);}
};

struct VariableCollector: public ExprVisitor {
std::set<int> variables_;
std::function<bool(int)> _func;

void visitVariableExpr(const VariableExpr* expr) override {
if (_func(expr->idx_))
variables_.insert(expr->idx_);
}

std::set<int> operator()(const Expr* root, std::function<bool(int)>func) {
_func = func;
root->accept(this);
return std::move(variables_);
}
std::set<int> operator()(const Expr* root) {
_func = [](int){return true;};
root->accept(this);
return std::move(variables_);
}
};

struct Constraint {
enum {
LOOPCOUNT,
MEM,
SPATIAL
}type_;
std::shared_ptr<Expr> expr;
std::string msg;
std::string short_msg = "";
};

} // namespace TileFlow
Loading

0 comments on commit 6e2dae9

Please sign in to comment.