Skip to content

Commit 1765d6f

Browse files
committed
Improved error messages for matrix dimensionality mismatch
1 parent 84568f9 commit 1765d6f

File tree

7 files changed

+128
-112
lines changed

7 files changed

+128
-112
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
- Added transpose to Tensor API
33
- Reduced memory footprint of matmul operation
44
- Removed magic getters
5+
- Added shape string method to Tensor API
6+
- Improved error messages for matrix dimensionality mismatch
57

68
- 1.0.3
79
- Added clip upper and lower bounds

src/Matrix.php

Lines changed: 86 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ public static function gaussian(int $m, int $n) : self
255255
for ($i = 0; $i < $m; $i++) {
256256
$row = [];
257257

258-
if (!empty($extras)) {
258+
if ($extras) {
259259
$row[] = array_pop($extras);
260260
}
261261

@@ -265,8 +265,10 @@ public static function gaussian(int $m, int $n) : self
265265

266266
$r = sqrt(-2. * log($r1));
267267

268-
$row[] = $r * sin($r2 * TWO_PI);
269-
$row[] = $r * cos($r2 * TWO_PI);
268+
$phi = $r2 * TWO_PI;
269+
270+
$row[] = $r * sin($phi);
271+
$row[] = $r * cos($phi);
270272
}
271273

272274
if (count($row) > $n) {
@@ -443,6 +445,16 @@ public function shape() : array
443445
return [$this->m, $this->n];
444446
}
445447

448+
/**
449+
* Return the shape of the tensor as a string.
450+
*
451+
* @return string
452+
*/
453+
public function shapeString() : string
454+
{
455+
return (string) $this->m . ' x ' . (string) $this->n;
456+
}
457+
446458
/**
447459
* Is this a square matrix?
448460
*
@@ -674,21 +686,20 @@ public function reduce(callable $fn, $initial = 0)
674686
}
675687

676688
/**
677-
* Transpose the matrix i.e row become columns and columns
678-
* become rows.
689+
* Transpose the matrix i.e row become columns and columns become rows.
679690
*
680691
* @return self
681692
*/
682693
public function transpose() : self
683694
{
684695
if ($this->m > 1) {
685-
$b = array_map(null, ...$this->a);
686-
} else {
687-
$b = [];
696+
return self::quick(array_map(null, ...$this->a));
697+
}
698+
699+
$b = [];
688700

689-
for ($i = 0; $i < $this->n; $i++) {
690-
$b[] = array_column($this->a, $i);
691-
}
701+
for ($i = 0; $i < $this->n; $i++) {
702+
$b[] = array_column($this->a, $i);
692703
}
693704

694705
return self::quick($b);
@@ -788,8 +799,8 @@ public function matmul(Matrix $b) : self
788799
. " $this->n rows but Matrix B has {$b->m()}.");
789800
}
790801

791-
$p = $b->n();
792802
$bHat = $b->asArray();
803+
$p = $b->n();
793804

794805
$c = [];
795806

