001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package org.wpilib.math.autodiff; 006 007import java.util.Iterator; 008import java.util.NoSuchElementException; 009import java.util.function.BinaryOperator; 010import java.util.function.UnaryOperator; 011import java.util.stream.Stream; 012import java.util.stream.StreamSupport; 013import org.ejml.simple.SimpleMatrix; 014 015/** A matrix of autodiff variables. */ 016public class VariableMatrix implements AutoCloseable, Iterable<Variable> { 017 private final Variable[] m_storage; 018 private int m_rows; 019 private int m_cols; 020 021 /** 022 * Constructs a VariableMatrix from Variable internal handles. 023 * 024 * <p>This constructor is for internal use only. 025 * 026 * @param rows The number of matrix rows. 027 * @param cols The number of matrix columns. 028 * @param handles Variable handles. 029 */ 030 public VariableMatrix(int rows, int cols, long[] handles) { 031 assert handles.length == rows * cols; 032 033 m_rows = rows; 034 m_cols = cols; 035 m_storage = new Variable[rows * cols]; 036 for (int index = 0; index < m_storage.length; ++index) { 037 m_storage[index] = new Variable(Variable.HANDLE, handles[index]); 038 } 039 } 040 041 /** 042 * Constructs a zero-initialized VariableMatrix column vector with the given rows. 043 * 044 * @param rows The number of matrix rows. 045 */ 046 public VariableMatrix(int rows) { 047 this(rows, 1); 048 } 049 050 /** 051 * Constructs a zero-initialized VariableMatrix with the given dimensions. 052 * 053 * @param rows The number of matrix rows. 054 * @param cols The number of matrix columns. 055 */ 056 public VariableMatrix(int rows, int cols) { 057 m_rows = rows; 058 m_cols = cols; 059 m_storage = new Variable[rows * cols]; 060 for (int index = 0; index < m_storage.length; ++index) { 061 m_storage[index] = new Variable(); 062 } 063 } 064 065 /** 066 * Constructs a scalar VariableMatrix from a nested list of doubles. 067 * 068 * @param list The nested list of Variables. 069 */ 070 public VariableMatrix(double[][] list) { 071 // Get row and column counts for destination matrix 072 m_rows = list.length; 073 m_cols = 0; 074 if (list.length > 0) { 075 m_cols = list[0].length; 076 } 077 078 // Assert all column counts are the same 079 for (var row : list) { 080 assert row.length == m_cols; 081 } 082 083 m_storage = new Variable[m_rows * m_cols]; 084 int index = 0; 085 for (var row : list) { 086 for (var elem : row) { 087 m_storage[index] = new Variable(elem); 088 ++index; 089 } 090 } 091 } 092 093 /** 094 * Constructs a scalar VariableMatrix from a nested list of Variables. 095 * 096 * @param list The nested list of Variables. 097 */ 098 public VariableMatrix(Variable[][] list) { 099 // Get row and column counts for destination matrix 100 m_rows = list.length; 101 m_cols = 0; 102 if (list.length > 0) { 103 m_cols = list[0].length; 104 } 105 106 // Assert all column counts are the same 107 for (var row : list) { 108 assert row.length == m_cols; 109 } 110 111 m_storage = new Variable[m_rows * m_cols]; 112 int index = 0; 113 for (var row : list) { 114 for (var elem : row) { 115 m_storage[index] = elem; 116 ++index; 117 } 118 } 119 } 120 121 /** 122 * Constructs a VariableMatrix from an EJML matrix. 123 * 124 * @param values EJML matrix of values. 125 */ 126 public VariableMatrix(SimpleMatrix values) { 127 m_rows = values.getNumRows(); 128 m_cols = values.getNumCols(); 129 m_storage = new Variable[m_rows * m_cols]; 130 for (int row = 0; row < values.getNumRows(); ++row) { 131 for (int col = 0; col < values.getNumCols(); ++col) { 132 m_storage[row * m_cols + col] = new Variable(values.get(row, col)); 133 } 134 } 135 } 136 137 /** 138 * Constructs a scalar VariableMatrix from a Variable. 139 * 140 * @param variable Variable. 141 */ 142 public VariableMatrix(Variable variable) { 143 m_rows = 1; 144 m_cols = 1; 145 m_storage = new Variable[] {variable}; 146 } 147 148 /** 149 * Constructs a VariableMatrix from a VariableBlock. 150 * 151 * @param values VariableBlock of values. 152 */ 153 public VariableMatrix(VariableBlock values) { 154 m_rows = values.rows(); 155 m_cols = values.cols(); 156 m_storage = new Variable[m_rows * m_cols]; 157 for (int row = 0; row < m_rows; ++row) { 158 for (int col = 0; col < m_cols; ++col) { 159 m_storage[row * m_cols + col] = values.get(row, col); 160 } 161 } 162 } 163 164 @Override 165 public void close() { 166 for (int index = 0; index < rows() * cols(); ++index) { 167 m_storage[index].close(); 168 } 169 } 170 171 /** 172 * Assigns a double array to a VariableMatrix. 173 * 174 * @param values Double array of values. 175 * @return This VariableMatrix. 176 */ 177 public VariableMatrix set(double[][] values) { 178 assert rows() == values.length; 179 180 // Assert all column counts are the same 181 for (var row : values) { 182 assert row.length == cols(); 183 } 184 185 for (int row = 0; row < values.length; ++row) { 186 for (int col = 0; col < values[0].length; ++col) { 187 set(row, col, values[row][col]); 188 } 189 } 190 191 return this; 192 } 193 194 /** 195 * Assigns an EJML matrix to a VariableMatrix. 196 * 197 * @param values EJML matrix of values. 198 * @return This VariableMatrix. 199 */ 200 public VariableMatrix set(SimpleMatrix values) { 201 assert rows() == values.getNumRows() && cols() == values.getNumCols(); 202 203 for (int row = 0; row < values.getNumRows(); ++row) { 204 for (int col = 0; col < values.getNumCols(); ++col) { 205 set(row, col, values.get(row, col)); 206 } 207 } 208 209 return this; 210 } 211 212 /** 213 * Assigns a VariableMatrix to a VariableMatrix. 214 * 215 * @param values VariableMatrix of values. 216 * @return This VariableMatrix. 217 */ 218 public VariableMatrix set(VariableMatrix values) { 219 assert rows() == values.rows() && cols() == values.cols(); 220 221 for (int row = 0; row < values.rows(); ++row) { 222 for (int col = 0; col < values.cols(); ++col) { 223 set(row, col, values.get(row, col)); 224 } 225 } 226 227 return this; 228 } 229 230 /** 231 * Assigns a VariableBlock to a VariableMatrix. 232 * 233 * @param values VariableBlock of values. 234 * @return This VariableMatrix. 235 */ 236 public VariableMatrix set(VariableBlock values) { 237 assert rows() == values.rows() && cols() == values.cols(); 238 239 for (int row = 0; row < values.rows(); ++row) { 240 for (int col = 0; col < values.cols(); ++col) { 241 set(row, col, values.get(row, col)); 242 } 243 } 244 245 return this; 246 } 247 248 /** 249 * Assigns a double to the matrix. 250 * 251 * <p>This only works for matrices with one row and one column. 252 * 253 * @param value Value to assign. 254 * @return This VariableMatrix. 255 */ 256 public VariableMatrix set(double value) { 257 return set(new Variable(value)); 258 } 259 260 /** 261 * Assigns a Variable to the matrix. 262 * 263 * <p>This only works for matrices with one row and one column. 264 * 265 * @param value Value to assign. 266 * @return This VariableMatrix. 267 */ 268 public VariableMatrix set(Variable value) { 269 assert rows() == 1 && cols() == 1; 270 271 m_storage[0] = value; 272 273 return this; 274 } 275 276 /** 277 * Sets an element to the given value. 278 * 279 * @param row The row. 280 * @param col The column. 281 * @param value The value. 282 */ 283 public void set(int row, int col, Variable value) { 284 assert row >= 0 && row < rows(); 285 assert col >= 0 && col < cols(); 286 m_storage[row * cols() + col] = value; 287 } 288 289 /** 290 * Sets an element to the given value. 291 * 292 * @param row The row. 293 * @param col The column. 294 * @param value The value. 295 */ 296 public void set(int row, int col, double value) { 297 assert row >= 0 && row < rows(); 298 assert col >= 0 && col < cols(); 299 m_storage[row * cols() + col] = new Variable(value); 300 } 301 302 /** 303 * Sets an element to the given value. 304 * 305 * @param index The index of the element. 306 * @param value The value. 307 */ 308 public void set(int index, double value) { 309 set(index, new Variable(value)); 310 } 311 312 /** 313 * Sets an element to the given value. 314 * 315 * @param index The index of the element. 316 * @param value The value. 317 */ 318 public void set(int index, Variable value) { 319 assert index >= 0 && index < rows() * cols(); 320 m_storage[index] = value; 321 } 322 323 /** 324 * Sets the VariableMatrix's internal values. 325 * 326 * @param values Double array of values. 327 */ 328 public void setValue(double[][] values) { 329 assert rows() == values.length; 330 331 // Assert all column counts are the same 332 for (var row : values) { 333 assert row.length == cols(); 334 } 335 336 for (int row = 0; row < rows(); ++row) { 337 for (int col = 0; col < cols(); ++col) { 338 get(row, col).setValue(values[row][col]); 339 } 340 } 341 } 342 343 /** 344 * Sets the VariableMatrix's internal values. 345 * 346 * @param values EJML matrix of values. 347 */ 348 public void setValue(SimpleMatrix values) { 349 assert rows() == values.getNumRows() && cols() == values.getNumCols(); 350 351 for (int row = 0; row < values.getNumRows(); ++row) { 352 for (int col = 0; col < values.getNumCols(); ++col) { 353 get(row, col).setValue(values.get(row, col)); 354 } 355 } 356 } 357 358 /** 359 * Returns the element at the given row and column. 360 * 361 * @param row The row. 362 * @param col The column. 363 * @return The element at the given row and column. 364 */ 365 public Variable get(int row, int col) { 366 assert row >= 0 && row < rows(); 367 assert col >= 0 && col < cols(); 368 return m_storage[row * cols() + col]; 369 } 370 371 /** 372 * Returns the element at the given index. 373 * 374 * @param index The index. 375 * @return The element at the given index. 376 */ 377 public Variable get(int index) { 378 assert index >= 0 && index < rows() * cols(); 379 return m_storage[index]; 380 } 381 382 /** 383 * Returns a slice of the variable matrix. 384 * 385 * @param row The row. 386 * @param colSlice The column slice. 387 * @return A slice of the variable matrix. 388 */ 389 public VariableBlock get(int row, Slice.None colSlice) { 390 return get(new Slice(row), new Slice(colSlice)); 391 } 392 393 /** 394 * Returns a slice of the variable matrix. 395 * 396 * @param row The row. 397 * @param colSlice The column slice. 398 * @return A slice of the variable matrix. 399 */ 400 public VariableBlock get(int row, Slice colSlice) { 401 return get(new Slice(row), colSlice); 402 } 403 404 /** 405 * Returns a slice of the variable matrix. 406 * 407 * @param rowSlice The row slice. 408 * @param col The column. 409 * @return A slice of the variable matrix. 410 */ 411 public VariableBlock get(Slice.None rowSlice, int col) { 412 return get(new Slice(rowSlice), new Slice(col)); 413 } 414 415 /** 416 * Returns a slice of the variable matrix. 417 * 418 * @param rowSlice The row slice. 419 * @param col The column. 420 * @return A slice of the variable matrix. 421 */ 422 public VariableBlock get(Slice rowSlice, int col) { 423 return get(rowSlice, new Slice(col)); 424 } 425 426 /** 427 * Returns a slice of the variable matrix. 428 * 429 * @param rowSlice The row slice. 430 * @param colSlice The column slice. 431 * @return A slice of the variable matrix. 432 */ 433 public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) { 434 return get(new Slice(rowSlice), new Slice(colSlice)); 435 } 436 437 /** 438 * Returns a slice of the variable matrix. 439 * 440 * @param rowSlice The row slice. 441 * @param colSlice The column slice. 442 * @return A slice of the variable matrix. 443 */ 444 public VariableBlock get(Slice.None rowSlice, Slice colSlice) { 445 return get(new Slice(rowSlice), colSlice); 446 } 447 448 /** 449 * Returns a slice of the variable matrix. 450 * 451 * @param rowSlice The row slice. 452 * @param colSlice The column slice. 453 * @return A slice of the variable matrix. 454 */ 455 public VariableBlock get(Slice rowSlice, Slice.None colSlice) { 456 return get(rowSlice, new Slice(colSlice)); 457 } 458 459 /** 460 * Returns a slice of the variable matrix. 461 * 462 * @param rowSlice The row slice. 463 * @param colSlice The column slice. 464 * @return A slice of the variable matrix. 465 */ 466 public VariableBlock get(Slice rowSlice, Slice colSlice) { 467 int rowSliceLength = rowSlice.adjust(rows()); 468 int colSliceLength = colSlice.adjust(cols()); 469 return new VariableBlock(this, rowSlice, rowSliceLength, colSlice, colSliceLength); 470 } 471 472 /** 473 * Returns a block of the variable matrix. 474 * 475 * @param rowOffset The row offset of the block selection. 476 * @param colOffset The column offset of the block selection. 477 * @param blockRows The number of rows in the block selection. 478 * @param blockCols The number of columns in the block selection. 479 * @return A block of the variable matrix. 480 */ 481 public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) { 482 assert rowOffset >= 0 && rowOffset <= rows(); 483 assert colOffset >= 0 && colOffset <= cols(); 484 assert blockRows >= 0 && blockRows <= rows() - rowOffset; 485 assert blockCols >= 0 && blockCols <= cols() - colOffset; 486 return new VariableBlock(this, rowOffset, colOffset, blockRows, blockCols); 487 } 488 489 /** 490 * Returns a segment of the variable vector. 491 * 492 * @param offset The offset of the segment. 493 * @param length The length of the segment. 494 * @return A segment of the variable vector. 495 */ 496 public VariableBlock segment(int offset, int length) { 497 assert cols() == 1; 498 assert offset >= 0 && offset < rows(); 499 assert length >= 0 && length <= rows() - offset; 500 return block(offset, 0, length, 1); 501 } 502 503 /** 504 * Returns a row slice of the variable matrix. 505 * 506 * @param row The row to slice. 507 * @return A row slice of the variable matrix. 508 */ 509 public VariableBlock row(int row) { 510 assert row >= 0 && row < rows(); 511 return block(row, 0, 1, cols()); 512 } 513 514 /** 515 * Returns a column slice of the variable matrix. 516 * 517 * @param col The column to slice. 518 * @return A column slice of the variable matrix. 519 */ 520 public VariableBlock col(int col) { 521 assert col >= 0 && col < cols(); 522 return block(0, col, rows(), 1); 523 } 524 525 /** 526 * Matrix multiplication operator. 527 * 528 * @param rhs Operator right-hand side. 529 * @return Result of matrix multiplication. 530 */ 531 public VariableMatrix times(VariableMatrix rhs) { 532 assert cols() == rhs.rows(); 533 534 var result = new VariableMatrix(rows(), rhs.cols()); 535 536 for (int i = 0; i < rows(); ++i) { 537 for (int j = 0; j < rhs.cols(); ++j) { 538 var sum = new Variable(0.0); 539 for (int k = 0; k < cols(); ++k) { 540 sum = sum.plus(get(i, k).times(rhs.get(k, j))); 541 } 542 result.set(i, j, sum); 543 } 544 } 545 546 return result; 547 } 548 549 /** 550 * Matrix multiplication operator. 551 * 552 * @param rhs Operator right-hand side. 553 * @return Result of matrix multiplication. 554 */ 555 public VariableMatrix times(VariableBlock rhs) { 556 assert cols() == rhs.rows(); 557 558 var result = new VariableMatrix(rows(), rhs.cols()); 559 560 for (int i = 0; i < rows(); ++i) { 561 for (int j = 0; j < rhs.cols(); ++j) { 562 var sum = new Variable(0.0); 563 for (int k = 0; k < cols(); ++k) { 564 sum = sum.plus(get(i, k).times(rhs.get(k, j))); 565 } 566 result.set(i, j, sum); 567 } 568 } 569 570 return result; 571 } 572 573 /** 574 * Matrix multiplication operator. 575 * 576 * @param rhs Operator right-hand side. 577 * @return Result of matrix multiplication. 578 */ 579 public VariableMatrix times(SimpleMatrix rhs) { 580 return times(new VariableMatrix(rhs)); 581 } 582 583 /** 584 * Matrix-scalar multiplication operator. 585 * 586 * @param rhs Operator right-hand side. 587 * @return Result of matrix-scalar multiplication. 588 */ 589 public VariableMatrix times(double rhs) { 590 return times(new Variable(rhs)); 591 } 592 593 /** 594 * Matrix-scalar multiplication operator. 595 * 596 * @param rhs Operator right-hand side. 597 * @return Result of matrix-scalar multiplication. 598 */ 599 public VariableMatrix times(Variable rhs) { 600 var result = new VariableMatrix(rows(), cols()); 601 602 for (int row = 0; row < result.rows(); ++row) { 603 for (int col = 0; col < result.cols(); ++col) { 604 result.set(row, col, get(row, col).times(rhs)); 605 } 606 } 607 608 return result; 609 } 610 611 /** 612 * Binary division operator. 613 * 614 * @param rhs Operator right-hand side. 615 * @return Result of division. 616 */ 617 public VariableMatrix div(double rhs) { 618 return div(new Variable(rhs)); 619 } 620 621 /** 622 * Binary division operator. 623 * 624 * @param rhs Operator right-hand side. 625 * @return Result of division. 626 */ 627 public VariableMatrix div(Variable rhs) { 628 var result = new VariableMatrix(rows(), cols()); 629 630 for (int row = 0; row < result.rows(); ++row) { 631 for (int col = 0; col < result.cols(); ++col) { 632 result.set(row, col, get(row, col).div(rhs)); 633 } 634 } 635 636 return result; 637 } 638 639 /** 640 * Binary addition operator. 641 * 642 * @param rhs Operator right-hand side. 643 * @return Result of addition. 644 */ 645 public VariableMatrix plus(VariableMatrix rhs) { 646 assert rows() == rhs.rows() && cols() == rhs.cols(); 647 648 var result = new VariableMatrix(rows(), cols()); 649 650 for (int row = 0; row < result.rows(); ++row) { 651 for (int col = 0; col < result.cols(); ++col) { 652 result.set(row, col, get(row, col).plus(rhs.get(row, col))); 653 } 654 } 655 656 return result; 657 } 658 659 /** 660 * Binary addition operator. 661 * 662 * @param rhs Operator right-hand side. 663 * @return Result of addition. 664 */ 665 public VariableMatrix plus(VariableBlock rhs) { 666 assert rows() == rhs.rows() && cols() == rhs.cols(); 667 668 var result = new VariableMatrix(rows(), cols()); 669 670 for (int row = 0; row < result.rows(); ++row) { 671 for (int col = 0; col < result.cols(); ++col) { 672 result.set(row, col, get(row, col).plus(rhs.get(row, col))); 673 } 674 } 675 676 return result; 677 } 678 679 /** 680 * Binary addition operator. 681 * 682 * @param rhs Operator right-hand side. 683 * @return Result of addition. 684 */ 685 public VariableMatrix plus(SimpleMatrix rhs) { 686 assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols(); 687 688 var result = new VariableMatrix(rows(), cols()); 689 690 for (int row = 0; row < result.rows(); ++row) { 691 for (int col = 0; col < result.cols(); ++col) { 692 result.set(row, col, get(row, col).plus(rhs.get(row, col))); 693 } 694 } 695 696 return result; 697 } 698 699 /** 700 * Binary subtraction operator. 701 * 702 * @param rhs Operator right-hand side. 703 * @return Result of subtraction. 704 */ 705 public VariableMatrix minus(VariableMatrix rhs) { 706 assert rows() == rhs.rows() && cols() == rhs.cols(); 707 708 var result = new VariableMatrix(rows(), cols()); 709 710 for (int row = 0; row < result.rows(); ++row) { 711 for (int col = 0; col < result.cols(); ++col) { 712 result.set(row, col, get(row, col).minus(rhs.get(row, col))); 713 } 714 } 715 716 return result; 717 } 718 719 /** 720 * Binary subtraction operator. 721 * 722 * @param rhs Operator right-hand side. 723 * @return Result of subtraction. 724 */ 725 public VariableMatrix minus(VariableBlock rhs) { 726 assert rows() == rhs.rows() && cols() == rhs.cols(); 727 728 var result = new VariableMatrix(rows(), cols()); 729 730 for (int row = 0; row < result.rows(); ++row) { 731 for (int col = 0; col < result.cols(); ++col) { 732 result.set(row, col, get(row, col).minus(rhs.get(row, col))); 733 } 734 } 735 736 return result; 737 } 738 739 /** 740 * Binary subtraction operator. 741 * 742 * @param rhs Operator right-hand side. 743 * @return Result of subtraction. 744 */ 745 public VariableMatrix minus(SimpleMatrix rhs) { 746 assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols(); 747 748 var result = new VariableMatrix(rows(), cols()); 749 750 for (int row = 0; row < result.rows(); ++row) { 751 for (int col = 0; col < result.cols(); ++col) { 752 result.set(row, col, get(row, col).minus(rhs.get(row, col))); 753 } 754 } 755 756 return result; 757 } 758 759 /** 760 * Unary minus operator. 761 * 762 * @return Result of unary minus. 763 */ 764 public VariableMatrix unaryMinus() { 765 var result = new VariableMatrix(rows(), cols()); 766 767 for (int row = 0; row < result.rows(); ++row) { 768 for (int col = 0; col < result.cols(); ++col) { 769 result.set(row, col, get(row, col).unaryMinus()); 770 } 771 } 772 773 return result; 774 } 775 776 /** 777 * Returns the transpose of the variable matrix. 778 * 779 * @return The transpose of the variable matrix. 780 */ 781 public VariableMatrix T() { 782 var result = new VariableMatrix(cols(), rows()); 783 784 for (int row = 0; row < rows(); ++row) { 785 for (int col = 0; col < cols(); ++col) { 786 result.set(col, row, get(row, col)); 787 } 788 } 789 790 return result; 791 } 792 793 /** 794 * Returns the number of rows in the matrix. 795 * 796 * @return The number of rows in the matrix. 797 */ 798 public int rows() { 799 return m_rows; 800 } 801 802 /** 803 * Returns the number of columns in the matrix. 804 * 805 * @return The number of columns in the matrix. 806 */ 807 public int cols() { 808 return m_cols; 809 } 810 811 /** 812 * Returns an element of the variable matrix. 813 * 814 * @param row The row of the element to return. 815 * @param col The column of the element to return. 816 * @return An element of the variable matrix. 817 */ 818 public double value(int row, int col) { 819 return get(row, col).value(); 820 } 821 822 /** 823 * Returns an element of the variable matrix. 824 * 825 * @param index The index of the element to return. 826 * @return An element of the variable matrix. 827 */ 828 public double value(int index) { 829 return get(index).value(); 830 } 831 832 /** 833 * Returns the contents of the variable matrix. 834 * 835 * @return The contents of the variable matrix. 836 */ 837 public SimpleMatrix value() { 838 var result = new SimpleMatrix(rows(), cols()); 839 840 for (int row = 0; row < rows(); ++row) { 841 for (int col = 0; col < cols(); ++col) { 842 result.set(row, col, value(row, col)); 843 } 844 } 845 846 return result; 847 } 848 849 /** 850 * Maps the matrix coefficient-wise with an unary operator. 851 * 852 * @param unaryOp The unary operator to use for the map operation. 853 * @return Result of the unary operator. 854 */ 855 public VariableMatrix cwiseMap(UnaryOperator<Variable> unaryOp) { 856 var result = new VariableMatrix(rows(), cols()); 857 858 for (int row = 0; row < rows(); ++row) { 859 for (int col = 0; col < cols(); ++col) { 860 result.set(row, col, unaryOp.apply(get(row, col))); 861 } 862 } 863 864 return result; 865 } 866 867 /** 868 * Returns number of elements in matrix. 869 * 870 * @return Number of elements in matrix. 871 */ 872 public int size() { 873 return m_storage.length; 874 } 875 876 @Override 877 public Iterator<Variable> iterator() { 878 return new Iterator<>() { 879 private int m_index = 0; 880 881 @Override 882 public boolean hasNext() { 883 return m_index < VariableMatrix.this.size(); 884 } 885 886 @Override 887 public Variable next() { 888 if (!hasNext()) { 889 throw new NoSuchElementException(); 890 } 891 892 return VariableMatrix.this.get(m_index++); 893 } 894 }; 895 } 896 897 /** 898 * Creates a Stream of VariableMatrix elements. 899 * 900 * @return A Stream of VariableMatrix elements. 901 */ 902 public Stream<Variable> stream() { 903 return StreamSupport.stream(spliterator(), false); 904 } 905 906 /** 907 * Returns a variable matrix filled with zeroes. 908 * 909 * @param rows The number of matrix rows. 910 * @param cols The number of matrix columns. 911 * @return A variable matrix filled with zeroes. 912 */ 913 public static VariableMatrix zero(int rows, int cols) { 914 return new VariableMatrix(new SimpleMatrix(rows, cols)); 915 } 916 917 /** 918 * Returns a variable matrix filled with ones. 919 * 920 * @param rows The number of matrix rows. 921 * @param cols The number of matrix columns. 922 * @return A variable matrix filled with ones. 923 */ 924 public static VariableMatrix one(int rows, int cols) { 925 return new VariableMatrix(SimpleMatrix.ones(rows, cols)); 926 } 927 928 /** 929 * Returns a variable matrix filled with a constant. 930 * 931 * @param rows The number of matrix rows. 932 * @param cols The number of matrix columns. 933 * @param constant The constant. 934 * @return A variable matrix filled with a constant. 935 */ 936 public static VariableMatrix constant(int rows, int cols, double constant) { 937 return new VariableMatrix(SimpleMatrix.filled(rows, cols, constant)); 938 } 939 940 /** 941 * Applies a coefficient-wise reduce operation to two matrices. 942 * 943 * @param lhs The left-hand side of the binary operator. 944 * @param rhs The right-hand side of the binary operator. 945 * @param binaryOp The binary operator to use for the reduce operation. 946 * @return Result of binary operator. 947 */ 948 public static VariableMatrix cwiseReduce( 949 VariableMatrix lhs, VariableMatrix rhs, BinaryOperator<Variable> binaryOp) { 950 assert lhs.rows() == rhs.rows() && lhs.cols() == rhs.cols(); 951 952 var result = new VariableMatrix(lhs.rows(), lhs.cols()); 953 954 for (int row = 0; row < lhs.rows(); ++row) { 955 for (int col = 0; col < lhs.cols(); ++col) { 956 result.set(row, col, binaryOp.apply(lhs.get(row, col), rhs.get(row, col))); 957 } 958 } 959 960 return result; 961 } 962 963 /** 964 * Assembles a VariableMatrix from a nested list of blocks. 965 * 966 * <p>Each row's blocks must have the same height, and the assembled block rows must have the same 967 * width. For example, for the block matrix [[A, B], [C]] to be constructible, the number of rows 968 * in A and B must match, and the number of columns in [A, B] and [C] must match. 969 * 970 * @param list The nested list of blocks. 971 * @return Block matrix. 972 */ 973 @SuppressWarnings("OverloadMethodsDeclarationOrder") 974 public static VariableMatrix block(VariableMatrix[][] list) { 975 // Get row and column counts for destination matrix 976 int rows = 0; 977 int cols = -1; 978 for (var row : list) { 979 if (row.length > 0) { 980 rows += row[0].rows(); 981 } 982 983 // Get number of columns in this row 984 int latestCols = 0; 985 for (var elem : row) { 986 // Assert the first and latest row have the same height 987 assert row[0].rows() == elem.rows(); 988 989 latestCols += elem.cols(); 990 } 991 992 // If this is the first row, record the column count. Otherwise, assert the 993 // first and latest column counts are the same. 994 if (cols == -1) { 995 cols = latestCols; 996 } else { 997 assert cols == latestCols; 998 } 999 } 1000 1001 var result = new VariableMatrix(rows, cols); 1002 1003 int rowOffset = 0; 1004 for (var row : list) { 1005 int colOffset = 0; 1006 for (var elem : row) { 1007 result.block(rowOffset, colOffset, elem.rows(), elem.cols()).set(elem); 1008 colOffset += elem.cols(); 1009 } 1010 if (row.length > 0) { 1011 rowOffset += row[0].rows(); 1012 } 1013 } 1014 1015 return result; 1016 } 1017 1018 /** 1019 * Solves the VariableMatrix equation AX = B for X. 1020 * 1021 * @param A The left-hand side. 1022 * @param B The right-hand side. 1023 * @return The solution X. 1024 */ 1025 public static VariableMatrix solve(VariableMatrix A, VariableMatrix B) { 1026 // m x n * n x p = m x p 1027 assert A.rows() == B.rows(); 1028 1029 return new VariableMatrix( 1030 A.cols(), 1031 B.cols(), 1032 VariableMatrixJNI.solve(A.getHandles(), A.cols(), B.getHandles(), B.cols())); 1033 } 1034 1035 /** 1036 * Returns an array of VariableMatrix internal handles in row-major order. 1037 * 1038 * @return Array of VariableMatrix internal handles in row-major order. 1039 */ 1040 long[] getHandles() { 1041 var handles = new long[size()]; 1042 for (int index = 0; index < size(); ++index) { 1043 handles[index] = m_storage[index].getHandle(); 1044 } 1045 return handles; 1046 } 1047}