001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Nat; 009import edu.wpi.first.math.Num; 010import edu.wpi.first.math.numbers.N1; 011import org.ejml.simple.SimpleMatrix; 012 013/** 014 * Generates sigma points and weights according to Van der Merwe's 2004 dissertation[1] for the 015 * UnscentedKalmanFilter class. 016 * 017 * <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the version seen in 018 * most publications. Unless you know better, this should be your default choice. 019 * 020 * <p>States is the dimensionality of the state. 2*States+1 weights will be generated. 021 * 022 * <p>[1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilitic Inference in Dynamic 023 * State-Space Models" (Doctoral dissertation) 024 * 025 * @param <S> The dimensionality of the state. 2 * States + 1 weights will be generated. 026 */ 027public class MerweScaledSigmaPoints<S extends Num> { 028 private final double m_alpha; 029 private final int m_kappa; 030 private final Nat<S> m_states; 031 private Matrix<?, N1> m_wm; 032 private Matrix<?, N1> m_wc; 033 034 /** 035 * Constructs a generator for Van der Merwe scaled sigma points. 036 * 037 * @param states an instance of Num that represents the number of states. 038 * @param alpha Determines the spread of the sigma points around the mean. Usually a small 039 * positive value (1e-3). 040 * @param beta Incorporates prior knowledge of the distribution of the mean. For Gaussian 041 * distributions, beta = 2 is optimal. 042 * @param kappa Secondary scaling parameter usually set to 0 or 3 - States. 043 */ 044 public MerweScaledSigmaPoints(Nat<S> states, double alpha, double beta, int kappa) { 045 this.m_states = states; 046 this.m_alpha = alpha; 047 this.m_kappa = kappa; 048 049 computeWeights(beta); 050 } 051 052 /** 053 * Constructs a generator for Van der Merwe scaled sigma points with default values for alpha, 054 * beta, and kappa. 055 * 056 * @param states an instance of Num that represents the number of states. 057 */ 058 public MerweScaledSigmaPoints(Nat<S> states) { 059 this(states, 1e-3, 2, 3 - states.getNum()); 060 } 061 062 /** 063 * Returns number of sigma points for each variable in the state x. 064 * 065 * @return The number of sigma points for each variable in the state x. 066 */ 067 public int getNumSigmas() { 068 return 2 * m_states.getNum() + 1; 069 } 070 071 /** 072 * Computes the sigma points for an unscented Kalman filter given the mean (x) and covariance(P) 073 * of the filter. 074 * 075 * @param x An array of the means. 076 * @param s Square-root covariance of the filter. 077 * @return Two-dimensional array of sigma points. Each column contains all the sigmas for one 078 * dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}. 079 */ 080 public Matrix<S, ?> squareRootSigmaPoints(Matrix<S, N1> x, Matrix<S, S> s) { 081 double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum(); 082 double eta = Math.sqrt(lambda + m_states.getNum()); 083 084 Matrix<S, S> U = s.times(eta); 085 086 // 2 * states + 1 by states 087 Matrix<S, ?> sigmas = 088 new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1)); 089 sigmas.setColumn(0, x); 090 for (int k = 0; k < m_states.getNum(); k++) { 091 var xPlusU = x.plus(U.extractColumnVector(k)); 092 var xMinusU = x.minus(U.extractColumnVector(k)); 093 sigmas.setColumn(k + 1, xPlusU); 094 sigmas.setColumn(m_states.getNum() + k + 1, xMinusU); 095 } 096 097 return new Matrix<>(sigmas); 098 } 099 100 /** 101 * Computes the weights for the scaled unscented Kalman filter. 102 * 103 * @param beta Incorporates prior knowledge of the distribution of the mean. 104 */ 105 private void computeWeights(double beta) { 106 double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum(); 107 double c = 0.5 / (m_states.getNum() + lambda); 108 109 Matrix<?, N1> wM = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1)); 110 Matrix<?, N1> wC = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1)); 111 wM.fill(c); 112 wC.fill(c); 113 114 wM.set(0, 0, lambda / (m_states.getNum() + lambda)); 115 wC.set(0, 0, lambda / (m_states.getNum() + lambda) + (1 - Math.pow(m_alpha, 2) + beta)); 116 117 this.m_wm = wM; 118 this.m_wc = wC; 119 } 120 121 /** 122 * Returns the weight for each sigma point for the mean. 123 * 124 * @return the weight for each sigma point for the mean. 125 */ 126 public Matrix<?, N1> getWm() { 127 return m_wm; 128 } 129 130 /** 131 * Returns an element of the weight for each sigma point for the mean. 132 * 133 * @param element Element of vector to return. 134 * @return the element i's weight for the mean. 135 */ 136 public double getWm(int element) { 137 return m_wm.get(element, 0); 138 } 139 140 /** 141 * Returns the weight for each sigma point for the covariance. 142 * 143 * @return the weight for each sigma point for the covariance. 144 */ 145 public Matrix<?, N1> getWc() { 146 return m_wc; 147 } 148 149 /** 150 * Returns an element of the weight for each sigma point for the covariance. 151 * 152 * @param element Element of vector to return. 153 * @return The element I's weight for the covariance. 154 */ 155 public double getWc(int element) { 156 return m_wc.get(element, 0); 157 } 158}