Skip to content

Commit 3964503

Browse files
chunyuan-wEikanWang
authored andcommitted
add check on not dil and not own whole storage tensor
1 parent cde7f43 commit 3964503

File tree

5 files changed

+92
-15
lines changed

5 files changed

+92
-15
lines changed

tests/cpu/test_bf16_lazy_reorder.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,29 @@ def test_sliced_inplace_eltwise(self):
812812
self._check_tensor_shape(x_cpu_slice, x_dpcpp_slice)
813813
self.assertEqual(x_cpu_slice, x_dpcpp_slice, 0.01)
814814

815+
def test_sliced_eltwise_backward(self):
816+
rand_seed = int(get_rand_seed())
817+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
818+
torch.manual_seed(rand_seed)
819+
820+
input = torch.rand(10, 10, 10)
821+
with AutoDNNL(True), AutoMixPrecision(True, train=True):
822+
x_cpu = input.clone().requires_grad_()
823+
x_cpu_slice = x_cpu[3:7, 3:7, 5]
824+
825+
x_dpcpp = input.clone().to(device=device).requires_grad_()
826+
x_dpcpp_slice = x_dpcpp[3:7, 3:7, 5]
827+
828+
y_cpu = F.relu(x_cpu_slice)
829+
y_dpcpp = F.relu(x_dpcpp_slice)
830+
831+
y_cpu.sum().backward()
832+
y_dpcpp.sum().backward()
833+
834+
self._check_tensor_shape(y_cpu, y_dpcpp)
835+
self.assertEqual(y_cpu, y_dpcpp)
836+
self.assertEqual(x_cpu.grad, x_dpcpp.grad)
837+
815838
def test_linear_with_sliced_bias(self):
816839
bias = torch.rand(30)
817840
x_cpu = torch.rand(20, 30)
@@ -827,6 +850,42 @@ def test_linear_with_sliced_bias(self):
827850

828851
self.assertEqual(y_cpu, y_dpcpp, 0.1)
829852

853+
def test_chunk_version_counter(self):
854+
rand_seed = int(get_rand_seed())
855+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
856+
torch.manual_seed(rand_seed)
857+
858+
x_dpcpp = torch.randn(32, 4096).to(device).requires_grad_()
859+
860+
with AutoDNNL(True), AutoMixPrecision(True, train=True):
861+
x_chunked = x_dpcpp.chunk(4, 1)
862+
863+
output = x_chunked[0].sigmoid_()
864+
version_counter = output._version
865+
866+
output_other = x_chunked[1].sigmoid_()
867+
self.assertTrue(output._version == version_counter)
868+
869+
def test_unbind(self):
870+
rand_seed = int(get_rand_seed())
871+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
872+
torch.manual_seed(rand_seed)
873+
x_cpu = torch.rand(2, 8, 2)
874+
x_dpcpp = copy.deepcopy(x_cpu).to(device=device)
875+
876+
x_cpu_unbind = torch.unbind(x_cpu)
877+
with AutoDNNL(True), AutoMixPrecision(True):
878+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_dpcpp))
879+
x_dpcpp_unbind = torch.unbind(x_dpcpp)
880+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_dpcpp))
881+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_dpcpp_unbind[0]))
882+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_dpcpp_unbind[1]))
883+
884+
self._check_tensor_shape(x_cpu_unbind[0], x_dpcpp_unbind[0])
885+
self._check_tensor_shape(x_cpu_unbind[1], x_dpcpp_unbind[1])
886+
self.assertEqual(x_cpu_unbind[0], x_dpcpp_unbind[0], 0.01)
887+
self.assertEqual(x_cpu_unbind[1], x_dpcpp_unbind[1], 0.01)
888+
830889
class TestBinOPs(TestCase):
831890
def _gen_shapes(self):
832891
dims = torch.randint(1, 10, (1,))

tests/cpu/test_lazy_reorder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,18 @@ def test_split_backward(self):
14221422
y2.backward()
14231423
self.assertEqual(x1.grad, x2.grad)
14241424

