@@ -45,7 +45,7 @@ template <typename Fn>
45
45
void apply_on_flat_ix_with_dim_mask_and_base (
46
46
const Fn& fn,
47
47
const Tensor& in,
48
- bool * dim_mask,
48
+ const bool * dim_mask,
49
49
const size_t base,
50
50
const size_t start,
51
51
const size_t end) {
@@ -315,6 +315,92 @@ void apply_over_dim(
315
315
}
316
316
}
317
317
318
+ /* *
319
+ * Execution plan for repeated apply_over_dim_list with the same
320
+ * function, input tensor, dim list, start, and end but varying
321
+ * out_ix, as done (via {map_,}reduce_over_dim_list) in reductions.
322
+ */
323
+ class ApplyOverDimListPlan {
324
+ public:
325
+ ApplyOverDimListPlan (
326
+ const executorch::aten::Tensor& in,
327
+ // If set, lifetime must last until execute() returns.
328
+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
329
+ dim_list,
330
+ const int64_t start = 0 ,
331
+ const int64_t end = -1 )
332
+ : in_(in) {
333
+ ET_CHECK (check_dim_list_is_valid (in, dim_list));
334
+ out_numel_ = get_out_numel (in_, dim_list);
335
+ if (in.numel () == 0 ) {
336
+ mode_ = ExecutionMode::NothingToDo;
337
+ return ;
338
+ }
339
+ const size_t iter_length = get_reduced_dim_product (in, dim_list);
340
+ const size_t normalized_start = ET_NORMALIZE_IX (start, iter_length);
341
+ const size_t normalized_end = ET_NORMALIZE_IX (end, iter_length);
342
+ ustart_ = std::max (normalized_start, size_t (0 ));
343
+ uend_ = std::min (normalized_end, iter_length - 1 );
344
+ if (!dim_list.has_value () || dim_list.value ().size () == 0 ||
345
+ in.dim () == 0 ) {
346
+ mode_ = ExecutionMode::NoDimMaskOrZeroDimension;
347
+ return ;
348
+ }
349
+ dim_list_ = dim_list.value ();
350
+ is_in_dim_list_.fill (0 );
351
+ for (const auto & d : dim_list.value ()) {
352
+ const size_t non_neg_d = d < 0 ? d + in.dim () : d;
353
+ is_in_dim_list_[non_neg_d] = true ;
354
+ }
355
+
356
+ mode_ = ExecutionMode::NormalDimMask;
357
+ }
358
+
359
+ template <typename Fn>
360
+ void execute (const Fn& fn, const size_t out_ix) const {
361
+ ET_CHECK_MSG (out_ix < out_numel_, " Out index %zd is out of bounds" , out_ix);
362
+
363
+ switch (mode_) {
364
+ case ExecutionMode::NothingToDo:
365
+ return ;
366
+ case ExecutionMode::NoDimMaskOrZeroDimension:
367
+ apply_on_flat_ix_with_stride_and_base (
368
+ fn, /* stride=*/ 1 , /* base=*/ 0 , ustart_, uend_);
369
+ return ;
370
+ case ExecutionMode::NormalDimMask:
371
+ apply_on_flat_ix_with_dim_mask_and_base (
372
+ fn,
373
+ in_,
374
+ is_in_dim_list_.data (),
375
+ get_init_index (in_, dim_list_, out_ix),
376
+ ustart_,
377
+ uend_);
378
+ return ;
379
+ }
380
+ }
381
+
382
+ private:
383
+ // Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
384
+ size_t ustart_;
385
+ // End argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
386
+ size_t uend_;
387
+ enum class ExecutionMode {
388
+ // Empty input, no work to do.
389
+ NothingToDo,
390
+ // Iterate over the entire tensor with
391
+ // apply_on_flat_ix_with_stride_and_base.
392
+ NoDimMaskOrZeroDimension,
393
+ // General mode, iterate with
394
+ // apply_on_flat_ix_with_dim_mask_and_base.
395
+ NormalDimMask
396
+ };
397
+ ExecutionMode mode_;
398
+ size_t out_numel_;
399
+ executorch::aten::ArrayRef<int64_t > dim_list_;
400
+ std::array<bool , kTensorDimensionLimit > is_in_dim_list_;
401
+ const executorch::aten::Tensor& in_;
402
+ };
403
+
318
404
/* *
319
405
* Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
320
406
* for the output element at index `out_ix` using the reduce function
@@ -331,42 +417,8 @@ void apply_over_dim_list(
331
417
const size_t out_ix,
332
418
const int64_t start = 0 ,
333
419
const int64_t end = -1 ) {
334
- ET_CHECK (check_dim_list_is_valid (in, dim_list));
335
- ET_CHECK_MSG (
336
- out_ix < get_out_numel (in, dim_list),
337
- " Out index %zd is out of bounds" ,
338
- out_ix);
339
-
340
- if (in.numel () == 0 ) {
341
- return ;
342
- }
343
-
344
- const size_t iter_length = get_reduced_dim_product (in, dim_list);
345
- const size_t normalized_start = ET_NORMALIZE_IX (start, iter_length);
346
- const size_t normalized_end = ET_NORMALIZE_IX (end, iter_length);
347
- const size_t ustart = std::max (normalized_start, size_t (0 ));
348
- const size_t uend = std::min (normalized_end, iter_length - 1 );
349
-
350
- // If dim_list is null or empty, or in is 0-D, iterate over the entire tensor
351
- if (!dim_list.has_value () || dim_list.value ().size () == 0 || in.dim () == 0 ) {
352
- apply_on_flat_ix_with_stride_and_base (
353
- fn, /* stride=*/ 1 , /* base=*/ 0 , ustart, uend);
354
- return ;
355
- }
356
-
357
- // Create is_in_dims to check whether each dimension is in the dim list
358
- bool is_in_dim_list[kTensorDimensionLimit ];
359
- memset (is_in_dim_list, false , sizeof (is_in_dim_list));
360
- for (const auto & d : dim_list.value ()) {
361
- const size_t non_neg_d = d < 0 ? d + in.dim () : d;
362
- is_in_dim_list[non_neg_d] = true ;
363
- }
364
-
365
- // Compute the starting base index
366
- const size_t base = get_init_index (in, dim_list, out_ix);
367
-
368
- apply_on_flat_ix_with_dim_mask_and_base (
369
- fn, in, is_in_dim_list, base, ustart, uend);
420
+ ApplyOverDimListPlan plan (in, dim_list, start, end);
421
+ plan.execute (fn, out_ix);
370
422
}
371
423
372
424
//
0 commit comments