Skip to content

Commit

Permalink
QOL update to the PerResidueEsmProbabilitiesMetric (RosettaCommons#26)
Browse files Browse the repository at this point in the history
This short PR:

*    Adds the proper citation for ESM
*    Adds some checks for NaNs/Infs to the Perplexity/softmax calculation
*    fixes a typo in MIFST.cc
*    Prevents an error in SampleSequenceFromProbabilities when calling the SimpleThreadingMover without mutations
*    Fixes const behavior in SampleSequenceFromProbabilities

---------

Co-authored-by: moritzertelt <[email protected]>
  • Loading branch information
MoritzErtelt and moritzertelt authored Apr 16, 2024
1 parent d972b59 commit bc4dfa1
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 34 deletions.
40 changes: 40 additions & 0 deletions database/citations/rosetta_citations.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1408,3 +1408,43 @@
10.1101/2022.05.25.493516
[END_DOI]
[END_CITATION]

[BEGIN_CITATION]
[BEGIN_PRIMARY_AUTHORS]
"" "Lin" "Z"
[END_PRIMARY_AUTHORS]
[BEGIN_COAUTHORS]
"" "Akin" "H"
"" "Rao" "R"
"" "Hie" "B"
"" "Zhu" "Z"
"" "Lu" "W"
"" "Smetanin" "N"
"" "Verkuil" "R"
"" "Kabeli" "O"
"" "Shmueli" "Y"
"" "DosSantosCosta" "A"
"" "Fazel-Zarandi" "M"
"" "Sercu" "T"
"" "Candido" "S"
[END_COAUTHORS]
[BEGIN_SENIOR_AUTHORS]
"" "Rives" "A"
[END_SENIOR_AUTHORS]
[BEGIN_YEAR]
2023
[END_YEAR]
[BEGIN_TITLE]
Evolutionary-scale prediction of atomic-level protein structure with a language model
[END_TITLE]
[BEGIN_JOURNAL]
Science
[END_JOURNAL]
[BEGIN_VOLUME_ISSUE_PAGES]
379(6637):1123-1130
[END_VOLUME_ISSUE_PAGES]
[BEGIN_DOI]
10.1126/science.ade2574
[END_DOI]
[END_CITATION]

Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <utility/file/file_sys_util.hh>

// basic headers
#include <basic/citation_manager/CitationCollection.hh>
#include <basic/citation_manager/CitationManager.hh>
#include <basic/Tracer.hh>
#include <basic/database/open.hh>
#include <basic/execute.hh>
Expand Down Expand Up @@ -373,25 +375,26 @@ EsmPerplexityTensorflowProtocol::softmax(
for ( const auto& pair : logit_map ) {
core::Size selected_residue = pair.first;
utility::vector1< core::Real > logit_vec = pair.second;
// get maximum
core::Real max_val = std::numeric_limits<core::Real>::lowest();
for ( const auto& logit : logit_vec ) {
if ( logit > max_val ) {
max_val = logit;
}

core::Real max_val = *std::max_element(logit_vec.begin(), logit_vec.end());

for ( auto& logit : logit_vec ) {
logit -= max_val;
}
// calc sum of exponential values

core::Real sum_exp = 0.0;
for ( const auto& logit : logit_vec ) {
sum_exp += std::exp(logit - max_val);
for ( const auto& scaled_logit : logit_vec ) {
sum_exp += std::exp(scaled_logit);
}

utility::vector1< core::Real > softmax_vec;
// calc softmax of each value
for ( const auto& logit : logit_vec ) {
softmax_vec.push_back(std::exp(logit - max_val) / sum_exp);
for ( const auto& scaled_logit : logit_vec ) {
softmax_vec.push_back(std::exp(scaled_logit) / sum_exp);
}
softmax_map[ selected_residue ] = softmax_vec;

softmax_map[selected_residue] = softmax_vec;
}

}

/// @brief Downloads model from GitLab if the specified path does not exist or is missing crucial files
Expand Down Expand Up @@ -456,5 +459,20 @@ EsmPerplexityTensorflowProtocol::download_model_if_not_existing( std::string con
}
}

/// @brief Get the citation for ESM
/*static*/
basic::citation_manager::CitationCollectionBaseCOP
EsmPerplexityTensorflowProtocol::get_ESM_neural_net_citation() {
using namespace basic::citation_manager;
CitationCollectionOP citation(
utility::pointer::make_shared< CitationCollection >(
"ESM", CitedModuleType::NeuralNetwork
)
);
citation->add_citation( CitationManager::get_instance()->get_citation_by_doi( "10.1126/science.ade2574" ) );
return citation;
}


} // esm_perplexity
} // protocols
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <core/sequence/SequenceProfile.hh>

// Basic headers
#include <basic/citation_manager/CitationCollectionBase.fwd.hh>
#include <basic/tensorflow_manager/RosettaTensorflowTensorContainer.hh>
#include <basic/tensorflow_manager/RosettaTensorflowProtocolBase.hh>
#include <basic/tensorflow_manager/RosettaTensorflowSessionContainer.hh>
Expand Down Expand Up @@ -85,6 +86,12 @@ public:

/// @brief Clone operation: make a copy of this object, and return an owning pointer to the copy.
basic::tensorflow_manager::RosettaTensorflowProtocolBaseOP clone() const override;

// @brief get the citation for ESM
static
basic::citation_manager::CitationCollectionBaseCOP
get_ESM_neural_net_citation();

/// @brief The tensorflow session
#ifdef USE_TENSORFLOW
basic::tensorflow_manager::RosettaTensorflowSessionContainerCOP tensorflow_session_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,13 @@ PerResidueEsmProbabilitiesMetric::fill_return_map( std::map< core::Size, utility
prob_index < EsmPerplexityTensorflowProtocol::alphabet_.size(); ++prob_index ) {
char const curr_aa = EsmPerplexityTensorflowProtocol::alphabet_[prob_index];
core::chemical::AA const aa_enum = core::chemical::aa_from_oneletter_code(curr_aa);
return_map[selected_position][aa_enum] = softmax_vec[prob_index + 1];
core::Real probability = softmax_vec[prob_index + 1];
// check for NaN/inf to avoid problems later
if ( std::isnan(probability) || std::isinf(probability) ) {
probability = 0.0;
}

return_map[selected_position][aa_enum] = probability;
}
}
}
Expand Down Expand Up @@ -249,6 +255,7 @@ PerResidueEsmProbabilitiesMetric::provide_citation_info(basic::citation_manager:
"Wrote the PerResidueEsmProbabilitiesMetric."
)
);
citations.add( EsmPerplexityTensorflowProtocol::get_ESM_neural_net_citation() );
}
} //esm_perplexity
} //protocols
Expand Down
7 changes: 5 additions & 2 deletions source/src/protocols/esm_perplexity/PseudoPerplexityMetric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ PseudoPerplexityMetric::compute_perplexity(
}
core::Real aa_probability = position_pair.second.at( aa_type );
// if the probability is zero add a small constant to it to prevent -inf
if ( aa_probability == 0 ) {
aa_probability += std::numeric_limits< core::Real >::min();
if ( aa_probability == 0 || aa_probability < 0.00001 ) {
aa_probability += 0.00001;
}
if ( std::isnan(aa_probability) || std::isinf(aa_probability) ) {
aa_probability = 0.00001;
}
log_probabilities_sum += std::log( aa_probability );
}
Expand Down
2 changes: 1 addition & 1 deletion source/src/protocols/inverse_folding/MIFST.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ MIFST::set_defaults() {
void
MIFST::auto_download( bool setting ) { auto_download_ = setting; }

/// @brief Get the citation for ProteinMPNN
/// @brief Get the citation for MIF-ST
/*static*/
basic::citation_manager::CitationCollectionBaseCOP
MIFST::get_MIFST_neural_net_citation() {
Expand Down
1 change: 0 additions & 1 deletion source/src/protocols/inverse_folding/MIFST.hh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ public:
MIFST operator=( MIFST const & ) = delete;

/// @brief Get the citation for MIF-ST
/// @details TODO: fill in details for Yang et al.
static
basic::citation_manager::CitationCollectionBaseCOP
get_MIFST_neural_net_citation();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,30 @@ SampleSequenceFromProbabilities::apply( core::pose::Pose& pose ){
std::string modified_sequence = construct_modified_sequence(pose, mutations );
TR << " Done!" << std::endl;

// Call SimpleThreadingMover with sampled sequence
TR << "Threading sequence onto pose..." << std::endl;
SimpleThreadingMoverOP threader( utility::pointer::make_shared< SimpleThreadingMover >( modified_sequence, 1 ) );
threader->set_pack_neighbors( true );
// disable packing if specified by user
if ( !packing_ ) {
threader->set_pack_rounds(0);
// Check if the sequence is unmutated (only hyphens)
bool is_unmutated = true;
for ( std::size_t i = 0; i < modified_sequence.length(); i += 4 ) {
if ( i >= modified_sequence.length() || modified_sequence[i] != '-' ) {
is_unmutated = false;
break;
}
}
threader->set_sequence_mode( "threeletter" );
threader->apply( pose );
TR.Info << " Done!" << std::endl;

// Call SimpleThreadingMover with sampled sequence if mutated
if ( !is_unmutated ) {
TR << "Threading sequence onto pose..." << std::endl;
SimpleThreadingMoverOP threader(utility::pointer::make_shared<SimpleThreadingMover>(modified_sequence, 1));
threader->set_pack_neighbors(true);
// disable packing if specified by user
if ( !packing_ ) {
threader->set_pack_rounds(0);
}
threader->set_sequence_mode("threeletter");
threader->apply(pose);
TR.Info << " Done!" << std::endl;
} else {
TR.Warning << "No mutations match the thresholds set. Skipping SimpleThreadingMover." << std::endl;
}
}

void
Expand Down Expand Up @@ -312,7 +324,7 @@ SampleSequenceFromProbabilities::provide_citation_info(basic::citation_manager::

std::vector<core::Size>
SampleSequenceFromProbabilities::sample_positions(
std::map<core::Size, std::map<core::chemical::AA, core::Real>> const & probabilities,
std::map<core::Size, std::map<core::chemical::AA, core::Real>> & probabilities,
core::pose::Pose const & pose
) const {
using namespace core::chemical;
Expand Down Expand Up @@ -403,7 +415,7 @@ SampleSequenceFromProbabilities::is_aa_allowed_by_task( core::pack::task::Residu

utility::vector1<std::pair<core::Size, core::Real>>
SampleSequenceFromProbabilities::calculate_position_diffs(
std::map<core::Size, std::map<core::chemical::AA, core::Real>> const & probabilities,
std::map<core::Size, std::map<core::chemical::AA, core::Real>> & probabilities,
core::pose::Pose const& pose
) const {
using namespace core::chemical;
Expand All @@ -417,7 +429,7 @@ SampleSequenceFromProbabilities::calculate_position_diffs(
utility::vector1<std::pair<core::Size, core::Real>> position_diffs; // Contains pairs of position and their max_diff
for ( auto & pos_and_probs : probabilities ) {
core::Size position = pos_and_probs.first;
std::map<core::chemical::AA, core::Real> const & aa_probs = pos_and_probs.second;
std::map<core::chemical::AA, core::Real> & aa_probs = pos_and_probs.second;

// get the probability of the amino acid currently present in the pose at that position
AA current_aa = pose.residue(position).aa();
Expand All @@ -431,7 +443,7 @@ SampleSequenceFromProbabilities::calculate_position_diffs(
}

core::Real max_prob = 0.0;
for ( auto aa_and_prob : aa_probs ) {
for ( auto & aa_and_prob : aa_probs ) {
core::chemical::AA aa = aa_and_prob.first;
core::Real prob = aa_and_prob.second;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ private: // methods

///@brief Helper function to get the ranked positions based on maximum difference in probabilities.
std::vector<core::Size> sample_positions(
std::map<core::Size, std::map<core::chemical::AA, core::Real>> const & values,
std::map<core::Size, std::map<core::chemical::AA, core::Real>> & values,
core::pose::Pose const & pose
) const;

Expand All @@ -181,7 +181,7 @@ private: // methods

///@brief Helper function to calculate differences between current and other AAs, as well as disabling unwanted AAs
utility::vector1<std::pair<core::Size, core::Real>> calculate_position_diffs(
std::map<core::Size, std::map<core::chemical::AA, core::Real>> const & probabilities,
std::map<core::Size, std::map<core::chemical::AA, core::Real>> & probabilities,
core::pose::Pose const& pose
) const;

Expand Down

0 comments on commit bc4dfa1

Please sign in to comment.