001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package org.wpilib.math.autodiff; 006 007import java.util.Iterator; 008import java.util.NoSuchElementException; 009import java.util.function.UnaryOperator; 010import java.util.stream.Stream; 011import java.util.stream.StreamSupport; 012import org.ejml.simple.SimpleMatrix; 013 014/** A submatrix of autodiff variables with reference semantics. */ 015public class VariableBlock implements Iterable<Variable> { 016 private final VariableMatrix m_mat; 017 018 private final Slice m_rowSlice; 019 private final int m_rowSliceLength; 020 021 private final Slice m_colSlice; 022 private final int m_colSliceLength; 023 024 /** 025 * Constructs a Variable block pointing to all of the given matrix. 026 * 027 * @param mat The matrix to which to point. 028 */ 029 public VariableBlock(VariableMatrix mat) { 030 this(mat, 0, 0, mat.rows(), mat.cols()); 031 } 032 033 /** 034 * Constructs a Variable block pointing to a subset of the given matrix. 035 * 036 * @param mat The matrix to which to point. 037 * @param rowOffset The block's row offset. 038 * @param colOffset The block's column offset. 039 * @param blockRows The number of rows in the block. 040 * @param blockCols The number of columns in the block. 041 */ 042 public VariableBlock( 043 VariableMatrix mat, int rowOffset, int colOffset, int blockRows, int blockCols) { 044 m_mat = mat; 045 m_rowSlice = new Slice(rowOffset, rowOffset + blockRows, 1); 046 m_rowSliceLength = m_rowSlice.adjust(mat.rows()); 047 m_colSlice = new Slice(colOffset, colOffset + blockCols, 1); 048 m_colSliceLength = m_colSlice.adjust(mat.cols()); 049 } 050 051 /** 052 * Constructs a Variable block pointing to a subset of the given matrix. 053 * 054 * <p>Note that the slices are taken as is rather than adjusted. 055 * 056 * @param mat The matrix to which to point. 057 * @param rowSlice The block's row slice. 058 * @param rowSliceLength The block's row length. 059 * @param colSlice The block's column slice. 060 * @param colSliceLength The block's column length. 061 */ 062 public VariableBlock( 063 VariableMatrix mat, Slice rowSlice, int rowSliceLength, Slice colSlice, int colSliceLength) { 064 m_mat = mat; 065 m_rowSlice = rowSlice; 066 m_rowSliceLength = rowSliceLength; 067 m_colSlice = colSlice; 068 m_colSliceLength = colSliceLength; 069 } 070 071 /** 072 * Assigns a double to the block. 073 * 074 * <p>This only works for blocks with one row and one column. 075 * 076 * @param value Value to assign. 077 * @return This VariableBlock. 078 */ 079 public VariableBlock set(double value) { 080 assert rows() == 1 && cols() == 1; 081 082 set(0, 0, new Variable(value)); 083 084 return this; 085 } 086 087 /** 088 * Assigns a Variable to the block. 089 * 090 * <p>This only works for blocks with one row and one column. 091 * 092 * @param value Value to assign. 093 * @return This VariableBlock. 094 */ 095 public VariableBlock set(Variable value) { 096 assert rows() == 1 && cols() == 1; 097 098 set(0, 0, value); 099 100 return this; 101 } 102 103 /** 104 * Assigns a double array to the block. 105 * 106 * @param values Double array of values to assign. 107 * @return This VariableBlock. 108 */ 109 public VariableBlock set(double[][] values) { 110 assert rows() == values.length; 111 112 // Assert all column counts are the same 113 for (var row : values) { 114 assert row.length == cols(); 115 } 116 117 for (int row = 0; row < rows(); ++row) { 118 for (int col = 0; col < cols(); ++col) { 119 set(row, col, values[row][col]); 120 } 121 } 122 123 return this; 124 } 125 126 /** 127 * Assigns an EJML matrix to the block. 128 * 129 * @param values EJML matrix of values to assign. 130 * @return This VariableBlock. 131 */ 132 public VariableBlock set(SimpleMatrix values) { 133 assert rows() == values.getNumRows() && cols() == values.getNumCols(); 134 135 for (int row = 0; row < rows(); ++row) { 136 for (int col = 0; col < cols(); ++col) { 137 set(row, col, values.get(row, col)); 138 } 139 } 140 141 return this; 142 } 143 144 /** 145 * Assigns a VariableMatrix to the block. 146 * 147 * @param values VariableMatrix of values. 148 * @return This VariableBlock. 149 */ 150 public VariableBlock set(VariableMatrix values) { 151 assert rows() == values.rows() && cols() == values.cols(); 152 153 for (int row = 0; row < rows(); ++row) { 154 for (int col = 0; col < cols(); ++col) { 155 set(row, col, values.get(row, col)); 156 } 157 } 158 return this; 159 } 160 161 /** 162 * Assigns a VariableBlock to the block. 163 * 164 * @param values VariableBlock of values. 165 * @return This VariableBlock. 166 */ 167 public VariableBlock set(VariableBlock values) { 168 assert rows() == values.rows() && cols() == values.cols(); 169 170 for (int row = 0; row < rows(); ++row) { 171 for (int col = 0; col < cols(); ++col) { 172 set(row, col, values.get(row, col)); 173 } 174 } 175 return this; 176 } 177 178 /** 179 * Sets a scalar subblock at the given row and column. 180 * 181 * @param row The scalar subblock's row. 182 * @param col The scalar subblock's column. 183 * @param value The value. 184 */ 185 public void set(int row, int col, Variable value) { 186 assert row >= 0 && row < rows(); 187 assert col >= 0 && col < cols(); 188 m_mat.set( 189 m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value); 190 } 191 192 /** 193 * Sets a scalar subblock at the given row and column. 194 * 195 * @param row The scalar subblock's row. 196 * @param col The scalar subblock's column. 197 * @param value The value. 198 */ 199 public void set(int row, int col, double value) { 200 assert row >= 0 && row < rows(); 201 assert col >= 0 && col < cols(); 202 m_mat.set( 203 m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step, value); 204 } 205 206 /** 207 * Sets a scalar subblock at the given index. 208 * 209 * @param index The scalar subblock's index. 210 * @param value The value. 211 */ 212 public void set(int index, double value) { 213 set(index, new Variable(value)); 214 } 215 216 /** 217 * Sets a scalar subblock at the given index. 218 * 219 * @param index The scalar subblock's index. 220 * @param value The value. 221 */ 222 public void set(int index, Variable value) { 223 assert index >= 0 && index < rows() * cols(); 224 set(index / cols(), index % cols(), value); 225 } 226 227 /** 228 * Assigns a double to the block. 229 * 230 * <p>This only works for blocks with one row and one column. 231 * 232 * @param value Value to assign. 233 */ 234 public void setValue(double value) { 235 assert rows() == 1 && cols() == 1; 236 237 get(0, 0).setValue(value); 238 } 239 240 /** 241 * Sets block's internal values. 242 * 243 * @param values Double array of values. 244 */ 245 public void setValue(double[][] values) { 246 assert rows() == values.length; 247 248 // Assert all column counts are the same 249 for (var row : values) { 250 assert row.length == cols(); 251 } 252 253 for (int row = 0; row < rows(); ++row) { 254 for (int col = 0; col < cols(); ++col) { 255 get(row, col).setValue(values[row][col]); 256 } 257 } 258 } 259 260 /** 261 * Sets block's internal values. 262 * 263 * @param values EJML matrix of values. 264 */ 265 public void setValue(SimpleMatrix values) { 266 assert rows() == values.getNumRows() && cols() == values.getNumCols(); 267 268 for (int row = 0; row < rows(); ++row) { 269 for (int col = 0; col < cols(); ++col) { 270 get(row, col).setValue(values.get(row, col)); 271 } 272 } 273 } 274 275 /** 276 * Returns a scalar subblock at the given row and column. 277 * 278 * @param row The scalar subblock's row. 279 * @param col The scalar subblock's column. 280 * @return A scalar subblock at the given row and column. 281 */ 282 public Variable get(int row, int col) { 283 assert row >= 0 && row < rows(); 284 assert col >= 0 && col < cols(); 285 return m_mat.get( 286 m_rowSlice.start + row * m_rowSlice.step, m_colSlice.start + col * m_colSlice.step); 287 } 288 289 /** 290 * Returns a scalar subblock at the given index. 291 * 292 * @param index The scalar subblock's index. 293 * @return A scalar subblock at the given index. 294 */ 295 public Variable get(int index) { 296 assert index >= 0 && index < rows() * cols(); 297 return get(index / cols(), index % cols()); 298 } 299 300 /** 301 * Returns a slice of the variable matrix. 302 * 303 * @param row The row. 304 * @param colSlice The column slice. 305 * @return A slice of the variable matrix. 306 */ 307 public VariableBlock get(int row, Slice.None colSlice) { 308 return get(new Slice(row), new Slice(colSlice)); 309 } 310 311 /** 312 * Returns a slice of the variable matrix. 313 * 314 * @param row The row. 315 * @param colSlice The column slice. 316 * @return A slice of the variable matrix. 317 */ 318 public VariableBlock get(int row, Slice colSlice) { 319 return get(new Slice(row), colSlice); 320 } 321 322 /** 323 * Returns a slice of the variable matrix. 324 * 325 * @param rowSlice The row slice. 326 * @param col The column. 327 * @return A slice of the variable matrix. 328 */ 329 public VariableBlock get(Slice.None rowSlice, int col) { 330 return get(new Slice(rowSlice), new Slice(col)); 331 } 332 333 /** 334 * Returns a slice of the variable matrix. 335 * 336 * @param rowSlice The row slice. 337 * @param col The column. 338 * @return A slice of the variable matrix. 339 */ 340 public VariableBlock get(Slice rowSlice, int col) { 341 return get(rowSlice, new Slice(col)); 342 } 343 344 /** 345 * Returns a slice of the variable matrix. 346 * 347 * @param rowSlice The row slice. 348 * @param colSlice The column slice. 349 * @return A slice of the variable matrix. 350 */ 351 public VariableBlock get(Slice.None rowSlice, Slice.None colSlice) { 352 return get(new Slice(rowSlice), new Slice(colSlice)); 353 } 354 355 /** 356 * Returns a slice of the variable matrix. 357 * 358 * @param rowSlice The row slice. 359 * @param colSlice The column slice. 360 * @return A slice of the variable matrix. 361 */ 362 public VariableBlock get(Slice.None rowSlice, Slice colSlice) { 363 return get(new Slice(rowSlice), colSlice); 364 } 365 366 /** 367 * Returns a slice of the variable matrix. 368 * 369 * @param rowSlice The row slice. 370 * @param colSlice The column slice. 371 * @return A slice of the variable matrix. 372 */ 373 public VariableBlock get(Slice rowSlice, Slice.None colSlice) { 374 return get(rowSlice, new Slice(colSlice)); 375 } 376 377 /** 378 * Returns a slice of the variable matrix. 379 * 380 * @param rowSlice The row slice. 381 * @param colSlice The column slice. 382 * @return A slice of the variable matrix. 383 */ 384 public VariableBlock get(Slice rowSlice, Slice colSlice) { 385 int rowSliceLength = rowSlice.adjust(m_rowSliceLength); 386 int colSliceLength = colSlice.adjust(m_colSliceLength); 387 return new VariableBlock( 388 m_mat, 389 new Slice( 390 m_rowSlice.start + rowSlice.start * m_rowSlice.step, 391 m_rowSlice.start + rowSlice.stop * m_rowSlice.step, 392 rowSlice.step * m_rowSlice.step), 393 rowSliceLength, 394 new Slice( 395 m_colSlice.start + colSlice.start * m_colSlice.step, 396 m_colSlice.start + colSlice.stop * m_colSlice.step, 397 colSlice.step * m_colSlice.step), 398 colSliceLength); 399 } 400 401 /** 402 * Returns a block of the variable matrix. 403 * 404 * @param rowOffset The row offset of the block selection. 405 * @param colOffset The column offset of the block selection. 406 * @param blockRows The number of rows in the block selection. 407 * @param blockCols The number of columns in the block selection. 408 * @return A block of the variable matrix. 409 */ 410 public VariableBlock block(int rowOffset, int colOffset, int blockRows, int blockCols) { 411 assert rowOffset >= 0 && rowOffset <= rows(); 412 assert colOffset >= 0 && colOffset <= cols(); 413 assert blockRows >= 0 && blockRows <= rows() - rowOffset; 414 assert blockCols >= 0 && blockCols <= cols() - colOffset; 415 return get( 416 new Slice(rowOffset, rowOffset + blockRows, 1), 417 new Slice(colOffset, colOffset + blockCols, 1)); 418 } 419 420 /** 421 * Returns a segment of the variable vector. 422 * 423 * @param offset The offset of the segment. 424 * @param length The length of the segment. 425 * @return A segment of the variable vector. 426 */ 427 public VariableBlock segment(int offset, int length) { 428 assert cols() == 1; 429 assert offset >= 0 && offset < rows(); 430 assert length >= 0 && length <= rows() - offset; 431 return block(offset, 0, length, 1); 432 } 433 434 /** 435 * Returns a row slice of the variable matrix. 436 * 437 * @param row The row to slice. 438 * @return A row slice of the variable matrix. 439 */ 440 public VariableBlock row(int row) { 441 assert row >= 0 && row < rows(); 442 return block(row, 0, 1, cols()); 443 } 444 445 /** 446 * Returns a column slice of the variable matrix. 447 * 448 * @param col The column to slice. 449 * @return A column slice of the variable matrix. 450 */ 451 public VariableBlock col(int col) { 452 assert col >= 0 && col < cols(); 453 return block(0, col, rows(), 1); 454 } 455 456 /** 457 * Matrix multiplication operator. 458 * 459 * @param rhs Operator right-hand side. 460 * @return Result of matrix multiplication. 461 */ 462 public VariableMatrix times(VariableMatrix rhs) { 463 assert cols() == rhs.rows(); 464 465 var result = new VariableMatrix(rows(), rhs.cols()); 466 467 for (int i = 0; i < rows(); ++i) { 468 for (int j = 0; j < rhs.cols(); ++j) { 469 var sum = new Variable(0.0); 470 for (int k = 0; k < cols(); ++k) { 471 sum = sum.plus(get(i, k).times(rhs.get(k, j))); 472 } 473 result.set(i, j, sum); 474 } 475 } 476 477 return result; 478 } 479 480 /** 481 * Matrix multiplication operator. 482 * 483 * @param rhs Operator right-hand side. 484 * @return Result of matrix multiplication. 485 */ 486 public VariableMatrix times(VariableBlock rhs) { 487 assert cols() == rhs.rows(); 488 489 var result = new VariableMatrix(rows(), rhs.cols()); 490 491 for (int i = 0; i < rows(); ++i) { 492 for (int j = 0; j < rhs.cols(); ++j) { 493 var sum = new Variable(0.0); 494 for (int k = 0; k < cols(); ++k) { 495 sum = sum.plus(get(i, k).times(rhs.get(k, j))); 496 } 497 result.set(i, j, sum); 498 } 499 } 500 501 return result; 502 } 503 504 /** 505 * Matrix-scalar multiplication operator. 506 * 507 * @param rhs Operator right-hand side. 508 * @return Result of matrix-scalar multiplication. 509 */ 510 public VariableMatrix times(double rhs) { 511 return times(new Variable(rhs)); 512 } 513 514 /** 515 * Matrix-scalar multiplication operator. 516 * 517 * @param rhs Operator right-hand side. 518 * @return Result of matrix-scalar multiplication. 519 */ 520 public VariableMatrix times(Variable rhs) { 521 var result = new VariableMatrix(rows(), cols()); 522 523 for (int row = 0; row < result.rows(); ++row) { 524 for (int col = 0; col < result.cols(); ++col) { 525 result.set(row, col, get(row, col).times(rhs)); 526 } 527 } 528 529 return result; 530 } 531 532 /** 533 * Binary division operator. 534 * 535 * @param rhs Operator right-hand side. 536 * @return Result of division. 537 */ 538 public VariableMatrix div(double rhs) { 539 return div(new Variable(rhs)); 540 } 541 542 /** 543 * Binary division operator. 544 * 545 * @param rhs Operator right-hand side. 546 * @return Result of division. 547 */ 548 public VariableMatrix div(Variable rhs) { 549 var result = new VariableMatrix(rows(), cols()); 550 551 for (int row = 0; row < result.rows(); ++row) { 552 for (int col = 0; col < result.cols(); ++col) { 553 result.set(row, col, get(row, col).div(rhs)); 554 } 555 } 556 557 return result; 558 } 559 560 /** 561 * Binary addition operator. 562 * 563 * @param rhs Operator right-hand side. 564 * @return Result of addition. 565 */ 566 public VariableMatrix plus(VariableMatrix rhs) { 567 assert rows() == rhs.rows() && cols() == rhs.cols(); 568 569 var result = new VariableMatrix(rows(), cols()); 570 571 for (int row = 0; row < result.rows(); ++row) { 572 for (int col = 0; col < result.cols(); ++col) { 573 result.set(row, col, get(row, col).plus(rhs.get(row, col))); 574 } 575 } 576 577 return result; 578 } 579 580 /** 581 * Binary addition operator. 582 * 583 * @param rhs Operator right-hand side. 584 * @return Result of addition. 585 */ 586 public VariableMatrix plus(VariableBlock rhs) { 587 assert rows() == rhs.rows() && cols() == rhs.cols(); 588 589 var result = new VariableMatrix(rows(), cols()); 590 591 for (int row = 0; row < result.rows(); ++row) { 592 for (int col = 0; col < result.cols(); ++col) { 593 result.set(row, col, get(row, col).plus(rhs.get(row, col))); 594 } 595 } 596 597 return result; 598 } 599 600 /** 601 * Binary addition operator. 602 * 603 * @param rhs Operator right-hand side. 604 * @return Result of addition. 605 */ 606 public VariableMatrix plus(SimpleMatrix rhs) { 607 assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols(); 608 609 var result = new VariableMatrix(rows(), cols()); 610 611 for (int row = 0; row < result.rows(); ++row) { 612 for (int col = 0; col < result.cols(); ++col) { 613 result.set(row, col, get(row, col).plus(rhs.get(row, col))); 614 } 615 } 616 617 return result; 618 } 619 620 /** 621 * Binary subtraction operator. 622 * 623 * @param rhs Operator right-hand side. 624 * @return Result of subtraction. 625 */ 626 public VariableMatrix minus(VariableMatrix rhs) { 627 assert rows() == rhs.rows() && cols() == rhs.cols(); 628 629 var result = new VariableMatrix(rows(), cols()); 630 631 for (int row = 0; row < result.rows(); ++row) { 632 for (int col = 0; col < result.cols(); ++col) { 633 result.set(row, col, get(row, col).minus(rhs.get(row, col))); 634 } 635 } 636 637 return result; 638 } 639 640 /** 641 * Binary subtraction operator. 642 * 643 * @param rhs Operator right-hand side. 644 * @return Result of subtraction. 645 */ 646 public VariableMatrix minus(VariableBlock rhs) { 647 assert rows() == rhs.rows() && cols() == rhs.cols(); 648 649 var result = new VariableMatrix(rows(), cols()); 650 651 for (int row = 0; row < result.rows(); ++row) { 652 for (int col = 0; col < result.cols(); ++col) { 653 result.set(row, col, get(row, col).minus(rhs.get(row, col))); 654 } 655 } 656 657 return result; 658 } 659 660 /** 661 * Binary subtraction operator. 662 * 663 * @param rhs Operator right-hand side. 664 * @return Result of subtraction. 665 */ 666 public VariableMatrix minus(SimpleMatrix rhs) { 667 assert rows() == rhs.getNumRows() && cols() == rhs.getNumCols(); 668 669 var result = new VariableMatrix(rows(), cols()); 670 671 for (int row = 0; row < result.rows(); ++row) { 672 for (int col = 0; col < result.cols(); ++col) { 673 result.set(row, col, get(row, col).minus(rhs.get(row, col))); 674 } 675 } 676 677 return result; 678 } 679 680 /** 681 * Unary minus operator. 682 * 683 * @return Result of unary minus. 684 */ 685 public VariableMatrix unaryMinus() { 686 var result = new VariableMatrix(rows(), cols()); 687 688 for (int row = 0; row < result.rows(); ++row) { 689 for (int col = 0; col < result.cols(); ++col) { 690 result.set(row, col, get(row, col).unaryMinus()); 691 } 692 } 693 694 return result; 695 } 696 697 /** 698 * Returns the transpose of the variable matrix. 699 * 700 * @return The transpose of the variable matrix. 701 */ 702 public VariableMatrix T() { 703 var result = new VariableMatrix(cols(), rows()); 704 705 for (int row = 0; row < rows(); ++row) { 706 for (int col = 0; col < cols(); ++col) { 707 result.set(col, row, get(row, col)); 708 } 709 } 710 711 return result; 712 } 713 714 /** 715 * Returns the number of rows in the matrix. 716 * 717 * @return The number of rows in the matrix. 718 */ 719 public int rows() { 720 return m_rowSliceLength; 721 } 722 723 /** 724 * Returns the number of columns in the matrix. 725 * 726 * @return The number of columns in the matrix. 727 */ 728 public int cols() { 729 return m_colSliceLength; 730 } 731 732 /** 733 * Returns an element of the variable matrix. 734 * 735 * @param row The row of the element to return. 736 * @param col The column of the element to return. 737 * @return An element of the variable matrix. 738 */ 739 public double value(int row, int col) { 740 return get(row, col).value(); 741 } 742 743 /** 744 * Returns an element of the variable block. 745 * 746 * @param index The index of the element to return. 747 * @return An element of the variable block. 748 */ 749 public double value(int index) { 750 return get(index).value(); 751 } 752 753 /** 754 * Returns the contents of the variable matrix. 755 * 756 * @return The contents of the variable matrix. 757 */ 758 public SimpleMatrix value() { 759 var result = new SimpleMatrix(rows(), cols()); 760 761 for (int row = 0; row < rows(); ++row) { 762 for (int col = 0; col < cols(); ++col) { 763 result.set(row, col, value(row, col)); 764 } 765 } 766 767 return result; 768 } 769 770 /** 771 * Maps the matrix coefficient-wise with an unary operator. 772 * 773 * @param unaryOp The unary operator to use for the map operation. 774 * @return Result of the unary operator. 775 */ 776 public VariableMatrix cwiseMap(UnaryOperator<Variable> unaryOp) { 777 var result = new VariableMatrix(rows(), cols()); 778 779 for (int row = 0; row < rows(); ++row) { 780 for (int col = 0; col < cols(); ++col) { 781 result.set(row, col, unaryOp.apply(get(row, col))); 782 } 783 } 784 785 return result; 786 } 787 788 /** 789 * Returns number of elements in matrix. 790 * 791 * @return Number of elements in matrix. 792 */ 793 public int size() { 794 return rows() * cols(); 795 } 796 797 @Override 798 public Iterator<Variable> iterator() { 799 return new Iterator<>() { 800 private int m_index = 0; 801 802 @Override 803 public boolean hasNext() { 804 return m_index < VariableBlock.this.size(); 805 } 806 807 @Override 808 public Variable next() { 809 if (!hasNext()) { 810 throw new NoSuchElementException(); 811 } 812 813 return VariableBlock.this.get(m_index++); 814 } 815 }; 816 } 817 818 /** 819 * Creates a Stream of VariableBlock elements. 820 * 821 * @return A Stream of VariableBlock elements. 822 */ 823 public Stream<Variable> stream() { 824 return StreamSupport.stream(spliterator(), false); 825 } 826}