@@ -1158,8 +1169,9 @@ public function lu() : array
11581169
public function eig(bool $normalize = true) : array
11591170
{
11601171
if (!$this->isSquare()) {
1161-
throw new RuntimeException('Cannot decompose a non'
1162-
. ' square matrix.');
1172+
throw new RuntimeException('Cannot eigen decompose a non'
1173+
. ' square matrix, ' . implode(' x ', $this->shape())
1174+
. ' matrix given.');
11631175
}
11641176

11651177
$jama = new JAMA($this->a);
@@ -2326,14 +2338,11 @@ public function repeat(int $m = 1, int $n = 1) : self
23262338
*/
23272339
protected function multiplyMatrix(Matrix $b) : self
23282340
{
2329-
if ($b->m() !== $this->m) {
2330-
throw new DimensionalityMismatchException('Matrix A require'
2331-
. " $this->m rows but Matrix B has {$b->m()}.");
2332-
}
2333-
2334-
if ($b->n() !== $this->n) {
2335-
throw new DimensionalityMismatchException('Matrix A requires'
2336-
. " $this->n columns but Matrix B has {$b->n()}.");
2341+
if ($b->shape() !== $this->shape()) {
2342+
throw new DimensionalityMismatchException(
2343+
"{$this->shapeString()}"
2344+
. " matrix needed but {$b->shapeString()} given."
2345+
);
23372346
}
23382347

23392348
$c = [];
@@ -2361,14 +2370,11 @@ protected function multiplyMatrix(Matrix $b) : self
23612370
*/
23622371
protected function divideMatrix(Matrix $b) : self
23632372
{
2364-
if ($b->m() !== $this->m) {
2365-
throw new DimensionalityMismatchException('Matrix A requires'
2366-
. " $this->m rows but Matrix B has {$b->m()}.");
2367-
}
2368-
2369-
if ($b->n() !== $this->n) {
2370-
throw new DimensionalityMismatchException('Matrix A requires'
2371-
. " $this->n columns but Matrix B has {$b->n()}.");
2373+
if ($b->shape() !== $this->shape()) {
2374+
throw new DimensionalityMismatchException(
2375+
"{$this->shapeString()}"
2376+
. " matrix needed but {$b->shapeString()} given."
2377+
);
23722378
}
23732379

23742380
$c = [];
@@ -2396,14 +2402,11 @@ protected function divideMatrix(Matrix $b) : self
23962402
*/
23972403
protected function addMatrix(Matrix $b) : self
23982404
{
2399-
if ($b->m() !== $this->m) {
2400-
throw new DimensionalityMismatchException('Matrix A requires'
2401-
. " $this->m rows but Matrix B has {$b->m()}.");
2402-
}
2403-
2404-
if ($b->n() !== $this->n) {
2405-
throw new DimensionalityMismatchException('Matrix A requires'
2406-
. " $this->n columns but Matrix B has {$b->n()}.");
2405+
if ($b->shape() !== $this->shape()) {
2406+
throw new DimensionalityMismatchException(
2407+
"{$this->shapeString()}"
2408+
. " matrix needed but {$b->shapeString()} given."
2409+
);
24072410
}
24082411

24092412
$c = [];
@@ -2431,14 +2434,11 @@ protected function addMatrix(Matrix $b) : self
24312434
*/
24322435
protected function subtractMatrix(Matrix $b) : self
24332436
{
2434-
if ($b->m() !== $this->m) {
2435-
throw new DimensionalityMismatchException('Matrix A requires'
2436-
. " $this->m rows but Matrix B has {$b->m()}.");
2437-
}
2438-
2439-
if ($b->n() !== $this->n) {
2440-
throw new DimensionalityMismatchException('Matrix A requires'
2441-
. " $this->n columns but Matrix B has {$b->n()}.");
2437+
if ($b->shape() !== $this->shape()) {
2438+
throw new DimensionalityMismatchException(
2439+
"{$this->shapeString()}"
2440+
. " matrix needed but {$b->shapeString()} given."
2441+
);
24422442
}
24432443

24442444
$c = [];
@@ -2467,14 +2467,11 @@ protected function subtractMatrix(Matrix $b) : self
24672467
*/
24682468
protected function powMatrix(Matrix $b) : self
24692469
{
2470-
if ($b->m() !== $this->m) {
2471-
throw new DimensionalityMismatchException('Matrix A requires'
2472-
. " $this->m rows but Matrix B has {$b->m()}.");
2473-
}
2474-
2475-
if ($b->n() !== $this->n) {
2476-
throw new DimensionalityMismatchException('Matrix A requires'
2477-
. " $this->n columns but Matrix B has {$b->n()}.");
2470+
if ($b->shape() !== $this->shape()) {
2471+
throw new DimensionalityMismatchException(
2472+
"{$this->shapeString()}"
2473+
. " matrix needed but {$b->shapeString()} given."
2474+
);
24782475
}
24792476

24802477
$c = [];
@@ -2503,14 +2500,11 @@ protected function powMatrix(Matrix $b) : self
25032500
*/
25042501
protected function modMatrix(Matrix $b) : self
25052502
{
2506-
if ($b->m() !== $this->m) {
2507-
throw new DimensionalityMismatchException('Matrix A requires'
2508-
. " $this->m rows but Matrix B has {$b->m()}.");
2509-
}
2510-
2511-
if ($b->n() !== $this->n) {
2512-
throw new DimensionalityMismatchException('Matrix A requires'
2513-
. " $this->n columns but Matrix B has {$b->n()}.");
2503+
if ($b->shape() !== $this->shape()) {
2504+
throw new DimensionalityMismatchException(
2505+
"{$this->shapeString()}"
2506+
. " matrix needed but {$b->shapeString()} given."
2507+
);
25142508
}
25152509

25162510
$c = [];
@@ -2539,14 +2533,11 @@ protected function modMatrix(Matrix $b) : self
25392533
*/
25402534
protected function equalMatrix(Matrix $b) : self
25412535
{
2542-
if ($b->m() !== $this->m) {
2543-
throw new DimensionalityMismatchException('Matrix A requires'
2544-
. " $this->m rows but Matrix B has {$b->m()}.");
2545-
}
2546-
2547-
if ($b->n() !== $this->n) {
2548-
throw new DimensionalityMismatchException('Matrix A requires'
2549-
. " $this->n columns but Matrix B has {$b->n()}.");
2536+
if ($b->shape() !== $this->shape()) {
2537+
throw new DimensionalityMismatchException(
2538+
"{$this->shapeString()}"
2539+
. " matrix needed but {$b->shapeString()} given."
2540+
);
25502541
}
25512542

25522543
$c = [];
@@ -2575,14 +2566,11 @@ protected function equalMatrix(Matrix $b) : self
25752566
*/
25762567
protected function notEqualMatrix(Matrix $b) : self
25772568
{
2578-
if ($b->m() !== $this->m) {
2579-
throw new DimensionalityMismatchException('Matrix A requires'
2580-
. " $this->m rows but Matrix B has {$b->m()}.");
2581-
}
2582-
2583-
if ($b->n() !== $this->n) {
2584-
throw new DimensionalityMismatchException('Matrix A requires'
2585-
. " $this->n columns but Matrix B has {$b->n()}.");
2569+
if ($b->shape() !== $this->shape()) {
2570+
throw new DimensionalityMismatchException(
2571+
"{$this->shapeString()}"
2572+
. " matrix needed but {$b->shapeString()} given."
2573+
);
25862574
}
25872575

25882576
$c = [];
@@ -2611,14 +2599,11 @@ protected function notEqualMatrix(Matrix $b) : self
26112599
*/
26122600
protected function greaterMatrix(Matrix $b) : self
26132601
{
2614-
if ($b->m() !== $this->m) {
2615-
throw new DimensionalityMismatchException('Matrix A requires'
2616-
. " $this->m rows but Matrix B has {$b->m()}.");
2617-
}
2618-
2619-
if ($b->n() !== $this->n) {
2620-
throw new DimensionalityMismatchException('Matrix A requires'
2621-
. " $this->n columns but Matrix B has {$b->n()}.");
2602+
if ($b->shape() !== $this->shape()) {
2603+
throw new DimensionalityMismatchException(
2604+
"{$this->shapeString()}"
2605+
. " matrix needed but {$b->shapeString()} given."
2606+
);
26222607
}
26232608

26242609
$c = [];
@@ -2647,14 +2632,11 @@ protected function greaterMatrix(Matrix $b) : self
26472632
*/
26482633
protected function greaterEqualMatrix(Matrix $b) : self
26492634
{
2650-
if ($b->m() !== $this->m) {
2651-
throw new DimensionalityMismatchException('Matrix A requires'
2652-
. " $this->m rows but Matrix B has {$b->m()}.");
2653-
}
2654-
2655-
if ($b->n() !== $this->n) {
2656-
throw new DimensionalityMismatchException('Matrix A requires'
2657-
. " $this->n columns but Matrix B has {$b->n()}.");
2635+
if ($b->shape() !== $this->shape()) {
2636+
throw new DimensionalityMismatchException(
2637+
"{$this->shapeString()}"
2638+
. " matrix needed but {$b->shapeString()} given."
2639+
);
26582640
}
26592641

26602642
$c = [];
@@ -2683,14 +2665,11 @@ protected function greaterEqualMatrix(Matrix $b) : self
26832665
*/
26842666
protected function lessMatrix(Matrix $b) : self
26852667
{
2686-
if ($b->m() !== $this->m) {
2687-
throw new DimensionalityMismatchException('Matrix A requires'
2688-
. " $this->m rows but Matrix B has {$b->m()}.");
2689-
}
2690-
2691-
if ($b->n() !== $this->n) {
2692-
throw new DimensionalityMismatchException('Matrix A requires'
2693-
. " $this->n columns but Matrix B has {$b->n()}.");
2668+
if ($b->shape() !== $this->shape()) {
2669+
throw new DimensionalityMismatchException(
2670+
"{$this->shapeString()}"
2671+
. " matrix needed but {$b->shapeString()} given."
2672+
);
26942673
}
26952674

26962675
$c = [];
@@ -2719,14 +2698,11 @@ protected function lessMatrix(Matrix $b) : self
27192698
*/
27202699
protected function lessEqualMatrix(Matrix $b) : self
27212700
{
2722-
if ($b->m() !== $this->m) {
2723-
throw new DimensionalityMismatchException('Matrix A requires'
2724-
. " $this->m rows but Matrix B has {$b->m()}.");
2725-
}
2726-
2727-
if ($b->n() !== $this->n) {
2728-
throw new DimensionalityMismatchException('Matrix A requires'
2729-
. " $this->n columns but Matrix B has {$b->n()}.");
2701+
if ($b->shape() !== $this->shape()) {
2702+
throw new DimensionalityMismatchException(
2703+
"{$this->shapeString()}"
2704+
. " matrix needed but {$b->shapeString()} given."
2705+
);
27302706
}
27312707

27322708
$c = [];

src/Tensor.php

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ interface Tensor extends ArrayAccess, IteratorAggregate, Countable
1515
*/
1616
public function shape() : array;
1717

18+
/**
19+
* Return the shape of the tensor as a string.
20+
*
21+
* @return string
22+
*/
23+
public function shapeString() : string;
24+
1825
/**
1926
* Return the number of elements in the tensor.
2027
*

src/Vector.php

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,10 @@ public static function gaussian(int $n) : self
163163

164164
$r = sqrt(-2. * log($r1));
165165

166-
$a[] = $r * sin($r2 * TWO_PI);
167-
$a[] = $r * cos($r2 * TWO_PI);
166+
$phi = $r2 * TWO_PI;
167+
168+
$a[] = $r * sin($phi);
169+
$a[] = $r * cos($phi);
168170
}
169171

170172
if (count($a) > $n) {
@@ -312,6 +314,16 @@ public function shape() : array
312314
return [$this->n];
313315
}
314316

317+
/**
318+
* Return the shape of the tensor as a string.
319+
*
320+
* @return string
321+
*/
322+
public function shapeString() : string
323+
{
324+
return (string) $this->n;
325+
}
326+
315327
/**
316328
* Return the number of elements in the vector.
317329
*

tests/ColumnVectorTest.php

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ public function test_shape()
4141
$this->assertEquals([3], $this->a->shape());
4242
}
4343

44+
public function test_shape_string()
45+
{
46+
$this->assertEquals('3', $this->a->shapeString());
47+
}
48+
4449
public function test_size()
4550
{
4651
$this->assertEquals(3, $this->a->size());

0 commit comments

Comments
 (0)