-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathShapeScalarJob.h
73 lines (49 loc) · 1.98 KB
/
ShapeScalarJob.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#pragma once
#include <Data/Session.h>
#include <Job/Job.h>
#include <ParticleShapeStatistics.h>
#include <QPixmap>
namespace shapeworks {
class Project;
class ShapeScalarJob : public Job {
Q_OBJECT
public:
enum class JobType { Find_Components, MSE_Plot, Predict };
enum class Direction { To_Shape, To_Scalar };
ShapeScalarJob(QSharedPointer<Session> session, QString target_feature, Eigen::MatrixXd target_particles,
JobType job_type);
void run() override;
QString name() override;
QPixmap get_plot();
void set_number_of_components(int num_components) { num_components_ = num_components; }
void set_number_of_folds(int num_folds) { num_folds_ = num_folds; }
void set_max_number_of_components(int num) { max_components_ = num; }
Eigen::VectorXd get_prediction() { return prediction_; };
static Eigen::VectorXd predict_scalars(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_particles);
static Eigen::VectorXd predict_shape(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_particles);
static void clear_model() { needs_clear_ = true; };
void set_direction(Direction direction) { direction_ = direction; }
private:
void prep_data();
void run_fit();
void run_prediction();
static Eigen::VectorXd predict(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_particles, Direction direction);
QSharedPointer<Session> session_;
ParticleShapeStatistics stats_;
QString target_feature_;
QPixmap plot_;
Eigen::MatrixXd all_particles_;
Eigen::MatrixXd all_scalars_;
Eigen::MatrixXd target_values_;
Eigen::VectorXd prediction_;
bool num_components_ = 3;
int num_folds_ = 5;
int max_components_ = 20;
Direction direction_{Direction::To_Scalar};
JobType job_type_;
static std::atomic<bool> needs_clear_;
};
} // namespace shapeworks