Skip to content

Commit

Permalink
Refactor: ASR: Define and use handle_attribute()
Browse files Browse the repository at this point in the history
This combines duplicated logic for attribute handling in visit_Call() and visit_Expr()
  • Loading branch information
Shaikh-Ubaid committed Jun 20, 2023
1 parent 3096d24 commit 9e0caa2
Showing 1 changed file with 92 additions and 122 deletions.
214 changes: 92 additions & 122 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6316,43 +6316,11 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
return;
}
} else if (AST::is_a<AST::Attribute_t>(*c->m_func)) {
Vec<ASR::call_arg_t> args;
parse_args(*c, args);
AST::Attribute_t *at = AST::down_cast<AST::Attribute_t>(c->m_func);
if (AST::is_a<AST::Name_t>(*at->m_value)) {
std::string value = AST::down_cast<AST::Name_t>(at->m_value)->m_id;
ASR::symbol_t *t = current_scope->resolve_symbol(value);
if (!t) {
throw SemanticError("'" + value + "' is not defined in the scope",
x.base.base.loc);
}
if (ASR::is_a<ASR::Module_t>(*t)) {
std::string call_name = at->m_attr;
std::string call_name_store = "__" + value + "_" + call_name;
ASR::Module_t *m = ASR::down_cast<ASR::Module_t>(t);
call_name_store = ASRUtils::get_mangled_name(m, call_name_store);
ASR::symbol_t *st = current_scope->resolve_symbol(call_name_store);
if (!st) {
st = import_from_module(al, m, current_scope, value,
call_name, call_name_store, x.base.base.loc);
current_scope->add_symbol(call_name_store, st);
}
Vec<ASR::call_arg_t> args;
args.reserve(al, c->n_args);
visit_expr_list(c->m_args, c->n_args, args);
tmp = make_call_helper(al, st, current_scope, args,
call_name, x.base.base.loc);
return;
}
Vec<ASR::expr_t*> elements;
elements.reserve(al, c->n_args);
for (size_t i = 0; i < c->n_args; ++i) {
visit_expr(*c->m_args[i]);
elements.push_back(al, ASRUtils::EXPR(tmp));
}
ASR::expr_t *te = ASR::down_cast<ASR::expr_t>(
ASR::make_Var_t(al, x.base.base.loc, t));
handle_builtin_attribute(te, at->m_attr, x.base.base.loc, elements);
return;
}
handle_attribute(at, args, x.base.base.loc);
return;
} else {
throw SemanticError("Only Name/Attribute supported in Call",
x.base.base.loc);
Expand Down Expand Up @@ -7024,6 +6992,92 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = ASR::make_StringConstant_t(al, loc, s2c(al, s_var), str_type);
}

