Skip to content

Commit

Permalink
Fix NEURON mechanism registration function for non POINT_PROCESSes (B…
Browse files Browse the repository at this point in the history
…lueBrain#1111)

* Added empty definitions for `nrn_{alloc,init,state,cur}` functions
* Makes generated code compilable and loadable for simple MOD files
  • Loading branch information
iomaganaris authored Dec 5, 2023
1 parent 474c3fe commit 7265ec9
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 14 deletions.
6 changes: 6 additions & 0 deletions src/codegen/codegen_naming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ static constexpr char NRN_STATE_METHOD[] = "nrn_state";
/// nrn_cur method in generated code
static constexpr char NRN_CUR_METHOD[] = "nrn_cur";

/// nrn_jacob method in generated code
static constexpr char NRN_JACOB_METHOD[] = "nrn_jacob";

/// nrn_watch_check method in generated c++ file
static constexpr char NRN_WATCH_CHECK_METHOD[] = "nrn_watch_check";

Expand All @@ -164,6 +167,9 @@ static constexpr char THREAD_ARGS_PROTO[] = "_threadargsproto_";
/// prefix for ion variable
static constexpr char ION_VARNAME_PREFIX[] = "ion_";

/// hoc_nrnpointerindex name
static constexpr char NRN_POINTERINDEX[] = "hoc_nrnpointerindex";


/// commonly used variables in verbatim block and how they
/// should be mapped to new code generation backends
Expand Down
86 changes: 79 additions & 7 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,19 @@ void CodegenNeuronCppVisitor::print_sdlists_init(bool print_initializers) {

void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_initializers) {
/// TODO: Print only global variables printed in NEURON
printer->add_line();
printer->add_newline(2);
printer->add_line("/* NEURON global variables */");
if (info.primes_size != 0) {
printer->fmt_line("static neuron::container::field_index _slist1[{0}], _dlist1[{0}];",
info.primes_size);
}
printer->add_line("static int mech_type;");

printer->fmt_line("static int {} = {};",
naming::NRN_POINTERINDEX,
info.pointer_variables.size() > 0
? static_cast<int>(info.pointer_variables.size())
: -1);
}


Expand Down Expand Up @@ -388,11 +395,29 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {
/// TODO: Write this according to NEURON
printer->add_newline(2);
printer->add_line("/** register channel with the simulator */");
printer->fmt_push_block("void _{}_reg()", info.mod_file);
printer->fmt_push_block("extern \"C\" void _{}_reg()", info.mod_file);
print_sdlists_init(true);
printer->add_newline();

const auto compute_functions_parameters =
breakpoint_exist()
? fmt::format("{}, {}, {}",
nrn_cur_required() ? method_name(naming::NRN_CUR_METHOD) : "nullptr",
method_name(naming::NRN_JACOB_METHOD),
nrn_state_required() ? method_name(naming::NRN_STATE_METHOD) : "nullptr")
: "nullptr, nullptr, nullptr";
const auto register_mech_args = fmt::format("{}, {}, {}, {}, {}, {}",
get_channel_info_var_name(),
method_name(naming::NRN_ALLOC_METHOD),
compute_functions_parameters,
method_name(naming::NRN_INIT_METHOD),
naming::NRN_POINTERINDEX,
1 + info.thread_data_index);
printer->fmt_line("register_mech({});", register_mech_args);

// type related information
printer->add_newline();
printer->fmt_line("int mech_type = nrn_get_mechtype({}[1]);", get_channel_info_var_name());
printer->fmt_line("mech_type = nrn_get_mechtype({}[1]);", get_channel_info_var_name());

// More things to add here
printer->add_line("_nrn_mechanism_register_data_fields(mech_type,");
Expand Down Expand Up @@ -425,6 +450,7 @@ void CodegenNeuronCppVisitor::print_mechanism_register() {


void CodegenNeuronCppVisitor::print_mechanism_range_var_structure(bool print_initializers) {
printer->add_newline(2);
printer->add_line("/* NEURON RANGE variables macro definitions */");
for (auto i = 0; i < codegen_float_variables.size(); ++i) {
const auto float_var = codegen_float_variables[i];
Expand Down Expand Up @@ -454,6 +480,34 @@ void CodegenNeuronCppVisitor::print_global_function_common_code(BlockType type,
}


void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) {
codegen = true;
printer->add_newline(2);
printer->add_line("/** initialize channel */");

printer->fmt_line(
"static void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* "
"_ml_arg, int _type) {{}}",
method_name(naming::NRN_INIT_METHOD));

codegen = false;
}


void CodegenNeuronCppVisitor::print_nrn_jacob() {
codegen = true;
printer->add_newline(2);
printer->add_line("/** nrn_jacob function */");

printer->fmt_line(
"static void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* "
"_nt, Memb_list* _ml_arg, int _type) {{}}",
method_name(naming::NRN_JACOB_METHOD));

codegen = false;
}


/// TODO: Edit for NEURON
void CodegenNeuronCppVisitor::print_nrn_constructor() {
return;
Expand All @@ -468,7 +522,11 @@ void CodegenNeuronCppVisitor::print_nrn_destructor() {

/// TODO: Print the equivalent of `nrn_alloc_<mech_name>`
void CodegenNeuronCppVisitor::print_nrn_alloc() {
return;
printer->add_newline(2);
auto method = method_name(naming::NRN_ALLOC_METHOD);
printer->fmt_push_block("static void {}(Prop* _prop)", method);
printer->add_line("// do nothing");
printer->pop_block();
}


Expand All @@ -483,8 +541,13 @@ void CodegenNeuronCppVisitor::print_nrn_state() {
return;
}
codegen = true;
printer->add_newline(2);

printer->fmt_line(
"void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* "
"_ml_arg, int _type) {{}}",
method_name(naming::NRN_STATE_METHOD));

printer->add_line("void nrn_state() {}");
/// TODO: Fill in

codegen = false;
Expand Down Expand Up @@ -533,16 +596,21 @@ void CodegenNeuronCppVisitor::print_nrn_cur() {
}

codegen = true;
printer->add_newline(2);

printer->fmt_line(
"void {}(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, "
"int _type) {{}}",
method_name(naming::NRN_CUR_METHOD));

printer->add_line("void nrn_cur() {}");
/// TODO: Fill in

codegen = false;
}


/****************************************************************************************/
/* Main code printing entry points */
/* Main code printing entry points */
/****************************************************************************************/

void CodegenNeuronCppVisitor::print_headers_include() {
Expand Down Expand Up @@ -590,6 +658,7 @@ void CodegenNeuronCppVisitor::print_mechanism_variables_macros() {
using _nrn_model_sorted_token = neuron::model_sorted_token;
using _nrn_mechanism_cache_range = neuron::cache::MechanismRange<number_of_floating_point_variables, number_of_datum_variables>;
using _nrn_mechanism_cache_instance = neuron::cache::MechanismInstance<number_of_floating_point_variables, number_of_datum_variables>;
using _nrn_non_owning_id_without_container = neuron::container::non_owning_identifier_without_container;
template <typename T>
using _nrn_mechanism_field = neuron::mechanism::field<T>;
template <typename... Args>
Expand Down Expand Up @@ -641,8 +710,10 @@ void CodegenNeuronCppVisitor::print_g_unused() const {

/// TODO: Edit for NEURON
void CodegenNeuronCppVisitor::print_compute_functions() {
print_nrn_init();
print_nrn_cur();
print_nrn_state();
print_nrn_jacob();
}


Expand All @@ -657,6 +728,7 @@ void CodegenNeuronCppVisitor::print_codegen_routines() {
print_prcellstate_macros();
print_mechanism_info();
print_data_structures(true);
print_nrn_alloc();
print_global_variables_for_hoc();
print_compute_functions(); // only nrn_cur and nrn_state
print_mechanism_register();
Expand Down
14 changes: 14 additions & 0 deletions src/codegen/codegen_neuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
const std::string& function_name = "") override;


/**
* Print the \c nrn\_init function definition
* \param skip_init_check \c true to generate code executing the initialization conditionally
*/
void print_nrn_init(bool skip_init_check = true);


/**
* Print nrn_constructor function definition
*
Expand All @@ -382,6 +389,13 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
void print_nrn_alloc() override;


/**
* Print nrn_jacob function definition
*
*/
void print_nrn_jacob();


/****************************************************************************************/
/* Print nrn_state routine */
/****************************************************************************************/
Expand Down
19 changes: 12 additions & 7 deletions test/unit/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ std::string get_neuron_cpp_code(const std::string& nmodl_text,
std::stringstream ss;
auto cvisitor = create_neuron_cpp_visitor(ast, nmodl_text, ss);
cvisitor->visit_program(*ast);
return reindent_text(ss.str());
return ss.str();
}


Expand Down Expand Up @@ -138,6 +138,7 @@ using _nrn_mechanism_std_vector = std::vector<T>;
using _nrn_model_sorted_token = neuron::model_sorted_token;
using _nrn_mechanism_cache_range = neuron::cache::MechanismRange<number_of_floating_point_variables, number_of_datum_variables>;
using _nrn_mechanism_cache_instance = neuron::cache::MechanismInstance<number_of_floating_point_variables, number_of_datum_variables>;
using _nrn_non_owning_id_without_container = neuron::container::non_owning_identifier_without_container;
template <typename T>
using _nrn_mechanism_field = neuron::mechanism::field<T>;
template <typename... Args>
Expand Down Expand Up @@ -213,26 +214,30 @@ void _nrn_mechanism_register_data_fields(Args&&... args) {
ContainsSubstring(reindent_and_trim_text(expected_hoc_global_variables)));
}
THEN("Placeholder nrn_cur function is printed") {
std::string expected_placeholder_nrn_cur = R"(void nrn_cur() {})";
std::string expected_placeholder_nrn_cur =
R"(void nrn_cur_pas_test(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, int _type) {})";

REQUIRE_THAT(generated,
ContainsSubstring(reindent_and_trim_text(expected_placeholder_nrn_cur)));
}
THEN("Placeholder nrn_state function is printed") {
std::string expected_placeholder_nrn_state = R"(void nrn_state() {})";
std::string expected_placeholder_nrn_state =
R"(void nrn_state_pas_test(_nrn_model_sorted_token const& _sorted_token, NrnThread* _nt, Memb_list* _ml_arg, int _type) {})";

REQUIRE_THAT(generated,
ContainsSubstring(reindent_and_trim_text(expected_placeholder_nrn_state)));
}
THEN("Placeholder registration function is printed") {
std::string expected_placeholder_reg = R"(/** register channel with the simulator */
void __test_reg() {
std::string expected_placeholder_reg = R"CODE(/** register channel with the simulator */
extern "C" void __test_reg() {
/* s */
_slist1[0] = {4, 0};
/* Ds */
_dlist1[0] = {7, 0};
int mech_type = nrn_get_mechtype(mechanism_info[1]);
register_mech(mechanism_info, nrn_alloc_pas_test, nrn_cur_pas_test, nrn_jacob_pas_test, nrn_state_pas_test, nrn_init_pas_test, hoc_nrnpointerindex, 1);
mech_type = nrn_get_mechtype(mechanism_info[1]);
_nrn_mechanism_register_data_fields(mech_type,
_nrn_mechanism_field<double>{"g"} /* 0 */,
_nrn_mechanism_field<double>{"e"} /* 1 */,
Expand All @@ -246,7 +251,7 @@ void _nrn_mechanism_register_data_fields(Args&&... args) {
_nrn_mechanism_field<double>{"g_unused"} /* 9 */
);
})";
})CODE";

REQUIRE_THAT(generated,
ContainsSubstring(reindent_and_trim_text(expected_placeholder_reg)));
Expand Down

0 comments on commit 7265ec9

Please sign in to comment.