1425+
def test_split_share_memory(self):
1426+
with AutoDNNL(True):
1427+
x_dpcpp = torch.FloatTensor([1, 1, 1, 1, -1, -1, -1, -1]).to(device=device)
1428+
other = torch.FloatTensor([-1, -1, -1, -1]).to(device=device)
1429+
1430+
x_target = torch.FloatTensor([0, 0, 0, 0, -1, -1, -1, -1]).to(device=device)
1431+
1432+
splited_x = torch.split(x_dpcpp, 4)
1433+
splited_x[0].add_(other)
1434+
1435+
self.assertEqual(x_dpcpp, x_target)
1436+
14251437
class ConvRelu(nn.Module):
14261438
def __init__(self):
14271439
super(ConvRelu, self).__init__()

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,7 +1929,11 @@ at::Tensor dil_as_strided(
19291929

19301930
auto* _tensor_impl = (IPEXTensorImpl *)result.unsafeGetTensorImpl();
19311931
_tensor_impl->copy_meta_info(self.unsafeGetTensorImpl());
1932-
// reset version counter for chunk
1932+
// When a tensor is chunked, the obtained chunked tensors do not share the version counter.
1933+
// We have copied the version counter in copy_meta_info and it is a workaround to reset the
1934+
// version counter here.
1935+
// Note that when a tensor is sliced, PyTorch will call as_view which will copy the version
1936+
// counter to the sliced tensor. We do not need to handle it here.
19331937
_tensor_impl->set_version_counter(0);
19341938
_tensor_impl->copy_auto_grad(self.unsafeGetTensorImpl());
19351939

@@ -2134,8 +2138,21 @@ at::Tensor AtenIpexCPUDev::dil_select(const at::Tensor & self, at::Dimname dim,
21342138
return dil_select(self, at::dimname_to_position(self, dim), index);
21352139
}
21362140

2141+
at::Tensor _dil_narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length) {
2142+
// Port from aten/src/ATen/native/TensorShape.cpp
2143+
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
2144+
auto cur_size = self.size(dim);
2145+
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
2146+
start = at::maybe_wrap_dim(start, cur_size);
2147+
}
2148+
TORCH_CHECK(length >= 0 && start <= cur_size - length,
2149+
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
2150+
return AtenIpexCPUDev::dil_slice(self, dim, start, start + length, 1);
2151+
}
2152+
21372153
std::vector<at::Tensor> AtenIpexCPUDev::dil_split(const at::Tensor& self, int64_t split_size, int64_t dim) {
21382154
DEBUG("AtenIpexCPUDev::dil_split\n");
2155+
// Port from aten/src/ATen/native/TensorShape.cpp
21392156
TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
21402157
TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
21412158

@@ -2162,19 +2179,6 @@ std::vector<at::Tensor> AtenIpexCPUDev::dil_split(const at::Tensor& self, int64_
21622179
return splits;
21632180
}
21642181

2165-
// TODO only used for dil_split
2166-
at::Tensor AtenIpexCPUDev::_dil_narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length) {
2167-
// Port from aten/src/ATen/native/TensorShape.cpp
2168-
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
2169-
auto cur_size = self.size(dim);
2170-
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
2171-
start = at::maybe_wrap_dim(start, cur_size);
2172-
}
2173-
TORCH_CHECK(length >= 0 && start <= cur_size - length,
2174-
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
2175-
return dil_slice(self, dim, start, start + length, 1);
2176-
}
2177-
21782182
at::Tensor AtenIpexCPUDev::dil_gelu(const at::Tensor& input) {
21792183
DEBUG("AtenIpexCPUDev::dil_gelu\n");
21802184
CHECK_DNNL_OP_PRE_COND(input);

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ class AtenIpexCPUDev {
7171
static at::Tensor dil_cat(at::TensorList tensors, int64_t dim);
7272
static std::vector<at::Tensor> dil_split_with_sizes(const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim);
7373
static std::vector<at::Tensor> dil_split(const at::Tensor& self, int64_t split_size, int64_t dim);
74-
static at::Tensor _dil_narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
7574
static at::Tensor dil_gelu(const at::Tensor& input);
7675
static at::Tensor dil_gelu_backward(const at::Tensor& grad_output, const at::Tensor& input);
7776
static std::tuple<at::Tensor, at::Tensor, at::Tensor> dil_native_layer_norm(const at::Tensor& X, const at::Tensor& gamma, const at::Tensor& beta, int64_t M, int64_t N, double eps);

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ void reorder_to_dtype(const at::Tensor& tensor, at::ScalarType dst_scalar_type,
155155
// The data type of DIL tensor is same as the dst data type. DO NOTHING
156156
return;
157157
}
158+
// should fallback if not dil tensor and not own whole storage
159+
IPEX_CHECK(cpu::ShadeDataContext::isDilTensor(tensor) || check_tensor_own_whole_storage(tensor), "Reorder only works while tensor owns the whole storage or tensor is a dil tensor");
160+
158161
auto dst_desc = src.get_desc().to_type(get_dil_data_type(dst_scalar_type));
159162
reorder_to_desc(tensor, dst_desc, scales);
160163
}

0 commit comments

Comments
 (0)