Skip to content

Commit f680c44

Browse files
authored
Refactor apply_over_dim_list to plan-then-execute (#9107)
First step toward plan-then-execute for reductions, which will cause preparatory work to be done once rather than once per output element.
1 parent 4d21def commit f680c44

File tree

1 file changed

+89
-37
lines changed

1 file changed

+89
-37
lines changed

kernels/portable/cpu/util/reduce_util.h

Lines changed: 89 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ template <typename Fn>
4545
void apply_on_flat_ix_with_dim_mask_and_base(
4646
const Fn& fn,
4747
const Tensor& in,
48-
bool* dim_mask,
48+
const bool* dim_mask,
4949
const size_t base,
5050
const size_t start,
5151
const size_t end) {
@@ -315,6 +315,92 @@ void apply_over_dim(
315315
}
316316
}
317317

318+
/**
319+
* Execution plan for repeated apply_over_dim_list with the same
320+
* function, input tensor, dim list, start, and end but varying
321+
* out_ix, as done (via {map_,}reduce_over_dim_list) in reductions.
322+
*/
323+
class ApplyOverDimListPlan {
324+
public:
325+
ApplyOverDimListPlan(
326+
const executorch::aten::Tensor& in,
327+
// If set, lifetime must last until execute() returns.
328+
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
329+
dim_list,
330+
const int64_t start = 0,
331+
const int64_t end = -1)
332+
: in_(in) {
333+
ET_CHECK(check_dim_list_is_valid(in, dim_list));
334+
out_numel_ = get_out_numel(in_, dim_list);
335+
if (in.numel() == 0) {
336+
mode_ = ExecutionMode::NothingToDo;
337+
return;
338+
}
339+
const size_t iter_length = get_reduced_dim_product(in, dim_list);
340+
const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length);
341+
const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length);
342+
ustart_ = std::max(normalized_start, size_t(0));
343+
uend_ = std::min(normalized_end, iter_length - 1);
344+
if (!dim_list.has_value() || dim_list.value().size() == 0 ||
345+
in.dim() == 0) {
346+
mode_ = ExecutionMode::NoDimMaskOrZeroDimension;
347+
return;
348+
}
349+
dim_list_ = dim_list.value();
350+
is_in_dim_list_.fill(0);
351+
for (const auto& d : dim_list.value()) {
352+
const size_t non_neg_d = d < 0 ? d + in.dim() : d;
353+
is_in_dim_list_[non_neg_d] = true;
354+
}
355+
356+
mode_ = ExecutionMode::NormalDimMask;
357+
}
358+
359+
template <typename Fn>
360+
void execute(const Fn& fn, const size_t out_ix) const {
361+
ET_CHECK_MSG(out_ix < out_numel_, "Out index %zd is out of bounds", out_ix);
362+
363+
switch (mode_) {
364+
case ExecutionMode::NothingToDo:
365+
return;
366+
case ExecutionMode::NoDimMaskOrZeroDimension:
367+
apply_on_flat_ix_with_stride_and_base(
368+
fn, /*stride=*/1, /*base=*/0, ustart_, uend_);
369+
return;
370+
case ExecutionMode::NormalDimMask:
371+
apply_on_flat_ix_with_dim_mask_and_base(
372+
fn,
373+
in_,
374+
is_in_dim_list_.data(),
375+
get_init_index(in_, dim_list_, out_ix),
376+
ustart_,
377+
uend_);
378+
return;
379+
}
380+
}
381+
382+
private:
383+
// Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
384+
size_t ustart_;
385+
// End argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
386+
size_t uend_;
387+
enum class ExecutionMode {
388+
// Empty input, no work to do.
389+
NothingToDo,
390+
// Iterate over the entire tensor with
391+
// apply_on_flat_ix_with_stride_and_base.
392+
NoDimMaskOrZeroDimension,
393+
// General mode, iterate with
394+
// apply_on_flat_ix_with_dim_mask_and_base.
395+
NormalDimMask
396+
};
397+
ExecutionMode mode_;
398+
size_t out_numel_;
399+
executorch::aten::ArrayRef<int64_t> dim_list_;
400+
std::array<bool, kTensorDimensionLimit> is_in_dim_list_;
401+
const executorch::aten::Tensor& in_;
402+
};
403+
318404
/**
319405
* Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
320406
* for the output element at index `out_ix` using the reduce function
@@ -331,42 +417,8 @@ void apply_over_dim_list(
331417
const size_t out_ix,
332418
const int64_t start = 0,
333419
const int64_t end = -1) {
334-
ET_CHECK(check_dim_list_is_valid(in, dim_list));
335-
ET_CHECK_MSG(
336-
out_ix < get_out_numel(in, dim_list),
337-
"Out index %zd is out of bounds",
338-
out_ix);
339-
340-
if (in.numel() == 0) {
341-
return;
342-
}
343-
344-
const size_t iter_length = get_reduced_dim_product(in, dim_list);
345-
const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length);
346-
const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length);
347-
const size_t ustart = std::max(normalized_start, size_t(0));
348-
const size_t uend = std::min(normalized_end, iter_length - 1);
349-
350-
// If dim_list is null or empty, or in is 0-D, iterate over the entire tensor
351-
if (!dim_list.has_value() || dim_list.value().size() == 0 || in.dim() == 0) {
352-
apply_on_flat_ix_with_stride_and_base(
353-
fn, /*stride=*/1, /*base=*/0, ustart, uend);
354-
return;
355-
}
356-
357-
// Create is_in_dims to check whether each dimension is in the dim list
358-
bool is_in_dim_list[kTensorDimensionLimit];
359-
memset(is_in_dim_list, false, sizeof(is_in_dim_list));
360-
for (const auto& d : dim_list.value()) {
361-
const size_t non_neg_d = d < 0 ? d + in.dim() : d;
362-
is_in_dim_list[non_neg_d] = true;
363-
}
364-
365-
// Compute the starting base index
366-
const size_t base = get_init_index(in, dim_list, out_ix);
367-
368-
apply_on_flat_ix_with_dim_mask_and_base(
369-
fn, in, is_in_dim_list, base, ustart, uend);
420+
ApplyOverDimListPlan plan(in, dim_list, start, end);
421+
plan.execute(fn, out_ix);
370422
}
371423

372424
//

0 commit comments

Comments
 (0)