001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.math.estimator; 006 007import edu.wpi.first.math.Matrix; 008import edu.wpi.first.math.Nat; 009import edu.wpi.first.math.Num; 010import edu.wpi.first.math.numbers.N1; 011import org.ejml.simple.SimpleMatrix; 012 013/** 014 * Generates sigma points and weights according to Van der Merwe's 2004 dissertation[1] for the 015 * UnscentedKalmanFilter class. 016 * 017 * <p>It parametrizes the sigma points using alpha, beta, kappa terms, and is the version seen in 018 * most publications. S3SigmaPoints is generally preferred due to its greater performance with 019 * nearly identical accuracy. 020 * 021 * <p>States is the dimensionality of the state. 2*States+1 weights will be generated. 022 * 023 * <p>[1] R. Van der Merwe "Sigma-Point Kalman Filters for Probabilistic Inference in Dynamic 024 * State-Space Models" (Doctoral dissertation) 025 * 026 * @param <S> The dimensionality of the state. 2 * States + 1 weights will be generated. 027 */ 028public class MerweScaledSigmaPoints<S extends Num> implements SigmaPoints<S> { 029 private final double m_alpha; 030 private final int m_kappa; 031 private final Nat<S> m_states; 032 private Matrix<?, N1> m_wm; 033 private Matrix<?, N1> m_wc; 034 035 /** 036 * Constructs a generator for Van der Merwe scaled sigma points. 037 * 038 * @param states an instance of Num that represents the number of states. 039 * @param alpha Determines the spread of the sigma points around the mean. Usually a small 040 * positive value (1e-3). 041 * @param beta Incorporates prior knowledge of the distribution of the mean. For Gaussian 042 * distributions, beta = 2 is optimal. 043 * @param kappa Secondary scaling parameter usually set to 0 or 3 - States. 044 */ 045 public MerweScaledSigmaPoints(Nat<S> states, double alpha, double beta, int kappa) { 046 this.m_states = states; 047 this.m_alpha = alpha; 048 this.m_kappa = kappa; 049 050 computeWeights(beta); 051 } 052 053 /** 054 * Constructs a generator for Van der Merwe scaled sigma points with default values for alpha, 055 * beta, and kappa. 056 * 057 * @param states an instance of Num that represents the number of states. 058 */ 059 public MerweScaledSigmaPoints(Nat<S> states) { 060 this(states, 1e-3, 2, 3 - states.getNum()); 061 } 062 063 /** 064 * Returns number of sigma points for each variable in the state x. 065 * 066 * @return The number of sigma points for each variable in the state x. 067 */ 068 @Override 069 public int getNumSigmas() { 070 return 2 * m_states.getNum() + 1; 071 } 072 073 /** 074 * Computes the sigma points for an unscented Kalman filter given the mean (x) and square-root 075 * covariance (s) of the filter. 076 * 077 * @param x An array of the means. 078 * @param s Square-root covariance of the filter. 079 * @return Two-dimensional array of sigma points. Each column contains all the sigmas for one 080 * dimension in the problem space. Ordered by Xi_0, Xi_{1..n}, Xi_{n+1..2n}. 081 */ 082 @Override 083 public Matrix<S, ?> squareRootSigmaPoints(Matrix<S, N1> x, Matrix<S, S> s) { 084 double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum(); 085 double eta = Math.sqrt(lambda + m_states.getNum()); 086 087 Matrix<S, S> U = s.times(eta); 088 089 // 2 * states + 1 by states 090 Matrix<S, ?> sigmas = 091 new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1)); 092 093 // equation (17) 094 sigmas.setColumn(0, x); 095 for (int k = 0; k < m_states.getNum(); k++) { 096 var xPlusU = x.plus(U.extractColumnVector(k)); 097 var xMinusU = x.minus(U.extractColumnVector(k)); 098 sigmas.setColumn(k + 1, xPlusU); 099 sigmas.setColumn(m_states.getNum() + k + 1, xMinusU); 100 } 101 102 return new Matrix<>(sigmas); 103 } 104 105 /** 106 * Computes the weights for the scaled unscented Kalman filter. 107 * 108 * @param beta Incorporates prior knowledge of the distribution of the mean. 109 */ 110 private void computeWeights(double beta) { 111 double lambda = Math.pow(m_alpha, 2) * (m_states.getNum() + m_kappa) - m_states.getNum(); 112 double c = 0.5 / (m_states.getNum() + lambda); 113 114 Matrix<?, N1> wM = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1)); 115 Matrix<?, N1> wC = new Matrix<>(new SimpleMatrix(2 * m_states.getNum() + 1, 1)); 116 wM.fill(c); 117 wC.fill(c); 118 119 wM.set(0, 0, lambda / (m_states.getNum() + lambda)); 120 wC.set(0, 0, lambda / (m_states.getNum() + lambda) + (1 - Math.pow(m_alpha, 2) + beta)); 121 122 this.m_wm = wM; 123 this.m_wc = wC; 124 } 125 126 /** 127 * Returns the weight for each sigma point for the mean. 128 * 129 * @return the weight for each sigma point for the mean. 130 */ 131 @Override 132 public Matrix<?, N1> getWm() { 133 return m_wm; 134 } 135 136 /** 137 * Returns an element of the weight for each sigma point for the mean. 138 * 139 * @param element Element of vector to return. 140 * @return the element i's weight for the mean. 141 */ 142 @Override 143 public double getWm(int element) { 144 return m_wm.get(element, 0); 145 } 146 147 /** 148 * Returns the weight for each sigma point for the covariance. 149 * 150 * @return the weight for each sigma point for the covariance. 151 */ 152 @Override 153 public Matrix<?, N1> getWc() { 154 return m_wc; 155 } 156 157 /** 158 * Returns an element of the weight for each sigma point for the covariance. 159 * 160 * @param element Element of vector to return. 161 * @return The element I's weight for the covariance. 162 */ 163 @Override 164 public double getWc(int element) { 165 return m_wc.get(element, 0); 166 } 167}