Skip to content

Commit 3e9d22c

Browse files
pytorchbotlucylq
andauthored
[llm_patch] Fix out-of-bounds access in pad2d function (#15865)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #15832 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/128/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/128/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/128/orig Differential Revision: [D80831697](https://our.internmc.facebook.com/intern/diff/D80831697/) @diff-train-skip-merge Co-authored-by: lucylq <[email protected]>
1 parent 0769abc commit 3e9d22c

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

kernels/portable/cpu/util/padding_util.h

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ void pad1d(
5656
size_t out_i_base = i * out_width;
5757
size_t in_i_base = i * in_width;
5858
for (const auto w : c10::irange(out_width)) {
59-
out_data[out_i_base + w] =
60-
in_data[in_i_base + padding_ix(w, in_width, pad_left)];
59+
int64_t in_w_idx = padding_ix(w, in_width, pad_left);
60+
ET_CHECK(in_w_idx >= 0 && in_w_idx < in_width);
61+
out_data[out_i_base + w] = in_data[in_i_base + in_w_idx];
6162
}
6263
}
6364
}
@@ -85,11 +86,13 @@ void pad2d(
8586
size_t in_i_base = i * in_height * in_width;
8687
for (const auto h : c10::irange(out_height)) {
8788
size_t out_h_base = out_i_base + h * out_width;
88-
size_t in_h_base =
89-
in_i_base + padding_ix(h, in_height, pad_top) * in_width;
89+
int64_t in_h_idx = padding_ix(h, in_height, pad_top);
90+
ET_CHECK(in_h_idx >= 0 && in_h_idx < in_height);
91+
size_t in_h_base = in_i_base + in_h_idx * in_width;
9092
for (const auto w : c10::irange(out_width)) {
91-
out_data[out_h_base + w] =
92-
in_data[in_h_base + padding_ix(w, in_width, pad_left)];
93+
int64_t in_w_idx = padding_ix(w, in_width, pad_left);
94+
ET_CHECK(in_w_idx >= 0 && in_w_idx < in_width);
95+
out_data[out_h_base + w] = in_data[in_h_base + in_w_idx];
9396
}
9497
}
9598
}
@@ -121,15 +124,18 @@ void pad3d(
121124
size_t in_i_base = i * in_depth * in_height * in_width;
122125
for (const auto d : c10::irange(out_depth)) {
123126
size_t out_d_base = out_i_base + d * out_height * out_width;
124-
size_t in_d_base =
125-
in_i_base + padding_ix(d, in_depth, pad_front) * in_height * in_width;
127+
int64_t in_d_base_padding = padding_ix(d, in_depth, pad_front);
128+
ET_CHECK(in_d_base_padding >= 0 && in_d_base_padding < in_depth);
129+
size_t in_d_base = in_i_base + in_d_base_padding * in_height * in_width;
126130
for (const auto h : c10::irange(out_height)) {
127131
size_t out_h_base = out_d_base + h * out_width;
128-
size_t in_h_base =
129-
in_d_base + padding_ix(h, in_height, pad_top) * in_width;
132+
int64_t in_h_base_padding = padding_ix(h, in_height, pad_top);
133+
ET_CHECK(in_h_base_padding >= 0 && in_h_base_padding < in_height);
134+
size_t in_h_base = in_d_base + in_h_base_padding * in_width;
130135
for (const auto w : c10::irange(out_width)) {
131-
out_data[out_h_base + w] =
132-
in_data[in_h_base + padding_ix(w, in_width, pad_left)];
136+
int64_t in_w_base_padding = padding_ix(w, in_width, pad_left);
137+
ET_CHECK(in_w_base_padding >= 0 && in_w_base_padding < in_width);
138+
out_data[out_h_base + w] = in_data[in_h_base + in_w_base_padding];
133139
}
134140
}
135141
}

0 commit comments

Comments
 (0)