forked from Colvars/colvars
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcolvar_arithmeticpath.h
137 lines (127 loc) · 5.53 KB
/
colvar_arithmeticpath.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#ifndef ARITHMETICPATHCV_H
#define ARITHMETICPATHCV_H
#include "colvarmodule.h"
#include <vector>
#include <cmath>
#include <limits>
#include <string>
#include <algorithm>
namespace ArithmeticPathCV {
using std::vector;
template <typename scalar_type>
class ArithmeticPathBase {
public:
ArithmeticPathBase() {}
~ArithmeticPathBase() {}
void initialize(size_t p_num_elements, size_t p_total_frames, scalar_type p_lambda, const vector<scalar_type>& p_weights);
void reComputeLambda(const vector<scalar_type>& rmsd_between_refs);
template <typename element_type>
void computeValue(const vector<vector<element_type>>& frame_element_distances, scalar_type *s = nullptr, scalar_type *z = nullptr);
// can only be called after computeValue() for element-wise derivatives and store derivatives of i-th frame to dsdx and dzdx
template <typename element_type>
void computeDerivatives(const vector<vector<element_type>>& frame_element_distances, vector<vector<element_type>> *dsdx = nullptr, vector<vector<element_type>> *dzdx = nullptr);
protected:
scalar_type lambda;
vector<scalar_type> squared_weights;
size_t num_elements;
size_t total_frames;
vector<scalar_type> exponents;
scalar_type max_exponent;
scalar_type saved_exponent_sum;
scalar_type normalization_factor;
scalar_type saved_s;
};
template <typename scalar_type>
void ArithmeticPathBase<scalar_type>::initialize(size_t p_num_elements, size_t p_total_frames, scalar_type p_lambda, const vector<scalar_type>& p_weights) {
lambda = p_lambda;
for (size_t i = 0; i < p_weights.size(); ++i) squared_weights.push_back(p_weights[i] * p_weights[i]);
num_elements = p_num_elements;
total_frames = p_total_frames;
exponents.resize(total_frames);
normalization_factor = 1.0 / static_cast<scalar_type>(total_frames - 1);
saved_s = scalar_type();
saved_exponent_sum = scalar_type();
max_exponent = scalar_type();
}
template <typename scalar_type>
template <typename element_type>
void ArithmeticPathBase<scalar_type>::computeValue(
const vector<vector<element_type>>& frame_element_distances,
scalar_type *s, scalar_type *z)
{
for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
scalar_type exponent_tmp = scalar_type();
for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
exponent_tmp += squared_weights[j_elem] * frame_element_distances[i_frame][j_elem] * frame_element_distances[i_frame][j_elem];
}
exponents[i_frame] = exponent_tmp * -1.0 * lambda;
if (i_frame == 0 || exponents[i_frame] > max_exponent) max_exponent = exponents[i_frame];
}
scalar_type log_sum_exp_0 = scalar_type();
scalar_type log_sum_exp_1 = scalar_type();
for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
exponents[i_frame] = cvm::exp(exponents[i_frame] - max_exponent);
log_sum_exp_0 += exponents[i_frame];
log_sum_exp_1 += i_frame * exponents[i_frame];
}
saved_exponent_sum = log_sum_exp_0;
log_sum_exp_0 = max_exponent + cvm::logn(log_sum_exp_0);
log_sum_exp_1 = max_exponent + cvm::logn(log_sum_exp_1);
saved_s = normalization_factor * cvm::exp(log_sum_exp_1 - log_sum_exp_0);
if (s != nullptr) {
*s = saved_s;
}
if (z != nullptr) {
*z = -1.0 / lambda * log_sum_exp_0;
}
}
template <typename scalar_type>
void ArithmeticPathBase<scalar_type>::reComputeLambda(const vector<scalar_type>& rmsd_between_refs) {
scalar_type mean_square_displacements = 0.0;
for (size_t i_frame = 1; i_frame < total_frames; ++i_frame) {
cvm::log(std::string("Distance between frame ") + cvm::to_str(i_frame) + " and " + cvm::to_str(i_frame + 1) + " is " + cvm::to_str(rmsd_between_refs[i_frame - 1]) + std::string("\n"));
mean_square_displacements += rmsd_between_refs[i_frame - 1] * rmsd_between_refs[i_frame - 1];
}
mean_square_displacements /= scalar_type(total_frames - 1);
lambda = 1.0 / mean_square_displacements;
}
// frame-wise derivatives for frames using optimal rotation
template <typename scalar_type>
template <typename element_type>
void ArithmeticPathBase<scalar_type>::computeDerivatives(
const vector<vector<element_type>>& frame_element_distances,
vector<vector<element_type>> *dsdx,
vector<vector<element_type>> *dzdx)
{
vector<scalar_type> softmax_out, tmps;
softmax_out.reserve(total_frames);
tmps.reserve(total_frames);
for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
softmax_out.push_back(exponents[i_frame] / saved_exponent_sum);
tmps.push_back(
(static_cast<scalar_type>(i_frame) -
static_cast<scalar_type>(total_frames - 1) * saved_s) *
normalization_factor);
}
if (dsdx != nullptr) {
for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
(*dsdx)[i_frame][j_elem] =
-2.0 * squared_weights[j_elem] * lambda *
frame_element_distances[i_frame][j_elem] *
softmax_out[i_frame] * tmps[i_frame];
}
}
}
if (dzdx != nullptr) {
for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
(*dzdx)[i_frame][j_elem] =
2.0 * squared_weights[j_elem] * softmax_out[i_frame] *
frame_element_distances[i_frame][j_elem];
}
}
}
}
}
#endif // ARITHMETICPATHCV_H