@@ -1985,6 +1985,7 @@ at::Tensor AtenIpexCPUDev::dil_slice(const at::Tensor & self, int64_t dim, int64
1985
1985
DEBUG (" AtenIpexCPUDev::dil_slice\n " );
1986
1986
CHECK_DNNL_OP_PRE_COND (self);
1987
1987
1988
+ // TODO use weight TAG to decide whether to reorder or not
1988
1989
dbl::comm::reorder_to_bf16_for_mix_prec (self, true );
1989
1990
1990
1991
// Port from aten/src/ATen/native/TensorShape.cpp
@@ -2023,6 +2024,22 @@ at::Tensor AtenIpexCPUDev::dil_slice(const at::Tensor & self, int64_t dim, int64
2023
2024
return result;
2024
2025
}
2025
2026
2027
+ std::vector<at::Tensor> AtenIpexCPUDev::dil_unbind (const at::Tensor &self, int64_t dim) {
2028
+ DEBUG (" AtenIpexCPUDev::dil_unbind\n " );
2029
+
2030
+ dim = at::maybe_wrap_dim (dim, self.dim ());
2031
+ int64_t size = dil_size (self, dim);
2032
+ std::vector<at::Tensor> tensors (size);
2033
+ for (int i = 0 ; i < size; i++) {
2034
+ tensors[i] = dil_select (self, dim, i);
2035
+ }
2036
+ return tensors;
2037
+ }
2038
+
2039
+ std::vector<at::Tensor>AtenIpexCPUDev::dil_unbind (const at::Tensor& self, at::Dimname dim) {
2040
+ return dil_unbind (self, at::dimname_to_position (self, dim));
2041
+ }
2042
+
2026
2043
at::Tensor AtenIpexCPUDev::dil_select (const at::Tensor & self, int64_t dim, int64_t index) {
2027
2044
DEBUG (" AtenIpexCPUDev::dil_select\n " );
2028
2045
CHECK_DNNL_OP_PRE_COND (self);
@@ -2119,19 +2136,43 @@ at::Tensor AtenIpexCPUDev::dil_select(const at::Tensor & self, at::Dimname dim,
2119
2136
2120
2137
std::vector<at::Tensor> AtenIpexCPUDev::dil_split (const at::Tensor& self, int64_t split_size, int64_t dim) {
2121
2138
DEBUG (" AtenIpexCPUDev::dil_split\n " );
2139
+ TORCH_CHECK (self.dim () != 0 , " split expects at least a 1-dimensional tensor" );
2140
+ TORCH_CHECK (split_size >= 0 , " split expects split_size be non-negative, but got split_size=" , split_size);
2141
+
2122
2142
CHECK_DNNL_OP_PRE_COND (self);
2123
2143
dim = at::maybe_wrap_dim (dim, self.dim ());
2124
2144
int64_t dim_size = dil_size (self, dim);
2145
+ TORCH_CHECK (split_size > 0 || self.size (dim) == 0 ,
2146
+ " split_size can only be 0 if dimension size is 0, "
2147
+ " but got dimension size of " , dim_size);
2148
+ // if split_size is 0 and dimension size is 0, there is 1 split.
2125
2149
int64_t num_splits = 1 ;
2126
2150
if (split_size != 0 ) {
2127
2151
// ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
2128
2152
// (returns a single split). We might want to error here, but keep it for BC.
2129
2153
num_splits = std::max<int64_t >((dim_size + split_size - 1 ) / split_size, 1 );
2130
2154
}
2131
- std::vector<int64_t > split_sizes (num_splits, split_size );
2155
+ std::vector<at::Tensor> splits (num_splits);
2132
2156
int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
2133
- split_sizes[num_splits-1 ] = last_split_size;
2134
- return dil_split_with_sizes (self, split_sizes, dim);
2157
+
2158
+ for (int64_t i = 0 ; i < num_splits; ++i) {
2159
+ auto length = i < num_splits - 1 ? split_size : last_split_size;
2160
+ splits[i] = _dil_narrow (self, dim, i * split_size, length);
2161
+ }
2162
+ return splits;
2163
+ }
2164
+
2165
+ // TODO only used for dil_split
2166
+ at::Tensor AtenIpexCPUDev::_dil_narrow (const at::Tensor& self, int64_t dim, int64_t start, int64_t length) {
2167
+ // Port from aten/src/ATen/native/TensorShape.cpp
2168
+ TORCH_CHECK (self.dim () > 0 , " narrow() cannot be applied to a 0-dim tensor." );
2169
+ auto cur_size = self.size (dim);
2170
+ if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
2171
+ start = at::maybe_wrap_dim (start, cur_size);
2172
+ }
2173
+ TORCH_CHECK (length >= 0 && start <= cur_size - length,
2174
+ " start (" , start, " ) + length (" , length, " ) exceeds dimension size (" , cur_size, " )." );
2175
+ return dil_slice (self, dim, start, start + length, 1 );
2135
2176
}
2136
2177
2137
2178
at::Tensor AtenIpexCPUDev::dil_gelu (const at::Tensor& input) {
0 commit comments