Skip to content

Commit cde7f43

Browse files
chunyuan-wEikanWang
authored andcommitted
enable dil_unbind and fix dil_split
1 parent c529ac0 commit cde7f43

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
'aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)',
7171
'aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)',
7272
'aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)',
73+
'aten::unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[]',
74+
'aten::unbind.Dimname(Tensor(a) self, Dimname dim) -> Tensor(a)[]',
7375
'aten::view(Tensor(a) self, int[] size) -> Tensor(a)',
7476
'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor',
7577
'aten::_unsafe_view(Tensor self, int[] size) -> Tensor',

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,7 @@ at::Tensor AtenIpexCPUDev::dil_slice(const at::Tensor & self, int64_t dim, int64
19851985
DEBUG("AtenIpexCPUDev::dil_slice\n");
19861986
CHECK_DNNL_OP_PRE_COND(self);
19871987

1988+
// TODO use weight TAG to decide whether to reorder or not
19881989
dbl::comm::reorder_to_bf16_for_mix_prec(self, true);
19891990

19901991
// Port from aten/src/ATen/native/TensorShape.cpp
@@ -2023,6 +2024,22 @@ at::Tensor AtenIpexCPUDev::dil_slice(const at::Tensor & self, int64_t dim, int64
20232024
return result;
20242025
}
20252026

2027+
std::vector<at::Tensor> AtenIpexCPUDev::dil_unbind(const at::Tensor &self, int64_t dim) {
2028+
DEBUG("AtenIpexCPUDev::dil_unbind\n");
2029+
2030+
dim = at::maybe_wrap_dim(dim, self.dim());
2031+
int64_t size = dil_size(self, dim);
2032+
std::vector<at::Tensor> tensors(size);
2033+
for (int i = 0; i < size; i++) {
2034+
tensors[i] = dil_select(self, dim, i);
2035+
}
2036+
return tensors;
2037+
}
2038+
2039+
std::vector<at::Tensor>AtenIpexCPUDev::dil_unbind(const at::Tensor& self, at::Dimname dim) {
2040+
return dil_unbind(self, at::dimname_to_position(self, dim));
2041+
}
2042+
20262043
at::Tensor AtenIpexCPUDev::dil_select(const at::Tensor & self, int64_t dim, int64_t index) {
20272044
DEBUG("AtenIpexCPUDev::dil_select\n");
20282045
CHECK_DNNL_OP_PRE_COND(self);
@@ -2119,19 +2136,43 @@ at::Tensor AtenIpexCPUDev::dil_select(const at::Tensor & self, at::Dimname dim,
21192136

21202137
std::vector<at::Tensor> AtenIpexCPUDev::dil_split(const at::Tensor& self, int64_t split_size, int64_t dim) {
21212138
DEBUG("AtenIpexCPUDev::dil_split\n");
2139+
TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
2140+
TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
2141+
21222142
CHECK_DNNL_OP_PRE_COND(self);
21232143
dim = at::maybe_wrap_dim(dim, self.dim());
21242144
int64_t dim_size = dil_size(self, dim);
2145+
TORCH_CHECK(split_size > 0 || self.size(dim) == 0,
2146+
"split_size can only be 0 if dimension size is 0, "
2147+
"but got dimension size of ", dim_size);
2148+
// if split_size is 0 and dimension size is 0, there is 1 split.
21252149
int64_t num_splits = 1;
21262150
if (split_size != 0) {
21272151
// ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
21282152
// (returns a single split). We might want to error here, but keep it for BC.
21292153
num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
21302154
}
2131-
std::vector<int64_t> split_sizes(num_splits, split_size);
2155+
std::vector<at::Tensor> splits(num_splits);
21322156
int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
2133-
split_sizes[num_splits-1] = last_split_size;
2134-
return dil_split_with_sizes(self, split_sizes, dim);
2157+
2158+
for (int64_t i = 0; i < num_splits; ++i) {
2159+
auto length = i < num_splits - 1 ? split_size : last_split_size;
2160+
splits[i] = _dil_narrow(self, dim, i * split_size, length);
2161+
}
2162+
return splits;
2163+
}
2164+
2165+
// TODO only used for dil_split
2166+
at::Tensor AtenIpexCPUDev::_dil_narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length) {
2167+
// Port from aten/src/ATen/native/TensorShape.cpp
2168+
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
2169+
auto cur_size = self.size(dim);
2170+
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
2171+
start = at::maybe_wrap_dim(start, cur_size);
2172+
}
2173+
TORCH_CHECK(length >= 0 && start <= cur_size - length,
2174+
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
2175+
return dil_slice(self, dim, start, start + length, 1);
21352176
}
21362177

21372178
at::Tensor AtenIpexCPUDev::dil_gelu(const at::Tensor& input) {

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,14 @@ class AtenIpexCPUDev {
7171
static at::Tensor dil_cat(at::TensorList tensors, int64_t dim);
7272
static std::vector<at::Tensor> dil_split_with_sizes(const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim);
7373
static std::vector<at::Tensor> dil_split(const at::Tensor& self, int64_t split_size, int64_t dim);
74+
static at::Tensor _dil_narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
7475
static at::Tensor dil_gelu(const at::Tensor& input);
7576
static at::Tensor dil_gelu_backward(const at::Tensor& grad_output, const at::Tensor& input);
7677
static std::tuple<at::Tensor, at::Tensor, at::Tensor> dil_native_layer_norm(const at::Tensor& X, const at::Tensor& gamma, const at::Tensor& beta, int64_t M, int64_t N, double eps);
7778
static std::tuple<at::Tensor, at::Tensor, at::Tensor> dil_native_layer_norm_backward(const at::Tensor& dY, const at::Tensor& X, const at::Tensor& mean, const at::Tensor& rstd, const at::Tensor& gamma, int64_t M, int64_t N, std::array<bool, 3> grad_input_mask);
7879
static at::Tensor dil_slice(const at::Tensor & self, int64_t dim, int64_t start, int64_t end, int64_t step);
80+
static std::vector<at::Tensor> dil_unbind(const at::Tensor &self, int64_t dim);
81+
static std::vector<at::Tensor> dil_unbind(const at::Tensor& self, at::Dimname dim);
7982
static at::Tensor dil_select(const at::Tensor & self, int64_t dim, int64_t index);
8083
static at::Tensor dil_select(const at::Tensor & self, at::Dimname dim, int64_t index);
8184
static at::Tensor dil_view(const at::Tensor & self, at::IntArrayRef size);

0 commit comments

Comments
 (0)