blob: 966df972660d6f19e4d548e59d56018a448dee76 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ui/events/blink/prediction/kalman_filter.h"
namespace {
constexpr uint32_t kStableIterNum = 4;
constexpr double kSigmaProcess = 0.01;
constexpr double kSigmaMeasurement = 1.0;
gfx::Matrix3F GetStateTransition(double kDt) {
gfx::Matrix3F state_transition = gfx::Matrix3F::Zeros();
// State translation matrix is basic physics.
// new_pos = pre_pos + v * dt + 1/2 * a * dt^2.
// new_v = v + a * dt.
// new_a = a .
state_transition.set(1.0, kDt, 0.5 * kDt * kDt, 0.0, 1.0, kDt, 0.0, 0.0, 1.0);
return state_transition;
}
gfx::Matrix3F GetProcessNoise(double kDt) {
gfx::Vector3dF process_noise(0.5 * kDt * kDt, kDt, 1.0);
// We model the system noise as noisy force on the pointer.
// The following matrix describes the impact of that noise on each state.
gfx::Matrix3F process_noise_covariance = gfx::Matrix3F::FromOuterProduct(
process_noise, gfx::ScaleVector3d(process_noise, kSigmaProcess));
return process_noise_covariance;
}
} // namespace
namespace ui {
KalmanFilter::KalmanFilter()
: state_estimation_(gfx::Vector3dF()),
error_covariance_matrix_(gfx::Matrix3F::Identity()),
state_transition_matrix_(GetStateTransition(1.0)),
process_noise_covariance_matrix_(GetProcessNoise(1.0)),
measurement_vector_(gfx::Vector3dF(1.0, 0.0, 0.0)),
measurement_noise_variance_(kSigmaMeasurement),
iteration_count_(0) {}
KalmanFilter::~KalmanFilter() = default;
const gfx::Vector3dF& KalmanFilter::GetStateEstimation() const {
return state_estimation_;
}
bool KalmanFilter::Stable() const {
return iteration_count_ >= kStableIterNum;
}
void KalmanFilter::Update(double observation, double dt) {
if (iteration_count_++ == 0) {
// We only update the state estimation in the first iteration.
state_estimation_ = gfx::Vector3dF(observation, 0, 0);
return;
}
Predict(dt);
// Y = z - H * X
double y =
observation - gfx::DotProduct(measurement_vector_, state_estimation_);
// S = H * P * H' + R
double S = gfx::DotProduct(measurement_vector_,
gfx::MatrixProduct(error_covariance_matrix_,
measurement_vector_)) +
measurement_noise_variance_;
// K = P * H' * inv(S)
gfx::Vector3dF kalman_gain = gfx::ScaleVector3d(
gfx::MatrixProduct(error_covariance_matrix_, measurement_vector_),
1.0 / S);
// X = X + K * Y
state_estimation_ += gfx::ScaleVector3d(kalman_gain, y);
// I_HK = eye(P) - K * H
gfx::Matrix3F I_KH =
gfx::Matrix3F::Identity() -
gfx::Matrix3F::FromOuterProduct(kalman_gain, measurement_vector_);
// P = I_KH * P * I_KH' + K * R * K'
error_covariance_matrix_ =
gfx::MatrixProduct(gfx::MatrixProduct(I_KH, error_covariance_matrix_),
I_KH.Transpose()) +
gfx::Matrix3F::FromOuterProduct(
gfx::ScaleVector3d(kalman_gain, measurement_noise_variance_),
kalman_gain);
}
void KalmanFilter::Reset() {
state_estimation_ = gfx::Vector3dF();
error_covariance_matrix_ = gfx::Matrix3F::Identity();
iteration_count_ = 0;
}
double KalmanFilter::GetPosition() const {
return GetStateEstimation().x();
}
double KalmanFilter::GetVelocity() const {
return GetStateEstimation().y();
}
double KalmanFilter::GetAcceleration() const {
return GetStateEstimation().z();
}
void KalmanFilter::Predict(double dt) {
DCHECK_GT(iteration_count_, 0u);
state_transition_matrix_ = GetStateTransition(dt);
// X = F * X
state_estimation_ =
gfx::MatrixProduct(state_transition_matrix_, state_estimation_);
// P = F * P * F' + Q
error_covariance_matrix_ =
gfx::MatrixProduct(gfx::MatrixProduct(state_transition_matrix_,
error_covariance_matrix_),
state_transition_matrix_.Transpose()) +
GetProcessNoise(dt);
}
} // namespace ui