SourceXtractorPlusPlus 0.19.2
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
FlexibleModelFittingIterativeTask.cpp
Go to the documentation of this file.
1
22
26
28
36
40
41namespace SourceXtractor {
42
43using namespace ModelFitting;
44
45static auto logger = Elements::Logging::getLogger("FlexibleModelFitting");
46
62
65
66namespace {
67
69 auto fitting_rect = source.getProperty<MeasurementFrameRectangle>(frame_index).getRect();
70
71 if (fitting_rect.getWidth() <= 0 || fitting_rect.getHeight() <= 0) {
72 return PixelRectangle();
73 } else {
74 const auto& frame_info = source.getProperty<MeasurementFrameInfo>(frame_index);
75
76 auto min = fitting_rect.getTopLeft();
77 auto max = fitting_rect.getBottomRight();
78
79 // FIXME temporary, for now just enlarge the area by a fixed amount of pixels
80 PixelCoordinate border = (max - min) * .8 + PixelCoordinate(2, 2);
81
82 min -= border;
83 max += border;
84
85 // clip to image size
86 min.m_x = std::max(min.m_x, 0);
87 min.m_y = std::max(min.m_y, 0);
88 max.m_x = std::min(max.m_x, frame_info.getWidth() - 1);
89 max.m_y = std::min(max.m_y, frame_info.getHeight() - 1);
90
91 return PixelRectangle(min, max);
92 }
93}
94
95bool isFrameValid(SourceInterface& source, int frame_index) {
97 return stamp_rect.getWidth() > 0 && stamp_rect.getHeight() > 0;
98}
99
100std::shared_ptr<VectorImage<SeFloat>> createImageCopy(SourceInterface& source, int frame_index) {
101 const auto& frame_images = source.getProperty<MeasurementFrameImages>(frame_index);
103
105 LayerSubtractedImage, rect.getTopLeft().m_x, rect.getTopLeft().m_y, rect.getWidth(), rect.getHeight()));
106
107 return image;
108}
109
110std::shared_ptr<VectorImage<SeFloat>> createWeightImage(SourceInterface& source, int frame_index) {
111 const auto& frame_images = source.getProperty<MeasurementFrameImages>(frame_index);
112 auto frame_image = frame_images.getLockedImage(LayerSubtractedImage);
114 auto variance_map = frame_images.getLockedImage(LayerVarianceMap);
115
116 const auto& frame_info = source.getProperty<MeasurementFrameInfo>(frame_index);
117 SeFloat gain = frame_info.getGain();
118 SeFloat saturation = frame_info.getSaturation();
119
121 auto weight = VectorImage<SeFloat>::create(rect.getWidth(), rect.getHeight());
122
123 for (int y = 0; y < rect.getHeight(); y++) {
124 for (int x = 0; x < rect.getWidth(); x++) {
125 auto back_var = variance_map->getValue(rect.getTopLeft().m_x + x, rect.getTopLeft().m_y + y);
126 auto pixel_val = frame_image->getValue(rect.getTopLeft().m_x + x, rect.getTopLeft().m_y + y);
127 if (saturation > 0 && pixel_val > saturation) {
128 weight->at(x, y) = 0;
129 }
130 else if (gain > 0.0 && pixel_val > 0.0) {
131 weight->at(x, y) = sqrt(1.0 / (back_var + pixel_val / gain));
132 }
133 else {
134 weight->at(x, y) = sqrt(1.0 / back_var); // infinite gain
135 }
136 }
137 }
138
139
140 return weight;
141}
142
144 SourceInterface& source, double pixel_scale, FlexibleModelFittingParameterManager& manager,
146
147 int frame_index = frame->getFrameNb();
148
149 auto frame_coordinates = source.getProperty<MeasurementFrameCoordinates>(frame_index).getCoordinateSystem();
150 auto ref_coordinates = source.getProperty<DetectionFrameCoordinates>().getCoordinateSystem();
151
152 auto psf_property = source.getProperty<SourcePsfProperty>(frame_index);
153 auto jacobian = source.getProperty<JacobianSource>(frame_index).asTuple();
154
155 // The model fitting module expects to get a PSF with a pixel scale, but we have the pixel sampling step size
156 // It will be used to compute the rastering grid size, and after convolving with the PSF the result will be
157 // downscaled before copied into the frame image.
158 // We can multiply here then, as the unit is pixel/pixel, rather than "/pixel or similar
159 auto source_psf = DownSampledImagePsf(psf_property.getPixelSampling(), psf_property.getPsf(), down_scaling);
160
164
165 for (auto model : frame->getModels()) {
167 frame_coordinates, stamp_rect.getTopLeft());
168 }
169
170 // Full frame model with all sources
172 pixel_scale, (size_t) stamp_rect.getWidth(), (size_t) stamp_rect.getHeight(),
174
175 return frame_model;
176}
177
178}
179
180
183
184 for (auto& source : group) {
187 initial_state.iterations = 0;
188 initial_state.stop_reason = 0;
189 initial_state.reduced_chi_squared = 0.0;
190 initial_state.duration = 0.0;
191
192 for (auto parameter : m_parameters) {
194 if (free_parameter != nullptr) {
195 initial_state.parameters_values[free_parameter->getId()] = free_parameter->getInitialValue(source);
196 } else {
197 initial_state.parameters_values[parameter->getId()] = 0.0;
198 }
199 // Make sure we have a default value for sigmas in case we cannot do the fit
200 initial_state.parameters_sigmas[parameter->getId()] = std::numeric_limits<double>::quiet_NaN();
201 }
202 fitting_state.source_states.emplace_back(std::move(initial_state));
203 }
204
205 // TODO Sort sources by flux to fit brightest sources first?
206
207 // iterate over the whole group, fitting sources one at a time
208
209 double prev_chi_squared = 999999.9;
210 for (int iteration = 0; iteration < m_meta_iterations; iteration++) {
211 int index = 0;
212 for (auto& source : group) {
214 index++;
215 }
216
217 // evaluate reduced chi squared to bail out of meta iterations if no longer improving the fit
218
219 double chi_squared = 0.0;
220 for (auto& source_state : fitting_state.source_states) {
221 chi_squared += source_state.reduced_chi_squared;
222 }
223 chi_squared /= fitting_state.source_states.size();
224
226 break;
227 }
228
230 }
231
232
233 // Remove parameters that couldn't be fit from the output
234
235 for (size_t index = 0; index < group.size(); index++){
236 auto& source_state = fitting_state.source_states.at(index);
237
238 for (auto parameter : m_parameters) {
240
241 if (free_parameter != nullptr && !source_state.parameters_fitted[parameter->getId()]) {
242 source_state.parameters_values[parameter->getId()] = std::numeric_limits<double>::quiet_NaN();
243 source_state.parameters_sigmas[parameter->getId()] = std::numeric_limits<double>::quiet_NaN();
244 }
245 }
246 }
247
248 // output a property for every source
249 size_t index = 0;
250 for (auto& source : group) {
251 auto& source_state = fitting_state.source_states.at(index);
252
253 int meta_iterations = source_state.chi_squared_per_meta.size();
254 source_state.chi_squared_per_meta.resize(m_meta_iterations);
255 source_state.iterations_per_meta.resize(m_meta_iterations);
256
257 source.setProperty<FlexibleModelFitting>(source_state.iterations, source_state.stop_reason,
258 source_state.reduced_chi_squared, source_state.duration, source_state.flags,
259 source_state.parameters_values, source_state.parameters_sigmas,
260 source_state.chi_squared_per_meta, source_state.iterations_per_meta,
262
263 index++;
264 }
265
267}
268
269
273 int frame_index = frame->getFrameNb();
275
276 double pixel_scale = 1.0;
279 int n_free_parameters = 0;
280
281 int index = 0;
282 for (auto& src : group) {
283 if (index != source_index) {
284 for (auto parameter : m_parameters) {
286
287 if (free_parameter != nullptr) {
289
290 // Initial with the values from the current iteration run
291 parameter_manager.addParameter(src, parameter,
293 state.source_states[index].parameters_values.at(free_parameter->getId())));
294
295 } else {
296 parameter_manager.addParameter(src, parameter,
298 }
299 }
300 }
301 index++;
302 }
303
304 auto deblend_image = VectorImage<SeFloat>::create(rect.getWidth(), rect.getHeight());
305 index = 0;
306 for (auto& src : group) {
307 if (index != source_index) {
308 auto frame_model = createFrameModel(src, pixel_scale, parameter_manager, frame, rect);
309 auto final_stamp = frame_model.getImage();
310
311 for (int y = 0; y < final_stamp->getHeight(); ++y) {
312 for (int x = 0; x < final_stamp->getWidth(); ++x) {
313 deblend_image->at(x, y) += final_stamp->at(x, y);
314 }
315 }
316 }
317 index++;
318 }
319
320 return deblend_image;
321}
322
326 SourceInterface& source, int index, FittingState& state) const {
327 int free_parameters_nb = 0;
328 for (auto parameter : m_parameters) {
330
331 if (free_parameter != nullptr) {
333
334 // Initial with the values from the current iteration run
335 parameter_manager.addParameter(source, parameter,
337 state.source_states[index].parameters_values.at(free_parameter->getId())));
338 } else {
339 parameter_manager.addParameter(source, parameter,
341 }
342
343 }
344
345 // Reset access checks, as a dependent parameter could have triggered it
346 parameter_manager.clearAccessCheck();
347
348 return free_parameters_nb;
349}
350
353 SourceGroupInterface& group, SourceInterface& source, int index, FittingState& state, double down_scaling) const {
354
355 double pixel_scale = 1.0;
356
357 int valid_frames = 0;
358 for (auto frame : m_frames) {
359 int frame_index = frame->getFrameNb();
360 // Validate that each frame covers the model fitting region
361 if (isFrameValid(source, frame_index)) {
362 valid_frames++;
363
366
367 auto image = createImageCopy(source, frame_index);
368
369 auto deblend_image = createDeblendImage(group, source, index, frame, state);
370 for (int y = 0; y < image->getHeight(); ++y) {
371 for (int x = 0; x < image->getWidth(); ++x) {
372 image->at(x, y) -= m_deblend_factor * deblend_image->at(x, y);
373 }
374 }
375
376 auto weight = createWeightImage(source, frame_index);
377
378 // count number of pixels that can be used for fitting
379 for (int y = 0; y < weight->getHeight(); ++y) {
380 for (int x = 0; x < weight->getWidth(); ++x) {
381 good_pixels += (weight->at(x, y) != 0.);
382 }
383 }
384
385 // Setup residuals
386 auto data_vs_model =
388 //LogChiSquareComparator(m_modified_chi_squared_scale));
390 res_estimator.registerBlockProvider(std::move(data_vs_model));
391 }
392 }
393
394 return valid_frames;
395}
396
417
420 SeFloat avg_reduced_chi_squared, SeFloat duration, unsigned int iterations, unsigned int stop_reason, Flags flags,
422 int index, FittingState& state) const {
424 // Collect parameters for output
425 std::unordered_map<int, double> parameters_values, parameters_sigmas;
426 std::unordered_map<int, bool> parameters_fitted;
427
428 for (auto parameter : m_parameters) {
433
435 parameters_values[parameter->getId()] = modelfitting_parameter->getValue();
436 parameters_sigmas[parameter->getId()] = parameter->getSigma(parameter_manager, source, solution.parameter_sigmas);
437 parameters_fitted[parameter->getId()] = true;
438 } else {
439 parameters_values[parameter->getId()] = state.source_states[index].parameters_values[parameter->getId()];
440 parameters_sigmas[parameter->getId()] = state.source_states[index].parameters_sigmas[parameter->getId()];
441 parameters_fitted[parameter->getId()] = false;
442
443 // Need to cascade the NaN to any potential dependent parameter
445 if (engine_parameter) {
447 }
448
449 flags |= Flags::PARTIAL_FIT;
450 }
451 }
452
453 state.source_states[index].parameters_values = parameters_values;
454 state.source_states[index].parameters_sigmas = parameters_sigmas;
455 state.source_states[index].parameters_fitted = parameters_fitted;
456 state.source_states[index].reduced_chi_squared = avg_reduced_chi_squared;
457 state.source_states[index].chi_squared_per_meta.emplace_back(avg_reduced_chi_squared);
458 state.source_states[index].duration += duration;
459 state.source_states[index].iterations += iterations;
460 state.source_states[index].iterations_per_meta.emplace_back(iterations);
461 state.source_states[index].stop_reason = stop_reason;
462 state.source_states[index].flags = flags;
463}
464
466
468 // Determine size of fitted area and if needed downsize factor
469
470 double fit_size = 0;
471 for (auto frame : m_frames) {
472 int frame_index = frame->getFrameNb();
473 // Validate that each frame covers the model fitting region
474 if (isFrameValid(source, frame_index)) {
477 fit_size = std::max(fit_size, stamp_rect.getWidth() * stamp_rect.getHeight() /
478 (psf_property.getPixelSampling() * psf_property.getPixelSampling()));
479 }
480 }
481
483 if (fit_size > m_max_fit_size * 2.0) {
485 logger.warn() << "Exceeding max fit size: " << fit_size << " / " << m_max_fit_size
486 << " scaling factor: " << down_scaling;
487 }
488
490 // Prepare parameters
491
496
498 // Add models for all frames
500 int n_good_pixels = 0;
503
505 // Check that we had enough data for the fit
506
507 Flags flags = Flags::NONE;
508
509 if (valid_frames == 0) {
510 flags = Flags::OUTSIDE;
511 }
512 else if (n_good_pixels < n_free_parameters) {
514 }
515
516 // Do not run the model fitting for the flags above
517 if (flags != Flags::NONE) {
518 return;
519 }
520
521 if (down_scaling < 1.0) {
522 flags |= Flags::DOWNSAMPLED;
523 }
524
525
527 // Add priors
528 for (auto prior : m_priors) {
530 }
531
533 // Model fitting
534
537
538 auto iterations = solution.iteration_no;
539 auto stop_reason = solution.engine_stop_reason;
540 if (solution.status_flag == LeastSquareSummary::ERROR) {
541 flags |= Flags::ERROR;
542 }
543 auto duration = solution.duration;
544
546 // compute chi squared
547
549
551 // update state with results
552 fitSourceUpdateState(parameter_manager, source, avg_reduced_chi_squared, duration, iterations, stop_reason, flags, solution,
553 index, state);
554}
555
557 double pixel_scale, FittingState& state) const {
558
559 // recreate parameters
560
563
564 int index = 0;
565 for (auto& src : group) {
566 for (auto parameter : m_parameters) {
568
569 if (free_parameter != nullptr) {
570 // Initialize with the values from the current iteration run
571 parameter_manager.addParameter(src, parameter,
573 state.source_states[index].parameters_values.at(free_parameter->getId())));
574 } else {
575 parameter_manager.addParameter(src, parameter,
577 }
578 }
579 index++;
580 }
581
582 for (auto& src : group) {
583 for (auto frame : m_frames) {
584 int frame_index = frame->getFrameNb();
585
586 if (isFrameValid(src, frame_index)) {
588
589 auto frame_model = createFrameModel(src, pixel_scale, parameter_manager, frame, stamp_rect);
590 auto final_stamp = frame_model.getImage();
591
592 auto debug_image = CheckImages::getInstance().getModelFittingImage(frame_index);
593 if (debug_image) {
595 for (int x = 0; x < final_stamp->getWidth(); x++) {
596 for (int y = 0; y < final_stamp->getHeight(); y++) {
597 auto x_coord = stamp_rect.getTopLeft().m_x + x;
598 auto y_coord = stamp_rect.getTopLeft().m_y + y;
599 debug_image->setValue(x_coord, y_coord,
600 debugAccessor.getValue(x_coord, y_coord) + final_stamp->getValue(x, y));
601 }
602 }
603 }
604
605 }
606 }
607 }
608}
609
612 double reduced_chi_squared = 0.0;
613 data_points = 0;
614
617
618 for (int y=0; y < image->getHeight(); y++) {
619 for (int x=0; x < image->getWidth(); x++) {
620 double tmp = imageAccessor.getValue(x, y) - modelAccessor.getValue(x, y);
621 reduced_chi_squared += tmp * tmp * weightAccessor.getValue(x, y) * weightAccessor.getValue(x, y);
622 if (weightAccessor.getValue(x, y) > 0) {
623 data_points++;
624 }
625 }
626 }
627 return reduced_chi_squared;
628}
629
634 int valid_frames = 0;
635 for (auto frame : m_frames) {
636 int frame_index = frame->getFrameNb();
637 // Validate that each frame covers the model fitting region
638 if (isFrameValid(source, frame_index)) {
639 valid_frames++;
641 auto frame_model = createFrameModel(source, pixel_scale, manager, frame, stamp_rect);
642 auto final_stamp = frame_model.getImage();
643
644 auto image = createImageCopy(source, frame_index);
645 auto deblend_image = createDeblendImage(group, source, index, frame, state);
646 for (int y = 0; y < image->getHeight(); ++y) {
647 for (int x = 0; x < image->getWidth(); ++x) {
648 image->at(x, y) -= deblend_image->at(x, y);
649 }
650 }
651
652 auto weight = createWeightImage(source, frame_index);
653
654 int data_points = 0;
656
659 }
660 }
661
662 return total_chi_squared;
663}
664
665}
666
667
668
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > x
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > y
const double pixel_scale
Definition TestImage.cpp:74
static Logging getLogger(const std::string &name="")
void warn(const std::string &logMessage)
Data vs model comparator which computes a modified residual, using asinh.
Class responsible for managing the parameters the least square engine minimizes.
static std::shared_ptr< LeastSquareEngine > create(const std::string &name, unsigned max_iterations=1000)
Provides to the LeastSquareEngine the residual values.
static CheckImages & getInstance()
void fitSourceUpdateState(FlexibleModelFittingParameterManager &parameter_manager, SourceInterface &source, SeFloat avg_reduced_chi_squared, SeFloat duration, unsigned int iterations, unsigned int stop_reason, Flags flags, ModelFitting::LeastSquareSummary solution, int index, FittingState &state) const
void fitSource(SourceGroupInterface &group, SourceInterface &source, int index, FittingState &state) const
int fitSourcePrepareParameters(FlexibleModelFittingParameterManager &parameter_manager, ModelFitting::EngineParameterManager &engine_parameter_manager, SourceInterface &source, int index, FittingState &state) const
void computeProperties(SourceGroupInterface &group) const override
Computes one or more properties for the SourceGroup and/or the Sources it contains.
SeFloat fitSourceComputeChiSquared(FlexibleModelFittingParameterManager &parameter_manager, SourceGroupInterface &group, SourceInterface &source, int index, FittingState &state) const
std::vector< std::shared_ptr< FlexibleModelFittingFrame > > m_frames
std::shared_ptr< VectorImage< SeFloat > > createDeblendImage(SourceGroupInterface &group, SourceInterface &source, int source_index, std::shared_ptr< FlexibleModelFittingFrame > frame, FittingState &state) const
int fitSourcePrepareModels(FlexibleModelFittingParameterManager &parameter_manager, ModelFitting::ResidualEstimator &res_estimator, int &good_pixels, SourceGroupInterface &group, SourceInterface &source, int index, FittingState &state, double downscaling) const
FlexibleModelFittingIterativeTask(const std::string &least_squares_engine, unsigned int max_iterations, double modified_chi_squared_scale, std::vector< std::shared_ptr< FlexibleModelFittingParameter > > parameters, std::vector< std::shared_ptr< FlexibleModelFittingFrame > > frames, std::vector< std::shared_ptr< FlexibleModelFittingPrior > > priors, double scale_factor=1.0, int meta_iterations=3, double deblend_factor=1.0, double meta_iteration_stop=0.0001, size_t max_fit_size=100)
std::vector< std::shared_ptr< FlexibleModelFittingParameter > > m_parameters
SeFloat computeChiSquared(SourceGroupInterface &group, SourceInterface &source, int index, double pixel_scale, FlexibleModelFittingParameterManager &manager, int &total_data_points, FittingState &state) const
SeFloat computeChiSquaredForFrame(std::shared_ptr< const Image< SeFloat > > image, std::shared_ptr< const Image< SeFloat > > model, std::shared_ptr< const Image< SeFloat > > weights, int &data_points) const
void updateCheckImages(SourceGroupInterface &group, double pixel_scale, FittingState &state) const
std::vector< std::shared_ptr< FlexibleModelFittingPrior > > m_priors
Defines the interface used to group sources.
The SourceInterface is an abstract "source" that has properties attached to it.
static std::shared_ptr< VectorImage< T > > create(Args &&... args)
T fabs(T... args)
T max(T... args)
T min(T... args)
T move(T... args)
static Elements::Logging logger
std::unique_ptr< DataVsModelResiduals< typename std::remove_reference< DataType >::type, typename std::remove_reference< ModelType >::type, typename std::remove_reference< WeightType >::type, typename std::remove_reference< Comparator >::type > > createDataVsModelResiduals(DataType &&data, ModelType &&model, WeightType &&weight, Comparator &&comparator)
Flags
Flagging of bad sources.
Definition SourceFlags.h:37
@ DOWNSAMPLED
The fit was done on a downsampled image due to exceeding max size.
@ OUTSIDE
The object is completely outside of the measurement frame.
@ NONE
No flag is set.
@ ERROR
Error flag: something bad happened during the measurement, model fitting, etc.
@ INSUFFICIENT_DATA
There are not enough good pixels to fit the parameters.
@ PARTIAL_FIT
Some/all of the model parameters could not be fitted.
SeFloat32 SeFloat
Definition Types.h:32
@ LayerVarianceMap
Definition Frame.h:45
@ LayerThresholdedImage
Definition Frame.h:41
@ LayerSubtractedImage
Definition Frame.h:39
T quiet_NaN(T... args)
T sqrt(T... args)
Class containing the summary information of solving a least square minimization problem.