void handle_attribute(AST::Attribute_t* at, Vec<ASR::call_arg_t> &args, const Location &loc) {
if (AST::is_a<AST::Name_t>(*at->m_value)) {
AST::Name_t *n = AST::down_cast<AST::Name_t>(at->m_value);
std::string mod_name = n->m_id;
std::string call_name = at->m_attr;
std::string call_name_store = "__" + mod_name + "_" + call_name;
ASR::symbol_t *st = nullptr;
if (current_scope->resolve_symbol(call_name_store) != nullptr) {
st = current_scope->get_symbol(call_name_store);
} else {
st = current_scope->resolve_symbol(mod_name);
if (!st) {
throw SemanticError("NameError: '" + mod_name + "' is not defined", n->base.base.loc);
}
if( ASR::is_a<ASR::Module_t>(*st) ) {
ASR::Module_t *m = ASR::down_cast<ASR::Module_t>(st);
call_name_store = ASRUtils::get_mangled_name(m, call_name_store);
st = import_from_module(al, m, current_scope, mod_name,
call_name, call_name_store, loc);
current_scope->add_symbol(call_name_store, st);
} else if( ASR::is_a<ASR::StructType_t>(*st) ) {
st = get_struct_member(st, call_name, loc);
} else if ( ASR::is_a<ASR::Variable_t>(*st)) {
ASR::Variable_t* var = ASR::down_cast<ASR::Variable_t>(st);
if (ASR::is_a<ASR::Struct_t>(*var->m_type)) {
// call to struct member function
ASR::Struct_t* var_struct = ASR::down_cast<ASR::Struct_t>(var->m_type);
st = get_struct_member(var_struct->m_derived_type, call_name, loc);
} else {
// this case when we have variable and attribute
st = current_scope->resolve_symbol(mod_name);
Vec<ASR::expr_t*> eles;
eles.reserve(al, args.size());
for (size_t i=0; i<args.size(); i++) {
eles.push_back(al, args[i].m_value);
}
ASR::expr_t *se = ASR::down_cast<ASR::expr_t>(
ASR::make_Var_t(al, loc, st));
if (ASR::is_a<ASR::Character_t>(*(ASRUtils::expr_type(se)))) {
handle_string_attributes(se, args, at->m_attr, loc);
return;
}
handle_builtin_attribute(se, at->m_attr, loc, eles);
return;
}
}
}
tmp = make_call_helper(al, st, current_scope, args, call_name, loc);
return;
} else if (AST::is_a<AST::UnaryOp_t>(*at->m_value)) {
AST::UnaryOp_t* uop = AST::down_cast<AST::UnaryOp_t>(at->m_value);
visit_UnaryOp(*uop);
Vec<ASR::expr_t*> eles;
eles.reserve(al, args.size());
for (size_t i=0; i<args.size(); i++) {
eles.push_back(al, args[i].m_value);
}
ASR::expr_t* expr = ASR::down_cast<ASR::expr_t>(tmp);
handle_builtin_attribute(expr, at->m_attr, loc, eles);
return;
} else if (AST::is_a<AST::ConstantInt_t>(*at->m_value)) {
if (std::string(at->m_attr) == std::string("bit_length")) {
//bit_length() attribute:
if(args.size() != 0) {
throw SemanticError("int.bit_length() takes no arguments", loc);
}
AST::ConstantInt_t *n = AST::down_cast<AST::ConstantInt_t>(at->m_value);
int64_t int_val = std::abs(n->m_value);
int32_t res = 0;
for(; int_val; int_val >>= 1, res++);
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4));
tmp = ASR::make_IntegerConstant_t(al, loc, res, int_type);
return;
} else {
throw SemanticError("'int' object has no attribute '" + std::string(at->m_attr) + "'", loc);
}
} else if (AST::is_a<AST::ConstantStr_t>(*at->m_value)) {
AST::ConstantStr_t *n = AST::down_cast<AST::ConstantStr_t>(at->m_value);
std::string res = n->m_value;
handle_constant_string_attributes(res, args, at->m_attr, loc);
return;
} else {
throw SemanticError("Only Name type and constant integers supported in Call", loc);
}
}

