Skip to content

Commit 7e6a5f9

Browse files
v923zCallumJHays
andauthored
Update snippets/rclass.py
Co-authored-by: Cal Hays <[email protected]>
1 parent 4701338 commit 7e6a5f9

File tree

1 file changed

+39
-37
lines changed

1 file changed

+39
-37
lines changed

snippets/rclass.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
1-
from ulab import numpy as np
21
from typing import List, Tuple, Union # upip.install("pycopy-typing")
2+
from ulab import numpy as np
33

4-
ndarray = np.array
54
_DType = int
6-
RClassKeyType = Union[slice, int, float]
5+
_RClassKeyType = Union[slice, int, float, list, tuple, np.ndarray]
76

87
# this is a stripped down version of RClass (used by np.r_[...etc])
98
# it doesn't include support for string arguments as the first index element
109
class RClass:
1110

12-
def __getitem__(self, key: Union[RClassKeyType, Tuple[RClassKeyType, ...]]):
11+
def __getitem__(self, key: Union[_RClassKeyType, Tuple[_RClassKeyType, ...]]):
1312

1413
if not isinstance(key, tuple):
1514
key = (key,)
1615

17-
objs: List[ndarray] = []
16+
objs: List[np.ndarray] = []
1817
scalars: List[int] = []
1918
arraytypes: List[_DType] = []
2019
scalartypes: List[_DType] = []
@@ -24,50 +23,53 @@ def __getitem__(self, key: Union[RClassKeyType, Tuple[RClassKeyType, ...]]):
2423

2524
for idx, item in enumerate(key):
2625
scalar = False
27-
if isinstance(item, slice):
28-
step = item.step
29-
start = item.start
30-
stop = item.stop
31-
if start is None:
32-
start = 0
33-
if step is None:
34-
nstep = 1
35-
if isinstance(step, complex):
36-
size = int(abs(step))
37-
newobj = cast(ndarray, linspace(start, stop, num=size))
38-
else:
39-
newobj = np.arange(start, stop, step)
40-
41-
# if is number
42-
elif isinstance(item, (int, float)):
43-
newobj = np.array([item])
44-
scalars.append(len(objs))
45-
scalar = True
46-
scalartypes.append(newobj.dtype())
4726

48-
else:
27+
try:
28+
if isinstance(item, np.ndarray):
29+
newobj = item
30+
31+
elif isinstance(item, slice):
32+
step = item.step
33+
start = item.start
34+
stop = item.stop
35+
if start is None:
36+
start = 0
37+
if step is None:
38+
step = 1
39+
if isinstance(step, complex):
40+
size = int(abs(step))
41+
newobj: np.ndarray = np.linspace(start, stop, num=size)
42+
else:
43+
newobj = np.arange(start, stop, step)
44+
45+
# if is number
46+
elif isinstance(item, (int, float, bool)):
47+
newobj = np.array([item])
48+
scalars.append(len(objs))
49+
scalar = True
50+
scalartypes.append(newobj.dtype())
51+
52+
else:
53+
newobj = np.array(item)
54+
55+
except TypeError:
4956
raise Exception("index object %s of type %s is not supported by r_[]" % (
5057
str(item), type(item)))
5158

5259
objs.append(newobj)
53-
if not scalar and isinstance(newobj, ndarray):
60+
if not scalar and isinstance(newobj, np.ndarray):
5461
arraytypes.append(newobj.dtype())
5562

5663
# Ensure that scalars won't up-cast unless warranted
57-
# TODO: ensure that this actually works for dtype coercion
58-
# likelihood is we're going to have to do some funky logic for this
59-
final_dtype = max(arraytypes + scalartypes)
60-
if final_dtype is not None:
61-
for idx in scalars:
64+
final_dtype = min(arraytypes + scalartypes)
65+
for idx, obj in enumerate(objs):
66+
if obj.dtype != final_dtype:
6267
objs[idx] = np.array(objs[idx], dtype=final_dtype)
6368

64-
res = np.concatenate(tuple(objs), axis=axis)
65-
66-
return res
69+
return np.concatenate(tuple(objs), axis=axis)
6770

6871
# this seems weird - not sure what it's for
6972
def __len__(self):
7073
return 0
71-
72-
74+
7375
r_ = RClass()

0 commit comments

Comments
 (0)