001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package org.wpilib.math.autodiff;
006
007import java.util.Iterator;
008import java.util.NoSuchElementException;
009import java.util.function.UnaryOperator;
010import java.util.stream.Stream;
011import java.util.stream.StreamSupport;
012import org.ejml.simple.SimpleMatrix;
013
014/** A submatrix of autodiff variables with reference semantics. */
015public class VariableBlock implements Iterable<Variable> {
016  private final VariableMatrix m_mat;
017
018  private final Slice m_rowSlice;
019  private final int m_rowSliceLength;
020
021  private final Slice m_colSlice;
022  private final int m_colSliceLength;
023
024  /**
025   * Constructs a Variable block pointing to all of the given matrix.
026   *
027   * @param mat The matrix to which to point.
028   */
029  public VariableBlock(VariableMatrix mat) {
030    this(mat, 0, 0, mat.rows(), mat.cols());
031  }
032
033  /**
034   * Constructs a Variable block pointing to a subset of the given matrix.
035   *
036   * @param mat The matrix to which to point.
037   * @param rowOffset The block's row offset.
038   * @param colOffset The block's column offset.
039   * @param blockRows The number of rows in the block.
040   * @param blockCols The number of columns in the block.
041   */
042  public VariableBlock(
043      VariableMatrix mat, int rowOffset, int colOffset, int blockRows, int blockCols) {
044    m_mat = mat;
045    m_rowSlice = new Slice(rowOffset, rowOffset + blockRows, 1);
046    m_rowSliceLength = m_rowSlice.adjust(mat.rows());
047    m_colSlice = new Slice(colOffset, colOffset + blockCols, 1);
048    m_colSliceLength = m_colSlice.adjust(mat.cols());
049  }
050
051  /**
052   * Constructs a Variable block pointing to a subset of the given matrix.
053   *
054   * <p>Note that the slices are taken as is rather than adjusted.
055   *
056   * @param mat The matrix to which to point.
057   * @param rowSlice The block's row slice.
058   * @param rowSliceLength The block's row length.
059   * @param colSlice The block's column slice.
060   * @param colSliceLength The block's column length.
061   */
062  public VariableBlock(
063      VariableMatrix mat, Slice rowSlice, int rowSliceLength, Slice colSlice, int colSliceLength) {
064    m_mat = mat;
065    m_rowSlice = rowSlice;
066    m_rowSliceLength = rowSliceLength;
067    m_colSlice = colSlice;
068    m_colSliceLength = colSliceLength;
069  }
070
071  /**
072   * Assigns a double to the block.
073   *
074   * <p>This only works for blocks with one row and one column.
075   *
076   * @param value Value to assign.
077   * @return This VariableBlock.
078   */
079  public VariableBlock set(double value) {
080    assert rows() == 1 && cols() == 1;
081
082    set(0, 0, new Variable(value));
083
084    return this;
085  }
086
087  /**
088   * Assigns a Variable to the block.
089   *
090   * <p>This only works for blocks with one row and one column.
091   *
092   * @param value Value to assign.
093   * @return This VariableBlock.
094   */
095  public VariableBlock set(Variable value) {
096    assert rows() == 1 && cols() == 1;
097
098    set(0, 0, value);
099
100    return this;
101  }
102
103  /**
104   * Assigns a double array to the block.
105   *
106   * @param values Double array of values to assign.
107   * @return This VariableBlock.
108   */
109  public VariableBlock set(double[][] values) {
110    assert rows() == values.length;
111
112    // Assert all column counts are the same
113    for (var row : values) {
114      assert row.length == cols();
115    }
116
117    for (int row = 0; row < rows(); ++row) {
118      for (int col = 0; col < cols(); ++col) {
119        set(row, col, values[row][col]);
120      }
121    }
122
123    return this;
124  }
125
126  /**
127   * Assigns an EJML matrix to the block.
128   *
129   * @param values EJML matrix of values to assign.
130   * @return This VariableBlock.
131   */
132  public VariableBlock set(SimpleMatrix values) {
133    assert rows() == values.getNumRows() && cols() == values.getNumCols();
134
135    for (int row = 0; row < rows(); ++row) {
136      for (int col = 0; col < cols(); ++col) {
137        set(row, col, values.get(row, col));
138      }
139    }
140
141    return this;
142  }
143
144  /**
145   * Assigns a VariableMatrix to the block.
146   *
147   * @param values VariableMatrix of values.
148   * @return This VariableBlock.
149   */
150  public VariableBlock set(VariableMatrix values) {
151    assert rows() == values.rows() && cols() == values.cols();
152
153    for (int row = 0; row < rows(); ++row) {
154      for (int col = 0; col < cols(); ++col) {
155        set(row, col, values.get(row, col));
156      }
157    }
158    return this;
159  }
160
161  /**
162   * Assigns a VariableBlock to the block.
163   *
164   * @param values VariableBlock of values.
165   * @return This VariableBlock.
166   */
167  public VariableBlock set(VariableBlock values) {
168    assert rows() == values.rows() && cols() == values.cols();
169
170    for (int row = 0; row < rows(); ++row) {
171      for (int col = 0; col < cols(); ++col) {
172        set(row, col, values.get(row, col));
173      }
174    }
175    return this;
176  }
177
178  /**
179   * Sets a scalar subblock at the given row and column.
180   *
181   * @param row The scalar subblock's row.
182   * @param col The scalar subblock's column.
183   * @param value The value.
184   */
185  public void set(int row, int col, Variable value) {
186    assert row >= 0 && row < rows();
187    assert col >= 0 && col < cols();
188    m_mat.set(
189        m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value);
190  }
191
192  /**
193   * Sets a scalar subblock at the given row and column.
194   *
195   * @param row The scalar subblock's row.
196   * @param col The scalar subblock's column.
197   * @param value The value.
198   */
199  public void set(int row, int col, double value) {
200    assert row >= 0 && row < rows();
201    assert col >= 0 && col < cols();
202    m_mat.set(
203        m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value);
204  }
205
206  /**
207   * Sets a scalar subblock at the given index.
208   *
209   * @param index The scalar subblock's index.
210   * @param value The value.
211   */
212  public void set(int index, double value) {
213    set(index, new Variable(value));
214  }
215
216  /**
217   * Sets a scalar subblock at the given index.
218   *
219   * @param index The scalar subblock's index.
220   * @param value The value.
221   */
222  public void set(int index, Variable value) {
223    assert index >= 0 && index < rows() * cols();
224    set(index / cols(), index % cols(), value);
225  }
226
227  /**
228   * Assigns a double to the block.
229   *
230   * <p>This only works for blocks with one row and one column.
231   *
232   * @param value Value to assign.
233   */
234  public void setValue(double value) {
235    assert rows() == 1 && cols() == 1;
236
237    get(0, 0).setValue(value);
238  }
239
240  /**
241   * Sets block's internal values.
242   *
243   * @param values Double array of values.
244   */
245  public void setValue(double[][] values) {
246    assert rows() == values.length;
247
248    // Assert all column counts are the same
249    for (var row : values) {
250      assert row.length == cols();
251    }
252
253    for (int row = 0; row < rows(); ++row) {
254      for (int col = 0; col < cols(); ++col) {
255        get(row, col).setValue(values[row][col]);
256      }
257    }
258  }
259
260  /**
261   * Sets block's internal values.
262   *
263   * @param values EJML matrix of values.
264   */
265  public void setValue(SimpleMatrix values) {
266    assert rows() == values.getNumRows() && cols() == values.getNumCols();
267
268    for (int row = 0; row < rows(); ++row) {
269      for (int col = 0; col < cols(); ++col) {
270        get(row, col).setValue(values.get(row, col));
271      }
272    }
273  }
274
275  /**
276   * Returns a scalar subblock at the given row and column.
277   *
278   * @param row The scalar subblock's row.
279   * @param col The scalar subblock's column.
280   * @return A scalar subblock at the given row and column.
281   */
282  public Variable get(int row, int col) {
283    assert row >= 0 && row < rows();
284    assert col >= 0 && col < cols();
285    return m_mat.get(
286        m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step);
287  }
288
289  /**
290   * Returns a scalar subblock at the given index.
291   *
292   * @param index The scalar subblock's index.
293   * @return A scalar subblock at the given index.
294   */
295  public Variable get(int index) {
296    assert index >= 0 && index < rows() * cols();
297    return get(index / cols(), index % cols());
298  }
299
300  /**
301   * Returns a slice of the variable matrix.
302   *
303   * @param row The row.
304   * @param colSlice The column slice.
305   * @return A slice of the variable matrix.
306   */
307  public VariableBlock get(int row, Slice.None colSlice) {
308    return get(new Slice(row), new Slice(colSlice));
309  }
310
311  /**
312   * Returns a slice of the variable matrix.
313   *
314   * @param row The row.
315   * @param colSlice The column slice.
316   * @return A slice of the variable matrix.
317   */
318  public VariableBlock get(int row, Slice colSlice) {
319    return get(new Slice(row), colSlice);
320  }
321
322  /**
323   * Returns a slice of the variable matrix.
324   *
325   * @param rowSlice The row slice.
326   * @param col The column.
327   * @return A slice of the variable matrix.
328   */
329  public VariableBlock get(Slice.None rowSlice, int col) {
330    return get(new Slice(rowSlice), new Slice(col));
331  }
332
333  /**
334   * Returns a slice of the variable matrix.
335   *
336   * @param rowSlice The row slice.
337   * @param col The column.
338   * @return A slice of the variable matrix.
339   */
340  public VariableBlock get(Slice rowSlice, int col) {
341    return get(rowSlice, new Slice(col));
342  }
343
344  /**
345   * Returns a slice of the variable matrix.
346   *
347   * @param rowSlice The row slice.
348   * @param colSlice The column slice.
349   * @return A slice of the variable matrix.
350   */
351  public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) {
352    return get(new Slice(rowSlice), new Slice(colSlice));
353  }
354
355  /**
356   * Returns a slice of the variable matrix.
357   *
358   * @param rowSlice The row slice.
359   * @param colSlice The column slice.
360   * @return A slice of the variable matrix.
361   */
362  public VariableBlock get(Slice.None rowSlice, Slice colSlice) {
363    return get(new Slice(rowSlice), colSlice);
364  }
365
366  /**
367   * Returns a slice of the variable matrix.
368   *
369   * @param rowSlice The row slice.
370   * @param colSlice The column slice.
371   * @return A slice of the variable matrix.
372   */
373  public VariableBlock get(Slice rowSlice, Slice.None colSlice) {
374    return get(rowSlice, new Slice(colSlice));
375  }
376
377  /**
378   * Returns a slice of the variable matrix.
379   *
380   * @param rowSlice The row slice.
381   * @param colSlice The column slice.
382   * @return A slice of the variable matrix.
383   */
384  public VariableBlock get(Slice rowSlice, Slice colSlice) {
385    int rowSliceLength = rowSlice.adjust(m_rowSliceLength);
386    int colSliceLength = colSlice.adjust(m_colSliceLength);
387    return new VariableBlock(
388        m_mat,
389        new Slice(
390            m_rowSlice.start + rowSlice.start * m_rowSlice.step,
391            m_rowSlice.start + rowSlice.stop * m_rowSlice.step,
392            rowSlice.step * m_rowSlice.step),
393        rowSliceLength,
394        new Slice(
395            m_colSlice.start + colSlice.start * m_colSlice.step,
396            m_colSlice.start + colSlice.stop * m_colSlice.step,
397            colSlice.step * m_colSlice.step),
398        colSliceLength);
399  }
400
401  /**
402   * Returns a block of the variable matrix.
403   *
404   * @param rowOffset The row offset of the block selection.
405   * @param colOffset The column offset of the block selection.
406   * @param blockRows The number of rows in the block selection.
407   * @param blockCols The number of columns in the block selection.
408   * @return A block of the variable matrix.
409   */
410  public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) {
411    assert rowOffset >= 0 && rowOffset <= rows();
412    assert colOffset >= 0 && colOffset <= cols();
413    assert blockRows >= 0 && blockRows <= rows() - rowOffset;
414    assert blockCols >= 0 && blockCols <= cols() - colOffset;
415    return get(
416        new Slice(rowOffset, rowOffset + blockRows, 1),
417        new Slice(colOffset, colOffset + blockCols, 1));
418  }
419
420  /**
421   * Returns a segment of the variable vector.
422   *
423   * @param offset The offset of the segment.
424   * @param length The length of the segment.
425   * @return A segment of the variable vector.
426   */
427  public VariableBlock segment(int offset, int length) {
428    assert cols() == 1;
429    assert offset >= 0 && offset < rows();
430    assert length >= 0 && length <= rows() - offset;
431    return block(offset, 0, length, 1);
432  }
433
434  /**
435   * Returns a row slice of the variable matrix.
436   *
437   * @param row The row to slice.
438   * @return A row slice of the variable matrix.
439   */
440  public VariableBlock row(int row) {
441    assert row >= 0 && row < rows();
442    return block(row, 0, 1, cols());
443  }
444
445  /**
446   * Returns a column slice of the variable matrix.
447   *
448   * @param col The column to slice.
449   * @return A column slice of the variable matrix.
450   */
451  public VariableBlock col(int col) {
452    assert col >= 0 && col < cols();
453    return block(0, col, rows(), 1);
454  }
455
456  /**
457   * Matrix multiplication operator.
458   *
459   * @param rhs Operator right-hand side.
460   * @return Result of matrix multiplication.
461   */
462  public VariableMatrix times(VariableMatrix rhs) {
463    assert cols() == rhs.rows();
464
465    var result = new VariableMatrix(rows(), rhs.cols());
466
467    for (int i = 0; i < rows(); ++i) {
468      for (int j = 0; j < rhs.cols(); ++j) {
469        var sum = new Variable(0.0);
470        for (int k = 0; k < cols(); ++k) {
471          sum = sum.plus(get(i, k).times(rhs.get(k, j)));
472        }
473        result.set(i, j, sum);
474      }
475    }
476
477    return result;
478  }
479
480  /**
481   * Matrix multiplication operator.
482   *
483   * @param rhs Operator right-hand side.
484   * @return Result of matrix multiplication.
485   */
486  public VariableMatrix times(VariableBlock rhs) {
487    assert cols() == rhs.rows();
488
489    var result = new VariableMatrix(rows(), rhs.cols());
490
491    for (int i = 0; i < rows(); ++i) {
492      for (int j = 0; j < rhs.cols(); ++j) {
493        var sum = new Variable(0.0);
494        for (int k = 0; k < cols(); ++k) {
495          sum = sum.plus(get(i, k).times(rhs.get(k, j)));
496        }
497        result.set(i, j, sum);
498      }
499    }
500
501    return result;
502  }
503
504  /**
505   * Matrix-scalar multiplication operator.
506   *
507   * @param rhs Operator right-hand side.
508   * @return Result of matrix-scalar multiplication.
509   */
510  public VariableMatrix times(double rhs) {
511    return times(new Variable(rhs));
512  }
513
514  /**
515   * Matrix-scalar multiplication operator.
516   *
517   * @param rhs Operator right-hand side.
518   * @return Result of matrix-scalar multiplication.
519   */
520  public VariableMatrix times(Variable rhs) {
521    var result = new VariableMatrix(rows(), cols());
522
523    for (int row = 0; row < result.rows(); ++row) {
524      for (int col = 0; col < result.cols(); ++col) {
525        result.set(row, col, get(row, col).times(rhs));
526      }
527    }
528
529    return result;
530  }
531
532  /**
533   * Binary division operator.
534   *
535   * @param rhs Operator right-hand side.
536   * @return Result of division.
537   */
538  public VariableMatrix div(double rhs) {
539    return div(new Variable(rhs));
540  }
541
542  /**
543   * Binary division operator.
544   *
545   * @param rhs Operator right-hand side.
546   * @return Result of division.
547   */
548  public VariableMatrix div(Variable rhs) {
549    var result = new VariableMatrix(rows(), cols());
550
551    for (int row = 0; row < result.rows(); ++row) {
552      for (int col = 0; col < result.cols(); ++col) {
553        result.set(row, col, get(row, col).div(rhs));
554      }
555    }
556
557    return result;
558  }
559
560  /**
561   * Binary addition operator.
562   *
563   * @param rhs Operator right-hand side.
564   * @return Result of addition.
565   */
566  public VariableMatrix plus(VariableMatrix rhs) {
567    assert rows() == rhs.rows() && cols() == rhs.cols();
568
569    var result = new VariableMatrix(rows(), cols());
570
571    for (int row = 0; row < result.rows(); ++row) {
572      for (int col = 0; col < result.cols(); ++col) {
573        result.set(row, col, get(row, col).plus(rhs.get(row, col)));
574      }
575    }
576
577    return result;
578  }
579
580  /**
581   * Binary addition operator.
582   *
583   * @param rhs Operator right-hand side.
584   * @return Result of addition.
585   */
586  public VariableMatrix plus(VariableBlock rhs) {
587    assert rows() == rhs.rows() && cols() == rhs.cols();
588
589    var result = new VariableMatrix(rows(), cols());
590
591    for (int row = 0; row < result.rows(); ++row) {
592      for (int col = 0; col < result.cols(); ++col) {
593        result.set(row, col, get(row, col).plus(rhs.get(row, col)));
594      }
595    }
596
597    return result;
598  }
599
600  /**
601   * Binary addition operator.
602   *
603   * @param rhs Operator right-hand side.
604   * @return Result of addition.
605   */
606  public VariableMatrix plus(SimpleMatrix rhs) {
607    assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
608
609    var result = new VariableMatrix(rows(), cols());
610
611    for (int row = 0; row < result.rows(); ++row) {
612      for (int col = 0; col < result.cols(); ++col) {
613        result.set(row, col, get(row, col).plus(rhs.get(row, col)));
614      }
615    }
616
617    return result;
618  }
619
620  /**
621   * Binary subtraction operator.
622   *
623   * @param rhs Operator right-hand side.
624   * @return Result of subtraction.
625   */
626  public VariableMatrix minus(VariableMatrix rhs) {
627    assert rows() == rhs.rows() && cols() == rhs.cols();
628
629    var result = new VariableMatrix(rows(), cols());
630
631    for (int row = 0; row < result.rows(); ++row) {
632      for (int col = 0; col < result.cols(); ++col) {
633        result.set(row, col, get(row, col).minus(rhs.get(row, col)));
634      }
635    }
636
637    return result;
638  }
639
640  /**
641   * Binary subtraction operator.
642   *
643   * @param rhs Operator right-hand side.
644   * @return Result of subtraction.
645   */
646  public VariableMatrix minus(VariableBlock rhs) {
647    assert rows() == rhs.rows() && cols() == rhs.cols();
648
649    var result = new VariableMatrix(rows(), cols());
650
651    for (int row = 0; row < result.rows(); ++row) {
652      for (int col = 0; col < result.cols(); ++col) {
653        result.set(row, col, get(row, col).minus(rhs.get(row, col)));
654      }
655    }
656
657    return result;
658  }
659
660  /**
661   * Binary subtraction operator.
662   *
663   * @param rhs Operator right-hand side.
664   * @return Result of subtraction.
665   */
666  public VariableMatrix minus(SimpleMatrix rhs) {
667    assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
668
669    var result = new VariableMatrix(rows(), cols());
670
671    for (int row = 0; row < result.rows(); ++row) {
672      for (int col = 0; col < result.cols(); ++col) {
673        result.set(row, col, get(row, col).minus(rhs.get(row, col)));
674      }
675    }
676
677    return result;
678  }
679
680  /**
681   * Unary minus operator.
682   *
683   * @return Result of unary minus.
684   */
685  public VariableMatrix unaryMinus() {
686    var result = new VariableMatrix(rows(), cols());
687
688    for (int row = 0; row < result.rows(); ++row) {
689      for (int col = 0; col < result.cols(); ++col) {
690        result.set(row, col, get(row, col).unaryMinus());
691      }
692    }
693
694    return result;
695  }
696
697  /**
698   * Returns the transpose of the variable matrix.
699   *
700   * @return The transpose of the variable matrix.
701   */
702  public VariableMatrix T() {
703    var result = new VariableMatrix(cols(), rows());
704
705    for (int row = 0; row < rows(); ++row) {
706      for (int col = 0; col < cols(); ++col) {
707        result.set(col, row, get(row, col));
708      }
709    }
710
711    return result;
712  }
713
714  /**
715   * Returns the number of rows in the matrix.
716   *
717   * @return The number of rows in the matrix.
718   */
719  public int rows() {
720    return m_rowSliceLength;
721  }
722
723  /**
724   * Returns the number of columns in the matrix.
725   *
726   * @return The number of columns in the matrix.
727   */
728  public int cols() {
729    return m_colSliceLength;
730  }
731
732  /**
733   * Returns an element of the variable matrix.
734   *
735   * @param row The row of the element to return.
736   * @param col The column of the element to return.
737   * @return An element of the variable matrix.
738   */
739  public double value(int row, int col) {
740    return get(row, col).value();
741  }
742
743  /**
744   * Returns an element of the variable block.
745   *
746   * @param index The index of the element to return.
747   * @return An element of the variable block.
748   */
749  public double value(int index) {
750    return get(index).value();
751  }
752
753  /**
754   * Returns the contents of the variable matrix.
755   *
756   * @return The contents of the variable matrix.
757   */
758  public SimpleMatrix value() {
759    var result = new SimpleMatrix(rows(), cols());
760
761    for (int row = 0; row < rows(); ++row) {
762      for (int col = 0; col < cols(); ++col) {
763        result.set(row, col, value(row, col));
764      }
765    }
766
767    return result;
768  }
769
770  /**
771   * Maps the matrix coefficient-wise with an unary operator.
772   *
773   * @param unaryOp The unary operator to use for the map operation.
774   * @return Result of the unary operator.
775   */
776  public VariableMatrix cwiseMap(UnaryOperator<Variable> unaryOp) {
777    var result = new VariableMatrix(rows(), cols());
778
779    for (int row = 0; row < rows(); ++row) {
780      for (int col = 0; col < cols(); ++col) {
781        result.set(row, col, unaryOp.apply(get(row, col)));
782      }
783    }
784
785    return result;
786  }
787
788  /**
789   * Returns number of elements in matrix.
790   *
791   * @return Number of elements in matrix.
792   */
793  public int size() {
794    return rows() * cols();
795  }
796
797  @Override
798  public Iterator<Variable> iterator() {
799    return new Iterator<>() {
800      private int m_index = 0;
801
802      @Override
803      public boolean hasNext() {
804        return m_index < VariableBlock.this.size();
805      }
806
807      @Override
808      public Variable next() {
809        if (!hasNext()) {
810          throw new NoSuchElementException();
811        }
812
813        return VariableBlock.this.get(m_index++);
814      }
815    };
816  }
817
818  /**
819   * Creates a Stream of VariableBlock elements.
820   *
821   * @return A Stream of VariableBlock elements.
822   */
823  public Stream<Variable> stream() {
824    return StreamSupport.stream(spliterator(), false);
825  }
826}