Skip to content

Commit 9bb58b2

Browse files
committed
Added universal comparison methods to vectors
1 parent 1ff561f commit 9bb58b2

File tree

6 files changed

+1036
-52
lines changed

6 files changed

+1036
-52
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
- Unreleased
22
- Added ref using row elimination method
3+
- Added universal comparison methods to tensor API
34

45
- 1.0.1
56
- Added Column Vector

src/ColumnVector.php

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,4 +243,164 @@ protected function modMatrix(Matrix $b) : Matrix
243243

244244
return Matrix::quick($c);
245245
}
246+
247+
/**
248+
* Return the element-wise equality comparison of this column vector
249+
* and a matrix.
250+
*
251+
* @param \Rubix\Tensor\Matrix $b
252+
* @throws \Rubix\Tensor\Exceptions\DimensionalityMismatchException
253+
* @return \Rubix\Tensor\Matrix
254+
*/
255+
protected function equalMatrix(Matrix $b) : Matrix
256+
{
257+
if ($this->n !== $b->m()) {
258+
throw new DimensionalityMismatchException("Vector A requires"
259+
. " $this->n rows but Matrix B has {$b->m()}.");
260+
}
261+
262+
$c = [];
263+
264+
foreach ($b as $i => $row) {
265+
$valueA = $this->a[$i];
266+
267+
$temp = [];
268+
269+
foreach ($row as $valueB) {
270+
$temp[] = $valueA == $valueB ? 1 : 0;
271+
}
272+
273+
$c[] = $temp;
274+
}
275+
276+
return Matrix::quick($c);
277+
}
278+
279+
/**
280+
* Return the element-wise greater than comparison of this column
281+
* vector and a matrix.
282+
*
283+
* @param \Rubix\Tensor\Matrix $b
284+
* @throws \Rubix\Tensor\Exceptions\DimensionalityMismatchException
285+
* @return \Rubix\Tensor\Matrix
286+
*/
287+
protected function greaterMatrix(Matrix $b) : Matrix
288+
{
289+
if ($this->n !== $b->m()) {
290+
throw new DimensionalityMismatchException("Vector A requires"
291+
. " $this->n rows but Matrix B has {$b->m()}.");
292+
}
293+
294+
$c = [];
295+
296+
foreach ($b as $i => $row) {
297+
$valueA = $this->a[$i];
298+
299+
$temp = [];
300+
301+
foreach ($row as $valueB) {
302+
$temp[] = $valueA > $valueB ? 1 : 0;
303+
}
304+
305+
$c[] = $temp;
306+
}
307+
308+
return Matrix::quick($c);
309+
}
310+
311+
/**
312+
* Return the element-wise greater than or equal to comparison of
313+
* this column vector and a matrix.
314+
*
315+
* @param \Rubix\Tensor\Matrix $b
316+
* @throws \Rubix\Tensor\Exceptions\DimensionalityMismatchException
317+
* @return \Rubix\Tensor\Matrix
318+
*/
319+
protected function greaterEqualMatrix(Matrix $b) : Matrix
320+
{
321+
if ($this->n !== $b->m()) {
322+
throw new DimensionalityMismatchException("Vector A requires"
323+
. " $this->n rows but Matrix B has {$b->m()}.");
324+
}
325+
326+
$c = [];
327+
328+
foreach ($b as $i => $row) {
329+
$valueA = $this->a[$i];
330+
331+
$temp = [];
332+
333+
foreach ($row as $valueB) {
334+
$temp[] = $valueA >= $valueB ? 1 : 0;
335+
}
336+
337+
$c[] = $temp;
338+
}
339+
340+
return Matrix::quick($c);
341+
}
342+
343+
/**
344+
* Return the element-wise less than comparison of this column
345+
* vector and a matrix.
346+
*
347+
* @param \Rubix\Tensor\Matrix $b
348+
* @throws \Rubix\Tensor\Exceptions\DimensionalityMismatchException
349+
* @return \Rubix\Tensor\Matrix
350+
*/
351+
protected function lessMatrix(Matrix $b) : Matrix
352+
{
353+
if ($this->n !== $b->m()) {
354+
throw new DimensionalityMismatchException("Vector A requires"
355+
. " $this->n rows but Matrix B has {$b->m()}.");
356+
}
357+
358+
$c = [];
359+
360+
foreach ($b as $i => $row) {
361+
$valueA = $this->a[$i];
362+
363+
$temp = [];
364+
365+
foreach ($row as $valueB) {
366+
$temp[] = $valueA < $valueB ? 1 : 0;
367+
}
368+
369+
$c[] = $temp;
370+
}
371+
372+
return Matrix::quick($c);
373+
}
374+
375+
/**
376+
* Return the element-wise less than or equal to comparison of
377+
* this column vector and a matrix.
378+
*
379+
* @param \Rubix\Tensor\Matrix $b
380+
* @throws \Rubix\Tensor\Exceptions\DimensionalityMismatchException
381+
* @return \Rubix\Tensor\Matrix
382+
*/
383+
protected function lessEqualMatrix(Matrix $b) : Matrix
384+
{
385+
if ($this->n !== $b->m()) {
386+
throw new DimensionalityMismatchException("Vector A requires"
387+
. " $this->n rows but Matrix B has {$b->m()}.");
388+
}
389+
390+
$c = [];
391+
392+
foreach ($b as $i => $row) {
393+
$valueA = $this->a[$i];
394+
395+
$temp = [];
396+
397+
foreach ($row as $valueB) {
398+
$temp[] = $valueA <= $valueB ? 1 : 0;
399+
}
400+
401+
$c[] = $temp;
402+
}
403+
404+
return Matrix::quick($c);
405+
}
246406
}

