@@ -449,7 +449,7 @@ MP_DEFINE_CONST_FUN_OBJ_0(ndarray_get_printoptions_obj, ndarray_get_printoptions
449
449
#endif
450
450
451
451
mp_obj_t ndarray_get_item (ndarray_obj_t * ndarray , void * array ) {
452
- // returns a proper micropython item from an array
452
+ // returns a proper micropython object from an array
453
453
if (!ndarray -> boolean ) {
454
454
return mp_binary_get_val_array (ndarray -> dtype , array , 0 );
455
455
} else {
@@ -565,7 +565,7 @@ void ndarray_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t ki
565
565
}
566
566
567
567
void ndarray_assign_elements (ndarray_obj_t * ndarray , mp_obj_t iterable , uint8_t dtype , size_t * idx ) {
568
- // assigns a single row in the matrix
568
+ // assigns a single row in the tensor
569
569
mp_obj_t item ;
570
570
if (ndarray -> boolean ) {
571
571
uint8_t * array = (uint8_t * )ndarray -> array ;
@@ -588,6 +588,7 @@ void ndarray_assign_elements(ndarray_obj_t *ndarray, mp_obj_t iterable, uint8_t
588
588
bool ndarray_is_dense (ndarray_obj_t * ndarray ) {
589
589
// returns true, if the array is dense, false otherwise
590
590
// the array should be dense, if the very first stride can be calculated from shape
591
+ // TODO: this function could probably be removed
591
592
int32_t stride = ndarray -> itemsize ;
592
593
for (uint8_t i = ULAB_MAX_DIMS ; i > ULAB_MAX_DIMS - ndarray -> ndim ; i -- ) {
593
594
stride *= ndarray -> shape [i ];
@@ -836,61 +837,78 @@ STATIC mp_obj_t ndarray_make_new_core(const mp_obj_type_t *type, size_t n_args,
836
837
return MP_OBJ_FROM_PTR (target );
837
838
}
838
839
839
- mp_obj_t len_in = mp_obj_len_maybe (args [0 ]);
840
- size_t len1 = 0 , len2 = 0 ;
841
- if (len_in == MP_OBJ_NULL ) {
842
- mp_raise_ValueError (translate ("first argument must be an iterable" ));
843
- } else {
844
- // len1 is either the number of rows (for matrices), or the number of elements (row vectors)
845
- len1 = MP_OBJ_SMALL_INT_VALUE (len_in );
846
- }
847
- ndarray_obj_t * self ;
848
-
849
- // We have to figure out, whether the first element of the iterable is an iterable itself
850
- // Perhaps, there is a more elegant way of handling this
851
- mp_obj_iter_buf_t iter_buf1 ;
852
- mp_obj_t iterable1 = mp_getiter (args [0 ], & iter_buf1 );
853
- #if ULAB_MAX_DIMS > 1
854
- mp_obj_t item1 ;
855
- size_t i = 0 ;
856
- while ((item1 = mp_iternext (iterable1 )) != MP_OBJ_STOP_ITERATION ) {
857
- len_in = mp_obj_len_maybe (item1 );
858
- if (len_in != MP_OBJ_NULL ) { // indeed, this seems to be an iterable
859
- // Next, we have to check, whether all elements in the outer loop have the same length
860
- if (i > 0 ) {
861
- if (len2 != (size_t )MP_OBJ_SMALL_INT_VALUE (len_in )) {
862
- mp_raise_ValueError (translate ("iterables are not of the same length" ));
863
- }
864
- }
865
- len2 = MP_OBJ_SMALL_INT_VALUE (len_in );
866
- i ++ ;
840
+ // We have to figure out, whether the elements of the iterable are iterables themself
841
+ uint8_t ndim = 0 ;
842
+ size_t shape [ULAB_MAX_DIMS ];
843
+ mp_obj_iter_buf_t iter_buf [ULAB_MAX_DIMS ];
844
+ mp_obj_t iterable [ULAB_MAX_DIMS ];
845
+ // inspect only the very first element in each dimension; this is fast,
846
+ // but not completely safe, e.g., length compatibility is not checked
847
+ mp_obj_t item = args [0 ];
848
+
849
+ while (1 ) {
850
+ if (mp_obj_len_maybe (item ) == MP_OBJ_NULL ) {
851
+ break ;
867
852
}
853
+ if (ndim == ULAB_MAX_DIMS ) {
854
+ mp_raise_ValueError (translate ("too many dimensions" ));
855
+ }
856
+ shape [ndim ] = MP_OBJ_SMALL_INT_VALUE (mp_obj_len_maybe (item ));
857
+ iterable [ndim ] = mp_getiter (item , & iter_buf [ndim ]);
858
+ item = mp_iternext (iterable [ndim ]);
859
+ ndim ++ ;
868
860
}
869
- #endif
870
- // By this time, it should be established, what the shape is, so we can now create the array
871
- if (len2 == 0 ) {
872
- self = ndarray_new_linear_array (len1 , dtype );
861
+ for (uint8_t i = 0 ; i < ndim ; i ++ ) {
862
+ // align all values to the right
863
+ shape [ULAB_MAX_DIMS - i - 1 ] = shape [ndim - 1 - i ];
873
864
}
874
- #if ULAB_MAX_DIMS > 1
875
- else {
876
- self = ndarray_new_dense_ndarray (2 , ndarray_shape_vector (0 , 0 , len1 , len2 ), dtype );
865
+
866
+ ndarray_obj_t * self = ndarray_new_dense_ndarray (ndim , shape , dtype );
867
+ item = args [0 ];
868
+ for (uint8_t i = 0 ; i < ndim - 1 ; i ++ ) {
869
+ // if ndim > 1, descend into the hierarchy
870
+ iterable [ULAB_MAX_DIMS - ndim + i ] = mp_getiter (item , & iter_buf [ULAB_MAX_DIMS - ndim + i ]);
871
+ item = mp_iternext (iterable [ULAB_MAX_DIMS - ndim + i ]);
877
872
}
878
- #endif
873
+
879
874
size_t idx = 0 ;
880
- iterable1 = mp_getiter (args [0 ], & iter_buf1 );
881
- if (len2 == 0 ) { // the first argument is a single iterable
882
- ndarray_assign_elements (self , iterable1 , self -> dtype , & idx );
883
- }
884
- #if ULAB_MAX_DIMS > 1
885
- else {
886
- mp_obj_iter_buf_t iter_buf2 ;
887
- mp_obj_t iterable2 ;
888
- while ((item1 = mp_iternext (iterable1 )) != MP_OBJ_STOP_ITERATION ) {
889
- iterable2 = mp_getiter (item1 , & iter_buf2 );
890
- ndarray_assign_elements (self , iterable2 , self -> dtype , & idx );
875
+ // TODO: this could surely be done in a more elegant way...
876
+ #if ULAB_MAX_DIMS > 3
877
+ do {
878
+ #endif
879
+ #if ULAB_MAX_DIMS > 2
880
+ do {
881
+ #endif
882
+ #if ULAB_MAX_DIMS > 1
883
+ do {
884
+ #endif
885
+ iterable [ULAB_MAX_DIMS - 1 ] = mp_getiter (item , & iter_buf [ULAB_MAX_DIMS - 1 ]);
886
+ ndarray_assign_elements (self , iterable [ULAB_MAX_DIMS - 1 ], self -> dtype , & idx );
887
+ #if ULAB_MAX_DIMS > 1
888
+ item = ndim > 1 ? mp_iternext (iterable [ULAB_MAX_DIMS - 2 ]) : MP_OBJ_STOP_ITERATION ;
889
+ } while (item != MP_OBJ_STOP_ITERATION );
890
+ #endif
891
+ #if ULAB_MAX_DIMS > 2
892
+ item = ndim > 2 ? mp_iternext (iterable [ULAB_MAX_DIMS - 3 ]) : MP_OBJ_STOP_ITERATION ;
893
+ if (item != MP_OBJ_STOP_ITERATION ) {
894
+ iterable [ULAB_MAX_DIMS - 2 ] = mp_getiter (item , & iter_buf [ULAB_MAX_DIMS - 2 ]);
895
+ item = mp_iternext (iterable [ULAB_MAX_DIMS - 2 ]);
896
+ } else {
897
+ iterable [ULAB_MAX_DIMS - 2 ] = MP_OBJ_STOP_ITERATION ;
898
+ }
899
+ } while (iterable [ULAB_MAX_DIMS - 2 ] != MP_OBJ_STOP_ITERATION );
900
+ #endif
901
+ #if ULAB_MAX_DIMS > 3
902
+ item = ndim > 3 ? mp_iternext (iterable [ULAB_MAX_DIMS - 4 ]) : MP_OBJ_STOP_ITERATION ;
903
+ if (item != MP_OBJ_STOP_ITERATION ) {
904
+ iterable [ULAB_MAX_DIMS - 3 ] = mp_getiter (item , & iter_buf [ULAB_MAX_DIMS - 3 ]);
905
+ item = mp_iternext (iterable [ULAB_MAX_DIMS - 3 ]);
906
+ } else {
907
+ iterable [ULAB_MAX_DIMS - 3 ] = MP_OBJ_STOP_ITERATION ;
891
908
}
892
- }
909
+ } while ( iterable [ ULAB_MAX_DIMS - 3 ] != MP_OBJ_STOP_ITERATION );
893
910
#endif
911
+
894
912
return MP_OBJ_FROM_PTR (self );
895
913
}
896
914
0 commit comments