@@ -28,27 +28,21 @@ class NdMatrixProxy {
2828 // idim: The dimension associated with this proxy
2929 // dim_stride: The stride of this dimension (i.e. how many element in memory between indicies of this dimension)
3030 // start: Pointer to the start of the sub-matrix this proxy represents
31- NdMatrixProxy<T, N>(const size_t * dim_sizes, size_t idim, size_t dim_stride , T* start)
31+ NdMatrixProxy<T, N>(const size_t * dim_sizes, const size_t * dim_strides , T* start)
3232 : dim_sizes_(dim_sizes)
33- , idim_(idim)
34- , dim_stride_(dim_stride)
33+ , dim_strides_(dim_strides)
3534 , start_(start) {}
3635
3736 const NdMatrixProxy<T, N - 1 > operator [](size_t index) const {
3837 VTR_ASSERT_SAFE_MSG (index >= 0 , " Index out of range (below dimension minimum)" );
39- VTR_ASSERT_SAFE_MSG (index < dim_sizes_[idim_], " Index out of range (above dimension maximum)" );
40-
41- size_t next_dim_size = dim_sizes_[idim_ + 1 ];
42- VTR_ASSERT_SAFE_MSG (next_dim_size > 0 , " Can not index into zero-sized dimension" );
43-
44- // Determine the stride of the next dimension
45- size_t next_dim_stride = dim_stride_ / next_dim_size;
38+ VTR_ASSERT_SAFE_MSG (index < dim_sizes_[0 ], " Index out of range (above dimension maximum)" );
39+ VTR_ASSERT_SAFE_MSG (dim_sizes_[1 ] > 0 , " Can not index into zero-sized dimension" );
4640
4741 // Strip off one dimension
48- return NdMatrixProxy<T, N - 1 >(dim_sizes_, // Pass the dimension information
49- idim_ + 1 , // Pass the next dimension
50- next_dim_stride, // Pass the stride for the next dimension
51- start_ + dim_stride_ * index); // Advance to index in this dimension
42+ return NdMatrixProxy<T, N - 1 >(
43+ dim_sizes_ + 1 , // Pass the dimension information
44+ dim_strides_ + 1 , // Pass the stride for the next dimension
45+ start_ + dim_strides_[ 0 ] * index); // Advance to index in this dimension
5246 }
5347
5448 NdMatrixProxy<T, N - 1 > operator [](size_t index) {
@@ -58,25 +52,23 @@ class NdMatrixProxy {
5852
5953 private:
6054 const size_t * dim_sizes_;
61- const size_t idim_;
62- const size_t dim_stride_;
55+ const size_t * dim_strides_;
6356 T* start_;
6457};
6558
6659// Base case: 1-dimensional array
6760template <typename T>
6861class NdMatrixProxy <T, 1 > {
6962 public:
70- NdMatrixProxy<T, 1 >(const size_t * dim_sizes, size_t idim, size_t dim_stride, T* start)
63+ NdMatrixProxy<T, 1 >(const size_t * dim_sizes, const size_t * dim_stride, T* start)
7164 : dim_sizes_(dim_sizes)
72- , idim_(idim)
73- , dim_stride_(dim_stride)
65+ , dim_strides_(dim_stride)
7466 , start_(start) {}
7567
7668 const T& operator [](size_t index) const {
77- VTR_ASSERT_SAFE_MSG (dim_stride_ == 1 , " Final dimension must have stride 1" );
69+ VTR_ASSERT_SAFE_MSG (dim_strides_[ 0 ] == 1 , " Final dimension must have stride 1" );
7870 VTR_ASSERT_SAFE_MSG (index >= 0 , " Index out of range (below dimension minimum)" );
79- VTR_ASSERT_SAFE_MSG (index < dim_sizes_[idim_ ], " Index out of range (above dimension maximum)" );
71+ VTR_ASSERT_SAFE_MSG (index < dim_sizes_[0 ], " Index out of range (above dimension maximum)" );
8072
8173 // Base case
8274 return start_[index];
@@ -103,8 +95,7 @@ class NdMatrixProxy<T, 1> {
10395
10496 private:
10597 const size_t * dim_sizes_;
106- const size_t idim_;
107- const size_t dim_stride_;
98+ const size_t * dim_strides_;
10899 T* start_;
109100};
110101
@@ -207,12 +198,21 @@ class NdMatrixBase {
207198 size_ = calc_size ();
208199 alloc ();
209200 fill (value);
201+ if (size_ > 0 ) {
202+ dim_strides_[0 ] = size_ / dim_sizes_[0 ];
203+ for (size_t dim = 1 ; dim < N; ++dim) {
204+ dim_strides_[dim] = dim_strides_[dim - 1 ] / dim_sizes_[dim];
205+ }
206+ } else {
207+ dim_strides_.fill (0 );
208+ }
210209 }
211210
212211 // Reset the matrix to size zero
213212 void clear () {
214213 data_.reset (nullptr );
215214 dim_sizes_.fill (0 );
215+ dim_strides_.fill (0 );
216216 size_ = 0 ;
217217 }
218218
@@ -242,6 +242,7 @@ class NdMatrixBase {
242242 using std::swap;
243243 swap (m1.size_ , m2.size_ );
244244 swap (m1.dim_sizes_ , m2.dim_sizes_ );
245+ swap (m1.dim_strides_ , m2.dim_strides_ );
245246 swap (m1.data_ , m2.data_ );
246247 }
247248
@@ -265,6 +266,7 @@ class NdMatrixBase {
265266 protected:
266267 size_t size_ = 0 ;
267268 std::array<size_t , N> dim_sizes_;
269+ std::array<size_t , N> dim_strides_;
268270 std::unique_ptr<T[]> data_ = nullptr ;
269271};
270272
@@ -316,17 +318,11 @@ class NdMatrix : public NdMatrixBase<T, N> {
316318 VTR_ASSERT_SAFE_MSG (index >= 0 , " Index out of range (below dimension minimum)" );
317319 VTR_ASSERT_SAFE_MSG (index < this ->dim_sizes_ [0 ], " Index out of range (above dimension maximum)" );
318320
319- // Calculate the stride for the current dimension
320- size_t dim_stride = this ->size () / this ->dim_size (0 );
321-
322- // Calculate the stride for the next dimension
323- size_t next_dim_stride = dim_stride / this ->dim_size (1 );
324-
325321 // Peel off the first dimension
326- return NdMatrixProxy<T, N - 1 >(this -> dim_sizes_ . data (), // Pass the dimension information
327- 1 , // Pass the next dimension
328- next_dim_stride, // Pass the stride for the next dimension
329- this ->data_ .get () + dim_stride * index); // Advance to index in this dimension
322+ return NdMatrixProxy<T, N - 1 >(
323+ this -> dim_sizes_ . data () + 1 , // Pass the dimension information
324+ this -> dim_strides_ . data () + 1 , // Pass the stride for the next dimension
325+ this ->data_ .get () + this -> dim_strides_ [ 0 ] * index); // Advance to index in this dimension
330326 }
331327
332328 // Access an element
0 commit comments