001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package org.wpilib.math.autodiff;
006
007import java.util.Iterator;
008import java.util.NoSuchElementException;
009import java.util.function.BinaryOperator;
010import java.util.function.UnaryOperator;
011import java.util.stream.Stream;
012import java.util.stream.StreamSupport;
013import org.ejml.simple.SimpleMatrix;
014
015/** A matrix of autodiff variables. */
016public class VariableMatrix implements AutoCloseable, Iterable<Variable> {
017  private final Variable[] m_storage;
018  private int m_rows;
019  private int m_cols;
020
021  /**
022   * Constructs a VariableMatrix from Variable internal handles.
023   *
024   * <p>This constructor is for internal use only.
025   *
026   * @param rows The number of matrix rows.
027   * @param cols The number of matrix columns.
028   * @param handles Variable handles.
029   */
030  public VariableMatrix(int rows, int cols, long[] handles) {
031    assert handles.length == rows * cols;
032
033    m_rows = rows;
034    m_cols = cols;
035    m_storage = new Variable[rows * cols];
036    for (int index = 0; index < m_storage.length; ++index) {
037      m_storage[index] = new Variable(Variable.HANDLE, handles[index]);
038    }
039  }
040
041  /**
042   * Constructs a zero-initialized VariableMatrix column vector with the given rows.
043   *
044   * @param rows The number of matrix rows.
045   */
046  public VariableMatrix(int rows) {
047    this(rows, 1);
048  }
049
050  /**
051   * Constructs a zero-initialized VariableMatrix with the given dimensions.
052   *
053   * @param rows The number of matrix rows.
054   * @param cols The number of matrix columns.
055   */
056  public VariableMatrix(int rows, int cols) {
057    m_rows = rows;
058    m_cols = cols;
059    m_storage = new Variable[rows * cols];
060    for (int index = 0; index < m_storage.length; ++index) {
061      m_storage[index] = new Variable();
062    }
063  }
064
065  /**
066   * Constructs a scalar VariableMatrix from a nested list of doubles.
067   *
068   * @param list The nested list of Variables.
069   */
070  public VariableMatrix(double[][] list) {
071    // Get row and column counts for destination matrix
072    m_rows = list.length;
073    m_cols = 0;
074    if (list.length > 0) {
075      m_cols = list[0].length;
076    }
077
078    // Assert all column counts are the same
079    for (var row : list) {
080      assert row.length == m_cols;
081    }
082
083    m_storage = new Variable[m_rows * m_cols];
084    int index = 0;
085    for (var row : list) {
086      for (var elem : row) {
087        m_storage[index] = new Variable(elem);
088        ++index;
089      }
090    }
091  }
092
093  /**
094   * Constructs a scalar VariableMatrix from a nested list of Variables.
095   *
096   * @param list The nested list of Variables.
097   */
098  public VariableMatrix(Variable[][] list) {
099    // Get row and column counts for destination matrix
100    m_rows = list.length;
101    m_cols = 0;
102    if (list.length > 0) {
103      m_cols = list[0].length;
104    }
105
106    // Assert all column counts are the same
107    for (var row : list) {
108      assert row.length == m_cols;
109    }
110
111    m_storage = new Variable[m_rows * m_cols];
112    int index = 0;
113    for (var row : list) {
114      for (var elem : row) {
115        m_storage[index] = elem;
116        ++index;
117      }
118    }
119  }
120
121  /**
122   * Constructs a VariableMatrix from an EJML matrix.
123   *
124   * @param values EJML matrix of values.
125   */
126  public VariableMatrix(SimpleMatrix values) {
127    m_rows = values.getNumRows();
128    m_cols = values.getNumCols();
129    m_storage = new Variable[m_rows * m_cols];
130    for (int row = 0; row < values.getNumRows(); ++row) {
131      for (int col = 0; col < values.getNumCols(); ++col) {
132        m_storage[row * m_cols + col] = new Variable(values.get(row, col));
133      }
134    }
135  }
136
137  /**
138   * Constructs a scalar VariableMatrix from a Variable.
139   *
140   * @param variable Variable.
141   */
142  public VariableMatrix(Variable variable) {
143    m_rows = 1;
144    m_cols = 1;
145    m_storage = new Variable[] {variable};
146  }
147
148  /**
149   * Constructs a VariableMatrix from a VariableBlock.
150   *
151   * @param values VariableBlock of values.
152   */
153  public VariableMatrix(VariableBlock values) {
154    m_rows = values.rows();
155    m_cols = values.cols();
156    m_storage = new Variable[m_rows * m_cols];
157    for (int row = 0; row < m_rows; ++row) {
158      for (int col = 0; col < m_cols; ++col) {
159        m_storage[row * m_cols + col] = values.get(row, col);
160      }
161    }
162  }
163
164  @Override
165  public void close() {
166    for (int index = 0; index < rows() * cols(); ++index) {
167      m_storage[index].close();
168    }
169  }
170
171  /**
172   * Assigns a double array to a VariableMatrix.
173   *
174   * @param values Double array of values.
175   * @return This VariableMatrix.
176   */
177  public VariableMatrix set(double[][] values) {
178    assert rows() == values.length;
179
180    // Assert all column counts are the same
181    for (var row : values) {
182      assert row.length == cols();
183    }
184
185    for (int row = 0; row < values.length; ++row) {
186      for (int col = 0; col < values[0].length; ++col) {
187        set(row, col, values[row][col]);
188      }
189    }
190
191    return this;
192  }
193
194  /**
195   * Assigns an EJML matrix to a VariableMatrix.
196   *
197   * @param values EJML matrix of values.
198   * @return This VariableMatrix.
199   */
200  public VariableMatrix set(SimpleMatrix values) {
201    assert rows() == values.getNumRows() && cols() == values.getNumCols();
202
203    for (int row = 0; row < values.getNumRows(); ++row) {
204      for (int col = 0; col < values.getNumCols(); ++col) {
205        set(row, col, values.get(row, col));
206      }
207    }
208
209    return this;
210  }
211
212  /**
213   * Assigns a VariableMatrix to a VariableMatrix.
214   *
215   * @param values VariableMatrix of values.
216   * @return This VariableMatrix.
217   */
218  public VariableMatrix set(VariableMatrix values) {
219    assert rows() == values.rows() && cols() == values.cols();
220
221    for (int row = 0; row < values.rows(); ++row) {
222      for (int col = 0; col < values.cols(); ++col) {
223        set(row, col, values.get(row, col));
224      }
225    }
226
227    return this;
228  }
229
230  /**
231   * Assigns a VariableBlock to a VariableMatrix.
232   *
233   * @param values VariableBlock of values.
234   * @return This VariableMatrix.
235   */
236  public VariableMatrix set(VariableBlock values) {
237    assert rows() == values.rows() && cols() == values.cols();
238
239    for (int row = 0; row < values.rows(); ++row) {
240      for (int col = 0; col < values.cols(); ++col) {
241        set(row, col, values.get(row, col));
242      }
243    }
244
245    return this;
246  }
247
248  /**
249   * Assigns a double to the matrix.
250   *
251   * <p>This only works for matrices with one row and one column.
252   *
253   * @param value Value to assign.
254   * @return This VariableMatrix.
255   */
256  public VariableMatrix set(double value) {
257    return set(new Variable(value));
258  }
259
260  /**
261   * Assigns a Variable to the matrix.
262   *
263   * <p>This only works for matrices with one row and one column.
264   *
265   * @param value Value to assign.
266   * @return This VariableMatrix.
267   */
268  public VariableMatrix set(Variable value) {
269    assert rows() == 1 && cols() == 1;
270
271    m_storage[0] = value;
272
273    return this;
274  }
275
276  /**
277   * Sets an element to the given value.
278   *
279   * @param row The row.
280   * @param col The column.
281   * @param value The value.
282   */
283  public void set(int row, int col, Variable value) {
284    assert row >= 0 && row < rows();
285    assert col >= 0 && col < cols();
286    m_storage[row * cols() + col] = value;
287  }
288
289  /**
290   * Sets an element to the given value.
291   *
292   * @param row The row.
293   * @param col The column.
294   * @param value The value.
295   */
296  public void set(int row, int col, double value) {
297    assert row >= 0 && row < rows();
298    assert col >= 0 && col < cols();
299    m_storage[row * cols() + col] = new Variable(value);
300  }
301
302  /**
303   * Sets an element to the given value.
304   *
305   * @param index The index of the element.
306   * @param value The value.
307   */
308  public void set(int index, double value) {
309    set(index, new Variable(value));
310  }
311
312  /**
313   * Sets an element to the given value.
314   *
315   * @param index The index of the element.
316   * @param value The value.
317   */
318  public void set(int index, Variable value) {
319    assert index >= 0 && index < rows() * cols();
320    m_storage[index] = value;
321  }
322
323  /**
324   * Sets the VariableMatrix's internal values.
325   *
326   * @param values Double array of values.
327   */
328  public void setValue(double[][] values) {
329    assert rows() == values.length;
330
331    // Assert all column counts are the same
332    for (var row : values) {
333      assert row.length == cols();
334    }
335
336    for (int row = 0; row < rows(); ++row) {
337      for (int col = 0; col < cols(); ++col) {
338        get(row, col).setValue(values[row][col]);
339      }
340    }
341  }
342
343  /**
344   * Sets the VariableMatrix's internal values.
345   *
346   * @param values EJML matrix of values.
347   */
348  public void setValue(SimpleMatrix values) {
349    assert rows() == values.getNumRows() && cols() == values.getNumCols();
350
351    for (int row = 0; row < values.getNumRows(); ++row) {
352      for (int col = 0; col < values.getNumCols(); ++col) {
353        get(row, col).setValue(values.get(row, col));
354      }
355    }
356  }
357
358  /**
359   * Returns the element at the given row and column.
360   *
361   * @param row The row.
362   * @param col The column.
363   * @return The element at the given row and column.
364   */
365  public Variable get(int row, int col) {
366    assert row >= 0 && row < rows();
367    assert col >= 0 && col < cols();
368    return m_storage[row * cols() + col];
369  }
370
371  /**
372   * Returns the element at the given index.
373   *
374   * @param index The index.
375   * @return The element at the given index.
376   */
377  public Variable get(int index) {
378    assert index >= 0 && index < rows() * cols();
379    return m_storage[index];
380  }
381
382  /**
383   * Returns a slice of the variable matrix.
384   *
385   * @param row The row.
386   * @param colSlice The column slice.
387   * @return A slice of the variable matrix.
388   */
389  public VariableBlock get(int row, Slice.None colSlice) {
390    return get(new Slice(row), new Slice(colSlice));
391  }
392
393  /**
394   * Returns a slice of the variable matrix.
395   *
396   * @param row The row.
397   * @param colSlice The column slice.
398   * @return A slice of the variable matrix.
399   */
400  public VariableBlock get(int row, Slice colSlice) {
401    return get(new Slice(row), colSlice);
402  }
403
404  /**
405   * Returns a slice of the variable matrix.
406   *
407   * @param rowSlice The row slice.
408   * @param col The column.
409   * @return A slice of the variable matrix.
410   */
411  public VariableBlock get(Slice.None rowSlice, int col) {
412    return get(new Slice(rowSlice), new Slice(col));
413  }
414
415  /**
416   * Returns a slice of the variable matrix.
417   *
418   * @param rowSlice The row slice.
419   * @param col The column.
420   * @return A slice of the variable matrix.
421   */
422  public VariableBlock get(Slice rowSlice, int col) {
423    return get(rowSlice, new Slice(col));
424  }
425
426  /**
427   * Returns a slice of the variable matrix.
428   *
429   * @param rowSlice The row slice.
430   * @param colSlice The column slice.
431   * @return A slice of the variable matrix.
432   */
433  public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) {
434    return get(new Slice(rowSlice), new Slice(colSlice));
435  }
436
437  /**
438   * Returns a slice of the variable matrix.
439   *
440   * @param rowSlice The row slice.
441   * @param colSlice The column slice.
442   * @return A slice of the variable matrix.
443   */
444  public VariableBlock get(Slice.None rowSlice, Slice colSlice) {
445    return get(new Slice(rowSlice), colSlice);
446  }
447
448  /**
449   * Returns a slice of the variable matrix.
450   *
451   * @param rowSlice The row slice.
452   * @param colSlice The column slice.
453   * @return A slice of the variable matrix.
454   */
455  public VariableBlock get(Slice rowSlice, Slice.None colSlice) {
456    return get(rowSlice, new Slice(colSlice));
457  }
458
459  /**
460   * Returns a slice of the variable matrix.
461   *
462   * @param rowSlice The row slice.
463   * @param colSlice The column slice.
464   * @return A slice of the variable matrix.
465   */
466  public VariableBlock get(Slice rowSlice, Slice colSlice) {
467    int rowSliceLength = rowSlice.adjust(rows());
468    int colSliceLength = colSlice.adjust(cols());
469    return new VariableBlock(this, rowSlice, rowSliceLength, colSlice, colSliceLength);
470  }
471
472  /**
473   * Returns a block of the variable matrix.
474   *
475   * @param rowOffset The row offset of the block selection.
476   * @param colOffset The column offset of the block selection.
477   * @param blockRows The number of rows in the block selection.
478   * @param blockCols The number of columns in the block selection.
479   * @return A block of the variable matrix.
480   */
481  public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) {
482    assert rowOffset >= 0 && rowOffset <= rows();
483    assert colOffset >= 0 && colOffset <= cols();
484    assert blockRows >= 0 && blockRows <= rows() - rowOffset;
485    assert blockCols >= 0 && blockCols <= cols() - colOffset;
486    return new VariableBlock(this, rowOffset, colOffset, blockRows, blockCols);
487  }
488
489  /**
490   * Returns a segment of the variable vector.
491   *
492   * @param offset The offset of the segment.
493   * @param length The length of the segment.
494   * @return A segment of the variable vector.
495   */
496  public VariableBlock segment(int offset, int length) {
497    assert cols() == 1;
498    assert offset >= 0 && offset < rows();
499    assert length >= 0 && length <= rows() - offset;
500    return block(offset, 0, length, 1);
501  }
502
503  /**
504   * Returns a row slice of the variable matrix.
505   *
506   * @param row The row to slice.
507   * @return A row slice of the variable matrix.
508   */
509  public VariableBlock row(int row) {
510    assert row >= 0 && row < rows();
511    return block(row, 0, 1, cols());
512  }
513
514  /**
515   * Returns a column slice of the variable matrix.
516   *
517   * @param col The column to slice.
518   * @return A column slice of the variable matrix.
519   */
520  public VariableBlock col(int col) {
521    assert col >= 0 && col < cols();
522    return block(0, col, rows(), 1);
523  }
524
525  /**
526   * Matrix multiplication operator.
527   *
528   * @param rhs Operator right-hand side.
529   * @return Result of matrix multiplication.
530   */
531  public VariableMatrix times(VariableMatrix rhs) {
532    assert cols() == rhs.rows();
533
534    var result = new VariableMatrix(rows(), rhs.cols());
535
536    for (int i = 0; i < rows(); ++i) {
537      for (int j = 0; j < rhs.cols(); ++j) {
538        var sum = new Variable(0.0);
539        for (int k = 0; k < cols(); ++k) {
540          sum = sum.plus(get(i, k).times(rhs.get(k, j)));
541        }
542        result.set(i, j, sum);
543      }
544    }
545
546    return result;
547  }
548
549  /**
550   * Matrix multiplication operator.
551   *
552   * @param rhs Operator right-hand side.
553   * @return Result of matrix multiplication.
554   */
555  public VariableMatrix times(VariableBlock rhs) {
556    assert cols() == rhs.rows();
557
558    var result = new VariableMatrix(rows(), rhs.cols());
559
560    for (int i = 0; i < rows(); ++i) {
561      for (int j = 0; j < rhs.cols(); ++j) {
562        var sum = new Variable(0.0);
563        for (int k = 0; k < cols(); ++k) {
564          sum = sum.plus(get(i, k).times(rhs.get(k, j)));
565        }
566        result.set(i, j, sum);
567      }
568    }
569
570    return result;
571  }
572
573  /**
574   * Matrix multiplication operator.
575   *
576   * @param rhs Operator right-hand side.
577   * @return Result of matrix multiplication.
578   */
579  public VariableMatrix times(SimpleMatrix rhs) {
580    return times(new VariableMatrix(rhs));
581  }
582
583  /**
584   * Matrix-scalar multiplication operator.
585   *
586   * @param rhs Operator right-hand side.
587   * @return Result of matrix-scalar multiplication.
588   */
589  public VariableMatrix times(double rhs) {
590    return times(new Variable(rhs));
591  }
592
593  /**
594   * Matrix-scalar multiplication operator.
595   *
596   * @param rhs Operator right-hand side.
597   * @return Result of matrix-scalar multiplication.
598   */
599  public VariableMatrix times(Variable rhs) {
600    var result = new VariableMatrix(rows(), cols());
601
602    for (int row = 0; row < result.rows(); ++row) {
603      for (int col = 0; col < result.cols(); ++col) {
604        result.set(row, col, get(row, col).times(rhs));
605      }
606    }
607
608    return result;
609  }
610
611  /**
612   * Binary division operator.
613   *
614   * @param rhs Operator right-hand side.
615   * @return Result of division.
616   */
617  public VariableMatrix div(double rhs) {
618    return div(new Variable(rhs));
619  }
620
621  /**
622   * Binary division operator.
623   *
624   * @param rhs Operator right-hand side.
625   * @return Result of division.
626   */
627  public VariableMatrix div(Variable rhs) {
628    var result = new VariableMatrix(rows(), cols());
629
630    for (int row = 0; row < result.rows(); ++row) {
631      for (int col = 0; col < result.cols(); ++col) {
632        result.set(row, col, get(row, col).div(rhs));
633      }
634    }
635
636    return result;
637  }
638
639  /**
640   * Binary addition operator.
641   *
642   * @param rhs Operator right-hand side.
643   * @return Result of addition.
644   */
645  public VariableMatrix plus(VariableMatrix rhs) {
646    assert rows() == rhs.rows() && cols() == rhs.cols();
647
648    var result = new VariableMatrix(rows(), cols());
649
650    for (int row = 0; row < result.rows(); ++row) {
651      for (int col = 0; col < result.cols(); ++col) {
652        result.set(row, col, get(row, col).plus(rhs.get(row, col)));
653      }
654    }
655
656    return result;
657  }
658
659  /**
660   * Binary addition operator.
661   *
662   * @param rhs Operator right-hand side.
663   * @return Result of addition.
664   */
665  public VariableMatrix plus(VariableBlock rhs) {
666    assert rows() == rhs.rows() && cols() == rhs.cols();
667
668    var result = new VariableMatrix(rows(), cols());
669
670    for (int row = 0; row < result.rows(); ++row) {
671      for (int col = 0; col < result.cols(); ++col) {
672        result.set(row, col, get(row, col).plus(rhs.get(row, col)));
673      }
674    }
675
676    return result;
677  }
678
679  /**
680   * Binary addition operator.
681   *
682   * @param rhs Operator right-hand side.
683   * @return Result of addition.
684   */
685  public VariableMatrix plus(SimpleMatrix rhs) {
686    assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
687
688    var result = new VariableMatrix(rows(), cols());
689
690    for (int row = 0; row < result.rows(); ++row) {
691      for (int col = 0; col < result.cols(); ++col) {
692        result.set(row, col, get(row, col).plus(rhs.get(row, col)));
693      }
694    }
695
696    return result;
697  }
698
699  /**
700   * Binary subtraction operator.
701   *
702   * @param rhs Operator right-hand side.
703   * @return Result of subtraction.
704   */
705  public VariableMatrix minus(VariableMatrix rhs) {
706    assert rows() == rhs.rows() && cols() == rhs.cols();
707
708    var result = new VariableMatrix(rows(), cols());
709
710    for (int row = 0; row < result.rows(); ++row) {
711      for (int col = 0; col < result.cols(); ++col) {
712        result.set(row, col, get(row, col).minus(rhs.get(row, col)));
713      }
714    }
715
716    return result;
717  }
718
719  /**
720   * Binary subtraction operator.
721   *
722   * @param rhs Operator right-hand side.
723   * @return Result of subtraction.
724   */
725  public VariableMatrix minus(VariableBlock rhs) {
726    assert rows() == rhs.rows() && cols() == rhs.cols();
727
728    var result = new VariableMatrix(rows(), cols());
729
730    for (int row = 0; row < result.rows(); ++row) {
731      for (int col = 0; col < result.cols(); ++col) {
732        result.set(row, col, get(row, col).minus(rhs.get(row, col)));
733      }
734    }
735
736    return result;
737  }
738
739  /**
740   * Binary subtraction operator.
741   *
742   * @param rhs Operator right-hand side.
743   * @return Result of subtraction.
744   */
745  public VariableMatrix minus(SimpleMatrix rhs) {
746    assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols();
747
748    var result = new VariableMatrix(rows(), cols());
749
750    for (int row = 0; row < result.rows(); ++row) {
751      for (int col = 0; col < result.cols(); ++col) {
752        result.set(row, col, get(row, col).minus(rhs.get(row, col)));
753      }
754    }
755
756    return result;
757  }
758
759  /**
760   * Unary minus operator.
761   *
762   * @return Result of unary minus.
763   */
764  public VariableMatrix unaryMinus() {
765    var result = new VariableMatrix(rows(), cols());
766
767    for (int row = 0; row < result.rows(); ++row) {
768      for (int col = 0; col < result.cols(); ++col) {
769        result.set(row, col, get(row, col).unaryMinus());
770      }
771    }
772
773    return result;
774  }
775
776  /**
777   * Returns the transpose of the variable matrix.
778   *
779   * @return The transpose of the variable matrix.
780   */
781  public VariableMatrix T() {
782    var result = new VariableMatrix(cols(), rows());
783
784    for (int row = 0; row < rows(); ++row) {
785      for (int col = 0; col < cols(); ++col) {
786        result.set(col, row, get(row, col));
787      }
788    }
789
790    return result;
791  }
792
793  /**
794   * Returns the number of rows in the matrix.
795   *
796   * @return The number of rows in the matrix.
797   */
798  public int rows() {
799    return m_rows;
800  }
801
802  /**
803   * Returns the number of columns in the matrix.
804   *
805   * @return The number of columns in the matrix.
806   */
807  public int cols() {
808    return m_cols;
809  }
810
811  /**
812   * Returns an element of the variable matrix.
813   *
814   * @param row The row of the element to return.
815   * @param col The column of the element to return.
816   * @return An element of the variable matrix.
817   */
818  public double value(int row, int col) {
819    return get(row, col).value();
820  }
821
822  /**
823   * Returns an element of the variable matrix.
824   *
825   * @param index The index of the element to return.
826   * @return An element of the variable matrix.
827   */
828  public double value(int index) {
829    return get(index).value();
830  }
831
832  /**
833   * Returns the contents of the variable matrix.
834   *
835   * @return The contents of the variable matrix.
836   */
837  public SimpleMatrix value() {
838    var result = new SimpleMatrix(rows(), cols());
839
840    for (int row = 0; row < rows(); ++row) {
841      for (int col = 0; col < cols(); ++col) {
842        result.set(row, col, value(row, col));
843      }
844    }
845
846    return result;
847  }
848
849  /**
850   * Maps the matrix coefficient-wise with an unary operator.
851   *
852   * @param unaryOp The unary operator to use for the map operation.
853   * @return Result of the unary operator.
854   */
855  public VariableMatrix cwiseMap(UnaryOperator<Variable> unaryOp) {
856    var result = new VariableMatrix(rows(), cols());
857
858    for (int row = 0; row < rows(); ++row) {
859      for (int col = 0; col < cols(); ++col) {
860        result.set(row, col, unaryOp.apply(get(row, col)));
861      }
862    }
863
864    return result;
865  }
866
867  /**
868   * Returns number of elements in matrix.
869   *
870   * @return Number of elements in matrix.
871   */
872  public int size() {
873    return m_storage.length;
874  }
875
876  @Override
877  public Iterator<Variable> iterator() {
878    return new Iterator<>() {
879      private int m_index = 0;
880
881      @Override
882      public boolean hasNext() {
883        return m_index < VariableMatrix.this.size();
884      }
885
886      @Override
887      public Variable next() {
888        if (!hasNext()) {
889          throw new NoSuchElementException();
890        }
891
892        return VariableMatrix.this.get(m_index++);
893      }
894    };
895  }
896
897  /**
898   * Creates a Stream of VariableMatrix elements.
899   *
900   * @return A Stream of VariableMatrix elements.
901   */
902  public Stream<Variable> stream() {
903    return StreamSupport.stream(spliterator(), false);
904  }
905
906  /**
907   * Returns a variable matrix filled with zeroes.
908   *
909   * @param rows The number of matrix rows.
910   * @param cols The number of matrix columns.
911   * @return A variable matrix filled with zeroes.
912   */
913  public static VariableMatrix zero(int rows, int cols) {
914    return new VariableMatrix(new SimpleMatrix(rows, cols));
915  }
916
917  /**
918   * Returns a variable matrix filled with ones.
919   *
920   * @param rows The number of matrix rows.
921   * @param cols The number of matrix columns.
922   * @return A variable matrix filled with ones.
923   */
924  public static VariableMatrix one(int rows, int cols) {
925    return new VariableMatrix(SimpleMatrix.ones(rows, cols));
926  }
927
928  /**
929   * Returns a variable matrix filled with a constant.
930   *
931   * @param rows The number of matrix rows.
932   * @param cols The number of matrix columns.
933   * @param constant The constant.
934   * @return A variable matrix filled with a constant.
935   */
936  public static VariableMatrix constant(int rows, int cols, double constant) {
937    return new VariableMatrix(SimpleMatrix.filled(rows, cols, constant));
938  }
939
940  /**
941   * Applies a coefficient-wise reduce operation to two matrices.
942   *
943   * @param lhs The left-hand side of the binary operator.
944   * @param rhs The right-hand side of the binary operator.
945   * @param binaryOp The binary operator to use for the reduce operation.
946   * @return Result of binary operator.
947   */
948  public static VariableMatrix cwiseReduce(
949      VariableMatrix lhs, VariableMatrix rhs, BinaryOperator<Variable> binaryOp) {
950    assert lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols();
951
952    var result = new VariableMatrix(lhs.rows(), lhs.cols());
953
954    for (int row = 0; row < lhs.rows(); ++row) {
955      for (int col = 0; col < lhs.cols(); ++col) {
956        result.set(row, col, binaryOp.apply(lhs.get(row, col), rhs.get(row, col)));
957      }
958    }
959
960    return result;
961  }
962
963  /**
964   * Assembles a VariableMatrix from a nested list of blocks.
965   *
966   * <p>Each row's blocks must have the same height, and the assembled block rows must have the same
967   * width. For example, for the block matrix [[A, B], [C]] to be constructible, the number of rows
968   * in A and B must match, and the number of columns in [A, B] and [C] must match.
969   *
970   * @param list The nested list of blocks.
971   * @return Block matrix.
972   */
973  @SuppressWarnings("OverloadMethodsDeclarationOrder")
974  public static VariableMatrix block(VariableMatrix[][] list) {
975    // Get row and column counts for destination matrix
976    int rows = 0;
977    int cols = -1;
978    for (var row : list) {
979      if (row.length > 0) {
980        rows += row[0].rows();
981      }
982
983      // Get number of columns in this row
984      int latestCols = 0;
985      for (var elem : row) {
986        // Assert the first and latest row have the same height
987        assert row[0].rows() == elem.rows();
988
989        latestCols += elem.cols();
990      }
991
992      // If this is the first row, record the column count. Otherwise, assert the
993      // first and latest column counts are the same.
994      if (cols == -1) {
995        cols = latestCols;
996      } else {
997        assert cols == latestCols;
998      }
999    }
1000
1001    var result = new VariableMatrix(rows, cols);
1002
1003    int rowOffset = 0;
1004    for (var row : list) {
1005      int colOffset = 0;
1006      for (var elem : row) {
1007        result.block(rowOffset, colOffset, elem.rows(), elem.cols()).set(elem);
1008        colOffset += elem.cols();
1009      }
1010      if (row.length > 0) {
1011        rowOffset += row[0].rows();
1012      }
1013    }
1014
1015    return result;
1016  }
1017
1018  /**
1019   * Solves the VariableMatrix equation AX = B for X.
1020   *
1021   * @param A The left-hand side.
1022   * @param B The right-hand side.
1023   * @return The solution X.
1024   */
1025  public static VariableMatrix solve(VariableMatrix A, VariableMatrix B) {
1026    // m x n * n x p = m x p
1027    assert A.rows() == B.rows();
1028
1029    return new VariableMatrix(
1030        A.cols(),
1031        B.cols(),
1032        VariableMatrixJNI.solve(A.getHandles(), A.cols(), B.getHandles(), B.cols()));
1033  }
1034
1035  /**
1036   * Returns an array of VariableMatrix internal handles in row-major order.
1037   *
1038   * @return Array of VariableMatrix internal handles in row-major order.
1039   */
1040  long[] getHandles() {
1041    var handles = new long[size()];
1042    for (int index = 0; index < size(); ++index) {
1043      handles[index] = m_storage[index].getHandle();
1044    }
1045    return handles;
1046  }
1047}