// Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_KALMAN_FiLTER_Hh_ #define DLIB_KALMAN_FiLTER_Hh_ #include "kalman_filter_abstract.h" #include "../matrix.h" #include "../geometry.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < long states, long measurements > class kalman_filter { public: kalman_filter() { H = 0; A = 0; Q = 0; R = 0; x = 0; xb = 0; P = identity_matrix<double>(states); got_first_meas = false; } void set_observation_model ( const matrix<double,measurements,states>& H_) { H = H_; } void set_transition_model ( const matrix<double,states,states>& A_) { A = A_; } void set_process_noise ( const matrix<double,states,states>& Q_) { Q = Q_; } void set_measurement_noise ( const matrix<double,measurements,measurements>& R_) { R = R_; } void set_estimation_error_covariance( const matrix<double,states,states>& P_) { P = P_; } void set_state ( const matrix<double,states,1>& xb_) { xb = xb_; if (!got_first_meas) { x = xb_; got_first_meas = true; } } const matrix<double,measurements,states>& get_observation_model ( ) const { return H; } const matrix<double,states,states>& get_transition_model ( ) const { return A; } const matrix<double,states,states>& get_process_noise ( ) const { return Q; } const matrix<double,measurements,measurements>& get_measurement_noise ( ) const { return R; } void update ( ) { // propagate estimation error covariance forward P = A*P*trans(A) + Q; // propagate state forward x = xb; xb = A*x; } void update (const matrix<double,measurements,1>& z) { // propagate estimation error covariance forward P = A*P*trans(A) + Q; // compute Kalman gain matrix const matrix<double,states,measurements> K = P*trans(H)*pinv(H*P*trans(H) + R); if (got_first_meas) { const matrix<double,measurements,1> res = z - H*xb; // correct the current state estimate x = xb + K*res; } else { // Since we don't have a previous state estimate at the start of filtering, // we will just set the current state to whatever is indicated by the measurement x = pinv(H)*z; got_first_meas = true; } // propagate state forward in time xb = A*x; // update estimation error covariance since we got a measurement. P = (identity_matrix<double,states>() - K*H)*P; } const matrix<double,states,1>& get_current_state( ) const { return x; } const matrix<double,states,1>& get_predicted_next_state( ) const { return xb; } const matrix<double,states,states>& get_current_estimation_error_covariance( ) const { return P; } friend inline void serialize(const kalman_filter& item, std::ostream& out) { int version = 1; serialize(version, out); serialize(item.got_first_meas, out); serialize(item.x, out); serialize(item.xb, out); serialize(item.P, out); serialize(item.H, out); serialize(item.A, out); serialize(item.Q, out); serialize(item.R, out); } friend inline void deserialize(kalman_filter& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 1) throw dlib::serialization_error("Unknown version number found while deserializing kalman_filter object."); deserialize(item.got_first_meas, in); deserialize(item.x, in); deserialize(item.xb, in); deserialize(item.P, in); deserialize(item.H, in); deserialize(item.A, in); deserialize(item.Q, in); deserialize(item.R, in); } private: bool got_first_meas; matrix<double,states,1> x, xb; matrix<double,states,states> P; matrix<double,measurements,states> H; matrix<double,states,states> A; matrix<double,states,states> Q; matrix<double,measurements,measurements> R; }; // ---------------------------------------------------------------------------------------- class momentum_filter { public: momentum_filter( double meas_noise, double acc, double max_meas_dev ) : measurement_noise(meas_noise), typical_acceleration(acc), max_measurement_deviation(max_meas_dev) { DLIB_CASSERT(meas_noise >= 0); DLIB_CASSERT(acc >= 0); DLIB_CASSERT(max_meas_dev >= 0); kal.set_observation_model({1, 0}); kal.set_transition_model( {1, 1, 0, 1}); kal.set_process_noise({0, 0, 0, typical_acceleration*typical_acceleration}); kal.set_measurement_noise({measurement_noise*measurement_noise}); } momentum_filter() = default; double get_measurement_noise ( ) const { return measurement_noise; } double get_typical_acceleration ( ) const { return typical_acceleration; } double get_max_measurement_deviation ( ) const { return max_measurement_deviation; } void reset() { *this = momentum_filter(measurement_noise, typical_acceleration, max_measurement_deviation); } double get_predicted_next_position( ) const { return kal.get_predicted_next_state()(0); } double operator()( const double measured_position ) { auto x = kal.get_predicted_next_state(); const auto max_deviation = max_measurement_deviation*measurement_noise; // Check if measured_position has suddenly jumped in value by a whole lot. This // could happen if the velocity term experiences a much larger than normal // acceleration, e.g. because the underlying object is doing a maneuver. If // this happens then we clamp the state so that the predicted next value is no // more than max_deviation away from measured_position at all times. if (x(0) > measured_position + max_deviation) { x(0) = measured_position + max_deviation; kal.set_state(x); } else if (x(0) < measured_position - max_deviation) { x(0) = measured_position - max_deviation; kal.set_state(x); } kal.update({measured_position}); return kal.get_current_state()(0); } friend std::ostream& operator << (std::ostream& out, const momentum_filter& item) { out << "measurement_noise: " << item.measurement_noise << "\n"; out << "typical_acceleration: " << item.typical_acceleration << "\n"; out << "max_measurement_deviation: " << item.max_measurement_deviation; return out; } friend void serialize(const momentum_filter& item, std::ostream& out) { int version = 15; serialize(version, out); serialize(item.measurement_noise, out); serialize(item.typical_acceleration, out); serialize(item.max_measurement_deviation, out); serialize(item.kal, out); } friend void deserialize(momentum_filter& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 15) throw serialization_error("Unexpected version found while deserializing momentum_filter."); deserialize(item.measurement_noise, in); deserialize(item.typical_acceleration, in); deserialize(item.max_measurement_deviation, in); deserialize(item.kal, in); } private: double measurement_noise = 2; double typical_acceleration = 0.1; double max_measurement_deviation = 3; // nominally number of standard deviations kalman_filter<2,1> kal; }; // ---------------------------------------------------------------------------------------- momentum_filter find_optimal_momentum_filter ( const std::vector<std::vector<double>>& sequences, const double smoothness = 1 ); // ---------------------------------------------------------------------------------------- momentum_filter find_optimal_momentum_filter ( const std::vector<double>& sequence, const double smoothness = 1 ); // ---------------------------------------------------------------------------------------- class rect_filter { public: rect_filter() = default; rect_filter( double meas_noise, double acc, double max_meas_dev ) : rect_filter(momentum_filter(meas_noise, acc, max_meas_dev)) {} rect_filter( const momentum_filter& filt ) : left(filt), top(filt), right(filt), bottom(filt) { } drectangle operator()(const drectangle& r) { return drectangle(left(r.left()), top(r.top()), right(r.right()), bottom(r.bottom())); } drectangle operator()(const rectangle& r) { return drectangle(left(r.left()), top(r.top()), right(r.right()), bottom(r.bottom())); } const momentum_filter& get_left () const { return left; } momentum_filter& get_left () { return left; } const momentum_filter& get_top () const { return top; } momentum_filter& get_top () { return top; } const momentum_filter& get_right () const { return right; } momentum_filter& get_right () { return right; } const momentum_filter& get_bottom () const { return bottom; } momentum_filter& get_bottom () { return bottom; } friend void serialize(const rect_filter& item, std::ostream& out) { int version = 123; serialize(version, out); serialize(item.left, out); serialize(item.top, out); serialize(item.right, out); serialize(item.bottom, out); } friend void deserialize(rect_filter& item, std::istream& in) { int version = 0; deserialize(version, in); if (version != 123) throw dlib::serialization_error("Unknown version number found while deserializing rect_filter object."); deserialize(item.left, in); deserialize(item.top, in); deserialize(item.right, in); deserialize(item.bottom, in); } private: momentum_filter left, top, right, bottom; }; // ---------------------------------------------------------------------------------------- rect_filter find_optimal_rect_filter ( const std::vector<rectangle>& rects, const double smoothness = 1 ); // ---------------------------------------------------------------------------------------- } #endif // DLIB_KALMAN_FiLTER_Hh_