Skip to content

Commit

Permalink
Yuhsuan/150 2d fitting (#1055)
Browse files Browse the repository at this point in the history
* add image fitting basic feature

* adjust protobuf; remove ImageFitter

* add image fitter using GSL

* reformat

* reformat

* add error handler

* minor change

* adjust SetResults, SetLog

* nan pixel handling

* calculate errors

* add struct FitStatus; check if image cache is filled before fitting

* fix bug in Gaussian(); modify image generator for testing

* remove protobuf message from image fitter

* add unit tests

* update protobuf

* fix _image_cache

* reformat

* parallelize Gaussian computation

* update protobuf

* handle undefined unit

* change to nullptr; avoid copying image data

* adjust unit test

* adjust cmake; adjust error estimate; handle special case of reaching max iter

* update unit test

* update protobuf

* update image data pointer to avoid wrong data after animation

* reformat

* improve FuncF algorithm

* update protobuf; adjust pa definition; add residual variance to log

* adjust response message; rename to _fit_values

* change to double to increase precision

* update changelog

* update protobuf

* adjust test/CMakeLists

* update protobuf
  • Loading branch information
YuHsuan-Hwang authored Apr 26, 2022
1 parent edc2561 commit a7f87d0
Show file tree
Hide file tree
Showing 14 changed files with 467 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
* Added data type to file info and open complex image with amplitude expression ([#520](https://github.com/CARTAvis/carta-backend/issues/520)).
* Added ability to set a custom rest frequency for saving subimages. ([#918](https://github.com/CARTAvis/carta-backend/issues/918)).
* Added image fitter for multiple 2D Gaussian component fitting ([#150](https://github.com/CARTAvis/carta-backend/issues/150)).

## [3.0.0-beta.2]

Expand Down
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ endif ()
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

FIND_PACKAGE(CURL REQUIRED)
FIND_PACKAGE(GSL REQUIRED)
FIND_PACKAGE(ZFP CONFIG REQUIRED)
FIND_PACKAGE(PkgConfig REQUIRED)

Expand Down Expand Up @@ -168,6 +169,7 @@ set(LINK_LIBS
pugixml
curl
wcs
gsl
casa_casa
casa_coordinates
casa_tables
Expand Down Expand Up @@ -236,7 +238,8 @@ set(SOURCE_FILES
src/Util/Image.cc
src/Util/Message.cc
src/Util/String.cc
src/Util/Token.cc)
src/Util/Token.cc
src/ImageFitter/ImageFitter.cc)

add_definitions(-DHAVE_HDF5)
add_executable(carta_backend ${SOURCE_FILES})
Expand Down
16 changes: 16 additions & 0 deletions src/Frame/Frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,22 @@ void Frame::StopMomentCalc() {
}
}

bool Frame::FitImage(const CARTA::FittingRequest& fitting_request, CARTA::FittingResponse& fitting_response) {
if (!_image_fitter) {
_image_fitter = std::make_unique<ImageFitter>(_width, _height);
}

bool success = false;
if (_image_fitter) {
FillImageCache();
std::vector<CARTA::GaussianComponent> initial_values(
fitting_request.initial_values().begin(), fitting_request.initial_values().end());
success = _image_fitter->FitImage(_image_cache.get(), initial_values, fitting_response);
}

return success;
}

// Export modified image to file, for changed range of channels/stokes and chopped region
// Input root_folder as target path
// Input save_file_msg as requesting parameters
Expand Down
8 changes: 8 additions & 0 deletions src/Frame/Frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <carta-protobuf/contour.pb.h>
#include <carta-protobuf/defs.pb.h>
#include <carta-protobuf/fitting_request.pb.h>
#include <carta-protobuf/raster_tile.pb.h>
#include <carta-protobuf/region_histogram.pb.h>
#include <carta-protobuf/region_requirements.pb.h>
Expand All @@ -33,6 +34,7 @@
#include "DataStream/Contouring.h"
#include "DataStream/Tile.h"
#include "ImageData/FileLoader.h"
#include "ImageFitter/ImageFitter.h"
#include "ImageGenerators/ImageGenerator.h"
#include "ImageGenerators/MomentGenerator.h"
#include "ImageStats/BasicStatsCalculator.h"
Expand Down Expand Up @@ -187,6 +189,9 @@ class Frame {
const CARTA::MomentRequest& moment_request, CARTA::MomentResponse& moment_response, std::vector<GeneratedImage>& collapse_results);
void StopMomentCalc();

// Image fitting
bool FitImage(const CARTA::FittingRequest& fitting_request, CARTA::FittingResponse& fitting_response);

// Save as a new file or export sub-image to CASA/FITS format
void SaveFile(const std::string& root_folder, const CARTA::SaveFile& save_file_msg, CARTA::SaveFileAck& save_file_ack,
std::shared_ptr<Region> image_region);
Expand Down Expand Up @@ -307,6 +312,9 @@ class Frame {

// Moment generator
std::unique_ptr<MomentGenerator> _moment_generator;

// Image fitter
std::unique_ptr<ImageFitter> _image_fitter;
};

} // namespace carta
Expand Down
230 changes: 230 additions & 0 deletions src/ImageFitter/ImageFitter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/* This file is part of the CARTA Image Viewer: https://github.com/CARTAvis/carta-backend
Copyright 2018, 2019, 2020, 2021 Academia Sinica Institute of Astronomy and Astrophysics (ASIAA),
Associated Universities, Inc. (AUI) and the Inter-University Institute for Data Intensive Astronomy (IDIA)
SPDX-License-Identifier: GPL-3.0-or-later
*/

#define SQ_FWHM_TO_SIGMA 1 / 8 / log(2)
#define DEG_TO_RAD M_PI / 180.0

#include "ImageFitter.h"

#include <omp.h>

using namespace carta;

ImageFitter::ImageFitter(size_t width, size_t height) {
_fit_data.width = width;
_fit_data.n = width * height;

_fdf.f = FuncF;
_fdf.df = nullptr; // internally computed using finite difference approximations of f when set to NULL
_fdf.fvv = nullptr;
_fdf.params = &_fit_data;
_fdf.n = _fit_data.n;

// avoid GSL default error handler calling abort()
gsl_set_error_handler(&ErrorHandler);
}

bool ImageFitter::FitImage(
float* image, const std::vector<CARTA::GaussianComponent>& initial_values, CARTA::FittingResponse& fitting_response) {
bool success = false;
_fit_data.data = image;
CalculateNanNum();
SetInitialValues(initial_values);

spdlog::info("Fitting image ({} data points) with {} Gaussian component(s).", _fit_data.n, _num_components);
int status = SolveSystem();

if (status == GSL_EMAXITER && _fit_status.num_iter < _max_iter) {
fitting_response.set_message("fit did not converge");
} else if (status) {
fitting_response.set_message(gsl_strerror(status));
}

if (!status || (status == GSL_EMAXITER && _fit_status.num_iter == _max_iter)) {
success = true;
spdlog::info("Writing fitting results and log.");
for (size_t i = 0; i < _num_components; i++) {
fitting_response.add_result_values();
*fitting_response.mutable_result_values(i) = GetGaussianComponent(_fit_values, i);
fitting_response.add_result_errors();
*fitting_response.mutable_result_errors(i) = GetGaussianComponent(_fit_errors, i);
}
fitting_response.set_log(GetLog());
}
fitting_response.set_success(success);

gsl_vector_free(_fit_values);
gsl_vector_free(_fit_errors);
return success;
}

void ImageFitter::CalculateNanNum() {
_fit_data.n = _fdf.n;
for (size_t i = 0; i < _fit_data.n; i++) {
if (isnan(_fit_data.data[i])) {
_fit_data.n--;
}
}
}

void ImageFitter::SetInitialValues(const std::vector<CARTA::GaussianComponent>& initial_values) {
_num_components = initial_values.size();

size_t p = _num_components * 6;
_fit_values = gsl_vector_alloc(p);
_fit_errors = gsl_vector_alloc(p);
for (size_t i = 0; i < _num_components; i++) {
CARTA::GaussianComponent component(initial_values[i]);
gsl_vector_set(_fit_values, i * 6, component.center().x());
gsl_vector_set(_fit_values, i * 6 + 1, component.center().y());
gsl_vector_set(_fit_values, i * 6 + 2, component.amp());
gsl_vector_set(_fit_values, i * 6 + 3, component.fwhm().x());
gsl_vector_set(_fit_values, i * 6 + 4, component.fwhm().y());
gsl_vector_set(_fit_values, i * 6 + 5, component.pa());
}
_fdf.p = p;
}

int ImageFitter::SolveSystem() {
gsl_multifit_nlinear_parameters fdf_params = gsl_multifit_nlinear_default_parameters();
const gsl_multifit_nlinear_type* T = gsl_multifit_nlinear_trust;
const double xtol = 1.0e-8;
const double gtol = 1.0e-8;
const double ftol = 1.0e-8;
const bool print_iter = false;
const size_t n = _fdf.n;
const size_t p = _fdf.p;
gsl_multifit_nlinear_workspace* work = gsl_multifit_nlinear_alloc(T, &fdf_params, n, p);
gsl_vector* f = gsl_multifit_nlinear_residual(work);
gsl_vector* y = gsl_multifit_nlinear_position(work);
gsl_matrix* covar = gsl_matrix_alloc(p, p);

gsl_multifit_nlinear_init(_fit_values, &_fdf, work);
gsl_blas_ddot(f, f, &_fit_status.chisq0);
int status =
gsl_multifit_nlinear_driver(_max_iter, xtol, gtol, ftol, print_iter ? Callback : nullptr, nullptr, &_fit_status.info, work);
gsl_blas_ddot(f, f, &_fit_status.chisq);
gsl_multifit_nlinear_rcond(&_fit_status.rcond, work);
gsl_vector_memcpy(_fit_values, y);

gsl_matrix* jac = gsl_multifit_nlinear_jac(work);
gsl_multifit_nlinear_covar(jac, 0.0, covar);
const double c = sqrt(_fit_status.chisq / (_fit_data.n - p));
for (size_t i = 0; i < p; i++) {
gsl_vector_set(_fit_errors, i, c * sqrt(gsl_matrix_get(covar, i, i)));
}

_fit_status.method = fmt::format("{}/{}", gsl_multifit_nlinear_name(work), gsl_multifit_nlinear_trs_name(work));
_fit_status.num_iter = gsl_multifit_nlinear_niter(work);

gsl_multifit_nlinear_free(work);
gsl_matrix_free(covar);
return status;
}

std::string ImageFitter::GetLog() {
std::string info;
switch (_fit_status.info) {
case 1:
info = "small step size";
break;
case 2:
info = "small gradient";
break;
case 0:
default:
info = "exceeded max number of iterations";
break;
}

std::string log = fmt::format("Gaussian fitting with {} component(s)\n", _num_components);
log += fmt::format("summary from method '{}':\n", _fit_status.method);
log += fmt::format("number of iterations = {}\n", _fit_status.num_iter);
log += fmt::format("function evaluations = {}\n", _fdf.nevalf);
log += fmt::format("Jacobian evaluations = {}\n", _fdf.nevaldf);
log += fmt::format("reason for stopping = {}\n", info);
log += fmt::format("initial |f(x)| = {:.12e}\n", sqrt(_fit_status.chisq0));
log += fmt::format("final |f(x)| = {:.12e}\n", sqrt(_fit_status.chisq));
log += fmt::format("initial cost = {:.12e}\n", _fit_status.chisq0);
log += fmt::format("final cost = {:.12e}\n", _fit_status.chisq);
log += fmt::format("residual variance = {:.12e}\n", _fit_status.chisq / (_fit_data.n - _fdf.p));
log += fmt::format("final cond(J) = {:.12e}\n", 1.0 / _fit_status.rcond);

return log;
}

int ImageFitter::FuncF(const gsl_vector* fit_values, void* fit_data, gsl_vector* f) {
struct FitData* d = (struct FitData*)fit_data;

for (size_t k = 0; k < fit_values->size; k += 6) {
const double center_x = gsl_vector_get(fit_values, k);
const double center_y = gsl_vector_get(fit_values, k + 1);
const double amp = gsl_vector_get(fit_values, k + 2);
const double fwhm_x = gsl_vector_get(fit_values, k + 3);
const double fwhm_y = gsl_vector_get(fit_values, k + 4);
const double pa = gsl_vector_get(fit_values, k + 5);

const double dbl_sq_std_x = 2 * fwhm_x * fwhm_x * SQ_FWHM_TO_SIGMA;
const double dbl_sq_std_y = 2 * fwhm_y * fwhm_y * SQ_FWHM_TO_SIGMA;
const double theta_radian = (pa - 90.0) * DEG_TO_RAD; // counterclockwise rotation
const double a = cos(theta_radian) * cos(theta_radian) / dbl_sq_std_x + sin(theta_radian) * sin(theta_radian) / dbl_sq_std_y;
const double dbl_b = 2 * (sin(2 * theta_radian) / (2 * dbl_sq_std_x) - sin(2 * theta_radian) / (2 * dbl_sq_std_y));
const double c = sin(theta_radian) * sin(theta_radian) / dbl_sq_std_x + cos(theta_radian) * cos(theta_radian) / dbl_sq_std_y;

#pragma omp parallel for
for (size_t i = 0; i < d->n; i++) {
float data_i = d->data[i];
if (!isnan(data_i)) {
double dx = i % d->width - center_x;
double dy = i / d->width - center_y;
float data = amp * exp(-(a * dx * dx + dbl_b * dx * dy + c * dy * dy));
if (k == 0) {
gsl_vector_set(f, i, data_i - data);
} else {
gsl_vector_set(f, i, gsl_vector_get(f, i) - data);
}
} else {
gsl_vector_set(f, i, 0);
}
}
}

return GSL_SUCCESS;
}

void ImageFitter::Callback(const size_t iter, void* params, const gsl_multifit_nlinear_workspace* w) {
gsl_vector* f = gsl_multifit_nlinear_residual(w);
gsl_vector* x = gsl_multifit_nlinear_position(w);
double avratio = gsl_multifit_nlinear_avratio(w);
double rcond;
gsl_multifit_nlinear_rcond(&rcond, w);

spdlog::debug("iter {}, |a|/|v| = {:.4f} cond(J) = {:8.4f}, |f(x)| = {:.4f}", iter, avratio, 1.0 / rcond, gsl_blas_dnrm2(f));
for (int k = 0; k < x->size / 6; ++k) {
spdlog::debug("component {}: ({:.12f}, {:.12f}, {:.12f}, {:.12f}, {:.12f}, {:.12f})", k + 1, gsl_vector_get(x, k * 6),
gsl_vector_get(x, k * 6 + 1), gsl_vector_get(x, k * 6 + 2), gsl_vector_get(x, k * 6 + 3), gsl_vector_get(x, k * 6 + 4),
gsl_vector_get(x, k * 6 + 5));
}
}

void ImageFitter::ErrorHandler(const char* reason, const char* file, int line, int gsl_errno) {
spdlog::error("gsl error: {} line{}: {}", file, line, reason);
}

CARTA::GaussianComponent ImageFitter::GetGaussianComponent(gsl_vector* value_vector, size_t index) {
CARTA::GaussianComponent component;
CARTA::DoublePoint center;
center.set_x(gsl_vector_get(value_vector, index * 6));
center.set_y(gsl_vector_get(value_vector, index * 6 + 1));
*component.mutable_center() = center;
component.set_amp(gsl_vector_get(value_vector, index * 6 + 2));
CARTA::DoublePoint fwhm;
fwhm.set_x(gsl_vector_get(value_vector, index * 6 + 3));
fwhm.set_y(gsl_vector_get(value_vector, index * 6 + 4));
*component.mutable_fwhm() = fwhm;
component.set_pa(gsl_vector_get(value_vector, index * 6 + 5));
return component;
}
63 changes: 63 additions & 0 deletions src/ImageFitter/ImageFitter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* This file is part of the CARTA Image Viewer: https://github.com/CARTAvis/carta-backend
Copyright 2018, 2019, 2020, 2021 Academia Sinica Institute of Astronomy and Astrophysics (ASIAA),
Associated Universities, Inc. (AUI) and the Inter-University Institute for Data Intensive Astronomy (IDIA)
SPDX-License-Identifier: GPL-3.0-or-later
*/

#ifndef CARTA_BACKEND_IMAGEFITTER_IMAGEFITTER_H_
#define CARTA_BACKEND_IMAGEFITTER_IMAGEFITTER_H_

#include <gsl/gsl_blas.h>
#include <gsl/gsl_multifit_nlinear.h>
#include <gsl/gsl_vector.h>
#include <string>
#include <vector>

#include <carta-protobuf/fitting_request.pb.h>

#include "Logger/Logger.h"

namespace carta {

struct FitData {
float* data;
size_t width;
size_t n; // number of pixels excluding nan pixels
};

struct FitStatus {
std::string method;
size_t num_iter;
int info;
double chisq0, chisq, rcond;
};

class ImageFitter {
public:
ImageFitter(size_t width, size_t height);
bool FitImage(float* image, const std::vector<CARTA::GaussianComponent>& initial_values, CARTA::FittingResponse& fitting_response);

private:
FitData _fit_data;
size_t _num_components;
gsl_vector* _fit_values;
gsl_vector* _fit_errors;
gsl_multifit_nlinear_fdf _fdf;
FitStatus _fit_status;
const size_t _max_iter = 200;

void CalculateNanNum();
void SetInitialValues(const std::vector<CARTA::GaussianComponent>& initial_values);
int SolveSystem();
void SetResults();
std::string GetLog();

static int FuncF(const gsl_vector* fit_params, void* fit_data, gsl_vector* f);
static void Callback(const size_t iter, void* params, const gsl_multifit_nlinear_workspace* w);
static void ErrorHandler(const char* reason, const char* file, int line, int gsl_errno);
static CARTA::GaussianComponent GetGaussianComponent(gsl_vector* value_vector, size_t index);
};

} // namespace carta

#endif // CARTA_BACKEND_IMAGEFITTER_IMAGEFITTER_H_
2 changes: 2 additions & 0 deletions src/Session/OnMessageTask.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class GeneralMessageTask : public OnMessageTask {
_session->OnCatalogFileList(_message, _request_id);
} else if constexpr (std::is_same_v<T, CARTA::PvRequest>) {
_session->OnPvRequest(_message, _request_id);
} else if constexpr (std::is_same_v<T, CARTA::FittingRequest>) {
_session->OnFittingRequest(_message, _request_id);
} else {
spdlog::warn("Bad event type for GeneralMessageTask!");
}
Expand Down
Loading

0 comments on commit a7f87d0

Please sign in to comment.