Skip to content

Commit b678cd5

Browse files
committed
Merge branch 'master' of https://github.com/RubixML/Tensor
2 parents 69183a7 + 471381e commit b678cd5

File tree

4 files changed

+21
-9
lines changed

4 files changed

+21
-9
lines changed

src/Matrix.php

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -813,10 +813,11 @@ public function dot(Vector $b) : ColumnVector
813813
* Convolve this matrix with another matrix.
814814
*
815815
* @param \Rubix\Tensor\Matrix $b
816+
* @param int $stride
816817
* @throws \InvalidArgumentException
817818
* @return self
818819
*/
819-
public function convolve(Matrix $b) : self
820+
public function convolve(Matrix $b, int $stride = 1) : self
820821
{
821822
list($m, $n) = $b->shape();
822823

@@ -825,21 +826,26 @@ public function convolve(Matrix $b) : self
825826
. ' larger than Matrix A.');
826827
}
827828

829+
if ($stride < 1) {
830+
throw new InvalidArgumentException('Stride cannot be'
831+
. " less than 1, $stride given.");
832+
}
833+
828834
$p = intdiv($m, 2);
829835
$q = intdiv($n, 2);
830836

831837
$c = [];
832838

833-
for ($i = 0; $i < $this->m; $i++) {
839+
for ($i = 0; $i < $this->m; $i += $stride) {
834840
$temp = [];
835841

836-
for ($j = 0; $j < $this->n; $j++) {
842+
for ($j = 0; $j < $this->n; $j += $stride) {
837843
$sigma = 0;
838844

839845
foreach ($b as $k => $rowB) {
840846
foreach ($rowB as $l => $valueB) {
841-
$x = $i + ($p - (int) $k);
842-
$y = $j + ($q - (int) $l);
847+
$x = $i + $p - (int) $k;
848+
$y = $j + $q - (int) $l;
843849

844850
if ($x < 0 or $x >= $this->n or $y < 0 or $y >= $this->m) {
845851
continue 1;

src/Vector.php

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,19 +560,25 @@ public function cross(Vector $b) : self
560560
* Convolve this vector with another vector.
561561
*
562562
* @param \Rubix\Tensor\Vector $b
563+
* @param int $stride
563564
* @throws \InvalidArgumentException
564565
* @return self
565566
*/
566-
public function convolve(Vector $b) : self
567+
public function convolve(Vector $b, int $stride = 1) : self
567568
{
568569
if ($b->size() > $this->n) {
569570
throw new InvalidArgumentException('Vector B cannot be'
570571
. ' longer than Vector A.');
571572
}
572573

574+
if ($stride < 1) {
575+
throw new InvalidArgumentException('Stride cannot be'
576+
. " less than 1, $stride given.");
577+
}
578+
573579
$c = [];
574580

575-
for ($i = 0; $i < $this->n; $i++) {
581+
for ($i = 0; $i < $this->n; $i += $stride) {
576582
$sigma = 0;
577583

578584
foreach ($b as $j => $valueB) {

tests/MatrixTest.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ public function test_convolve()
575575
[6, 9, 69, 5, 2, 33, 35],
576576
]);
577577

578-
$z = $input->convolve($this->a);
578+
$z = $input->convolve($this->a, 1);
579579

580580
$y = [
581581
[254, 792, 1565, 811, 499, 195],

tests/VectorTest.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ public function test_cross()
333333

334334
public function test_convolve()
335335
{
336-
$z = $this->a->convolve($this->c);
336+
$z = $this->a->convolve($this->c, 1);
337337

338338
$y = [-60.0, 2.5, 259.0, -144., 40.5, 370.1, 462.20000000000005, 9.999999999999886];
339339

0 commit comments

Comments
 (0)