src/Tensor.php

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ public function reduce(callable $fn, $initial = 0);
4848
* element-wise.
4949
*
5050
* @param mixed $b
51-
* @throws \InvalidArgumentException
5251
* @return mixed
5352
*/
5453
public function multiply($b);
@@ -58,7 +57,6 @@ public function multiply($b);
5857
* element-wise.
5958
*
6059
* @param mixed $b
61-
* @throws \InvalidArgumentException
6260
* @return mixed
6361
*/
6462
public function divide($b);
@@ -68,7 +66,6 @@ public function divide($b);
6866
* element-wise.
6967
*
7068
* @param mixed $b
71-
* @throws \InvalidArgumentException
7269
* @return mixed
7370
*/
7471
public function add($b);
@@ -78,7 +75,6 @@ public function add($b);
7875
* element-wise.
7976
*
8077
* @param mixed $b
81-
* @throws \InvalidArgumentException
8278
* @return mixed
8379
*/
8480
public function subtract($b);
@@ -88,7 +84,6 @@ public function subtract($b);
8884
* tensor element-wise.
8985
*
9086
* @param mixed $b
91-
* @throws \InvalidArgumentException
9287
* @return mixed
9388
*/
9489
public function pow($b);
@@ -98,11 +93,55 @@ public function pow($b);
9893
* and another tensor element-wise.
9994
*
10095
* @param mixed $b
101-
* @throws \InvalidArgumentException
10296
* @return mixed
10397
*/
10498
public function mod($b);
10599

100+
/**
101+
* A universal function to compute the equality comparison of a tensor
102+
* and another tensor element-wise.
103+
*
104+
* @param mixed $b
105+
* @return mixed
106+
*/
107+
public function equal($b);
108+
109+
/**
110+
* A universal function to compute the greater than comparison of a
111+
* tensor and another tensor element-wise.
112+
*
113+
* @param mixed $b
114+
* @return mixed
115+
*/
116+
public function greater($b);
117+
118+
/**
119+
* A universal function to compute the greater than or equal to
120+
* comparison of a tensor and another tensor element-wise.
121+
*
122+
* @param mixed $b
123+
* @return mixed
124+
*/
125+
public function greaterEqual($b);
126+
127+
/**
128+
* A universal function to compute the less than comparison of a
129+
* tensor and another tensor element-wise.
130+
*
131+
* @param mixed $b
132+
* @return mixed
133+
*/
134+
public function less($b);
135+
136+
/**
137+
* A universal function to compute the less than or equal to
138+
* comparison of a tensor and another tensor element-wise.
139+
*
140+
* @param mixed $b
141+
* @return mixed
142+
*/
143+
public function lessEqual($b);
144+
106145
/**
107146
* Take the absolute value of the tensor.
108147
*

0 commit comments

Comments
 (0)