Skip to content

Commit 4671a09

Browse files
committed
Added convolve operation to matrix and vector
1 parent 644b7d9 commit 4671a09

File tree

5 files changed

+125
-6
lines changed

5 files changed

+125
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
- Unreleased
22
- Added ref using row elimination method
33
- Added universal comparison methods to tensor API
4+
- Added convolve operation to Vector and Matrix
45

56
- 1.0.1
67
- Added Column Vector

src/Matrix.php

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -771,18 +771,22 @@ public function matmul(Matrix $b) : self
771771

772772
$bT = $b->transpose();
773773

774-
$c = [[]];
774+
$c = [];
775+
776+
foreach ($this->a as $row) {
777+
$temp = [];
775778

776-
foreach ($this->a as $i => $row) {
777779
foreach ($bT as $column) {
778780
$sigma = 0;
779781

780-
foreach ($row as $k => $value) {
781-
$sigma += $value * $column[$k];
782+
foreach ($row as $i => $value) {
783+
$sigma += $value * $column[$i];
782784
}
783785

784-
$c[$i][] = $sigma;
786+
$temp[] = $sigma;
785787
}
788+
789+
$c[] = $temp;
786790
}
787791

788792
return self::quick($c);
@@ -798,13 +802,62 @@ public function matmul(Matrix $b) : self
798802
public function dot(Vector $b) : ColumnVector
799803
{
800804
if ($this->n !== $b->size()) {
801-
throw new DimensionalityMismatchException("Matrix A requires"
805+
throw new DimensionalityMismatchException('Matrix A requires'
802806
. " $this->n elements but Vector B has {$b->size()}.");
803807
}
804808

805809
return $this->matmul($b->asColumnMatrix())->columnAsVector(0);
806810
}
807811

812+
/**
813+
* Convolve this matrix with another matrix.
814+
*
815+
* @param \Rubix\Tensor\Matrix $b
816+
* @throws \InvalidArgumentException
817+
* @return self
818+
*/
819+
public function convolve(Matrix $b) : self
820+
{
821+
list($m, $n) = $b->shape();
822+
823+
if ($m > $this->m or $n > $this->n) {
824+
throw new InvalidArgumentException('Matrix B cannot be'
825+
. ' larger than Matrix A.');
826+
}
827+
828+
$p = intdiv($m, 2);
829+
$q = intdiv($n, 2);
830+
831+
$c = [];
832+
833+
for ($i = 0; $i < $this->m; $i++) {
834+
$temp = [];
835+
836+
for ($j = 0; $j < $this->n; $j++) {
837+
$sigma = 0;
838+
839+
foreach ($b as $k => $rowB) {
840+
foreach ($rowB as $l => $valueB) {
841+
$x = $i + ($p - (int) $k);
842+
$y = $j + ($q - (int) $l);
843+
844+
if ($x < 0 or $x >= $this->n or $y < 0 or $y >= $this->m) {
845+
continue 1;
846+
}
847+
848+
$sigma += $this->a[$x][$y] * $valueB;
849+
}
850+
}
851+
852+
$temp[] = $sigma;
853+
}
854+
855+
$c[] = $temp;
856+
}
857+
858+
return self::quick($c);
859+
}
860+
808861
/**
809862
* Calculate the row echelon form (REF) of the matrix. Return the matrix
810863
* and the number of swaps in a tuple.

src/Vector.php

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,35 @@ public function cross(Vector $b) : self
556556
return self::quick($c);
557557
}
558558

559+
/**
560+
* Convolve this vector with another vector.
561+
*
562+
* @param \Rubix\Tensor\Vector $b
563+
* @throws \InvalidArgumentException
564+
* @return self
565+
*/
566+
public function convolve(Vector $b) : self
567+
{
568+
if ($b->size() > $this->n) {
569+
throw new InvalidArgumentException('Vector B cannot be'
570+
. ' longer than Vector A.');
571+
}
572+
573+
$c = [];
574+
575+
for ($i = 0; $i < $this->n; $i++) {
576+
$sigma = 0;
577+
578+
foreach ($b as $j => $valueB) {
579+
$sigma += ($this->a[$i - (int) $j] ?? 0) * $valueB;
580+
}
581+
582+
$c[] = $sigma;
583+
}
584+
585+
return self::quick($c);
586+
}
587+
559588
/**
560589
* Project this vector on another vector.
561590
*

tests/MatrixTest.php

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,32 @@ public function test_matmul()
564564
$this->assertEquals($y, $z->asArray());
565565
}
566566

567+
public function test_convolve()
568+
{
569+
$input = Matrix::quick([
570+
[3, 27, 66, 29, 42, 1],
571+
[5, 9, 15, 42, 45, 16],
572+
[91, 67, 49, 22, 66, 5],
573+
[5, 1, 4, 9, 8, 6, 2],
574+
[22, 16, 18, 19, 21, 25],
575+
[6, 9, 69, 5, 2, 33, 35],
576+
]);
577+
578+
$z = $input->convolve($this->a);
579+
580+
$y = [
581+
[254, 792, 1565, 811, 499, 195],
582+
[540, 2311, 711, 2350, -766, 409],
583+
[1356, 1083, 1304, 992, 478, -584],
584+
[831, 164, -75, 1225, 21, -747],
585+
[392, 1670, -566, 1114, 1036, -412],
586+
[290, 429, 889, 69, 347, 20],
587+
];
588+
589+
$this->assertInstanceOf(Matrix::class, $z);
590+
$this->assertEquals($y, $z->asArray());
591+
}
592+
567593
public function test_multiply_matrix()
568594
{
569595
$z = $this->a->multiply($this->c);

tests/VectorTest.php

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,16 @@ public function test_cross()
331331
$this->assertEquals($y, $z->asArray());
332332
}
333333

334+
public function test_convolve()
335+
{
336+
$z = $this->a->convolve($this->c);
337+
338+
$y = [-60.0, 2.5, 259.0, -144., 40.5, 370.1, 462.20000000000005, 9.999999999999886];
339+
340+
$this->assertInstanceOf(Vector::class, $z);
341+
$this->assertEquals($y, $z->asArray());
342+
}
343+
334344
public function test_project()
335345
{
336346
$z = $this->a->project($this->b);

0 commit comments

Comments
 (0)