ASR::symbol_t* get_struct_member(ASR::symbol_t* struct_type_sym, std::string &call_name, const Location &loc) {
ASR::StructType_t* struct_type = ASR::down_cast<ASR::StructType_t>(struct_type_sym);
std::string struct_var_name = struct_type->m_name;
Expand Down Expand Up @@ -7083,92 +7137,8 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
if (AST::is_a<AST::Attribute_t>(*x.m_func)) {
parse_args(x, args);
AST::Attribute_t *at = AST::down_cast<AST::Attribute_t>(x.m_func);
if (AST::is_a<AST::Name_t>(*at->m_value)) {
AST::Name_t *n = AST::down_cast<AST::Name_t>(at->m_value);
std::string mod_name = n->m_id;
call_name = at->m_attr;
std::string call_name_store = "__" + mod_name + "_" + call_name;
ASR::symbol_t *st = nullptr;
if (current_scope->resolve_symbol(call_name_store) != nullptr) {
st = current_scope->get_symbol(call_name_store);
} else {
st = current_scope->resolve_symbol(mod_name);
if (!st) {
throw SemanticError("NameError: '" + mod_name + "' is not defined", n->base.base.loc);
}
if( ASR::is_a<ASR::Module_t>(*st) ) {
ASR::Module_t *m = ASR::down_cast<ASR::Module_t>(st);
call_name_store = ASRUtils::get_mangled_name(m, call_name_store);
st = import_from_module(al, m, current_scope, mod_name,
call_name, call_name_store, x.base.base.loc);
current_scope->add_symbol(call_name_store, st);
} else if( ASR::is_a<ASR::StructType_t>(*st) ) {
st = get_struct_member(st, call_name, x.base.base.loc);
} else if ( ASR::is_a<ASR::Variable_t>(*st)) {
ASR::Variable_t* var = ASR::down_cast<ASR::Variable_t>(st);
if (ASR::is_a<ASR::Struct_t>(*var->m_type)) {
// call to struct member function
ASR::Struct_t* var_struct = ASR::down_cast<ASR::Struct_t>(var->m_type);
st = get_struct_member(var_struct->m_derived_type, call_name, x.base.base.loc);
} else {
// this case when we have variable and attribute
st = current_scope->resolve_symbol(mod_name);
Vec<ASR::expr_t*> eles;
eles.reserve(al, x.n_args);
for (size_t i=0; i<x.n_args; i++) {
eles.push_back(al, args[i].m_value);
}
ASR::expr_t *se = ASR::down_cast<ASR::expr_t>(
ASR::make_Var_t(al, x.base.base.loc, st));
if (ASR::is_a<ASR::Character_t>(*(ASRUtils::expr_type(se)))) {
handle_string_attributes(se, args, at->m_attr, x.base.base.loc);
return;
}
handle_builtin_attribute(se, at->m_attr, x.base.base.loc, eles);
return;
}
}
}
tmp = make_call_helper(al, st, current_scope, args, call_name, x.base.base.loc);
return;
} else if (AST::is_a<AST::UnaryOp_t>(*at->m_value)) {
AST::UnaryOp_t* uop = AST::down_cast<AST::UnaryOp_t>(at->m_value);
visit_UnaryOp(*uop);
Vec<ASR::expr_t*> eles;
eles.reserve(al, x.n_args);
for (size_t i=0; i<x.n_args; i++) {
eles.push_back(al, args[i].m_value);
}
ASR::expr_t* expr = ASR::down_cast<ASR::expr_t>(tmp);
handle_builtin_attribute(expr, at->m_attr, x.base.base.loc, eles);
return;
} else if (AST::is_a<AST::ConstantInt_t>(*at->m_value)) {
if (std::string(at->m_attr) == std::string("bit_length")) {
//bit_length() attribute:
if(args.size() != 0) {
throw SemanticError("int.bit_length() takes no arguments",
x.base.base.loc);
}
AST::ConstantInt_t *n = AST::down_cast<AST::ConstantInt_t>(at->m_value);
int64_t int_val = std::abs(n->m_value);
int32_t res = 0;
for(; int_val; int_val >>= 1, res++);
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4));
tmp = ASR::make_IntegerConstant_t(al, x.base.base.loc, res, int_type);
return;
} else {
throw SemanticError("'int' object has no attribute '" + std::string(at->m_attr) + "'",
x.base.base.loc);
}
} else if (AST::is_a<AST::ConstantStr_t>(*at->m_value)) {
AST::ConstantStr_t *n = AST::down_cast<AST::ConstantStr_t>(at->m_value);
std::string res = n->m_value;
handle_constant_string_attributes(res, args, at->m_attr, x.base.base.loc);
return;
} else {
throw SemanticError("Only Name type and constant integers supported in Call",
x.base.base.loc);
}
handle_attribute(at, args, x.base.base.loc);
return;
} else if( call_name == "" ) {
throw SemanticError("Only Name or Attribute type supported in Call",
x.base.base.loc);
Expand Down

0 comments on commit 9e0caa2

Please sign in to comment.