Skip to content

Commit ac0195b

Browse files
authored
support more data types in tensor dumper (#24813)
### Description support of dumping tensors of int8, uin t8, BFloat16, UInt4x2, and Int4x2 data types in the tensor dumper. ### Motivation and Context Help debugging of operators using these data types.
1 parent 8c1156e commit ac0195b

File tree

5 files changed

+313
-407
lines changed

5 files changed

+313
-407
lines changed

onnxruntime/contrib_ops/cpu/utils/console_dumper.h

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,31 @@ class IConsoleDumper {
1717
virtual ~IConsoleDumper() {}
1818
void Disable() { is_enabled_ = false; }
1919
bool IsEnabled() const { return is_enabled_; }
20-
virtual void Print(const char* name, const float* tensor, int dim0, int dim1) const = 0;
21-
virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const = 0;
22-
virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0;
23-
virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const = 0;
24-
virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const = 0;
25-
26-
virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const = 0;
27-
virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const = 0;
28-
virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const = 0;
29-
virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const = 0;
30-
31-
virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
32-
virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
33-
virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
34-
virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
35-
36-
virtual void Print(const char* name, const int32_t* tensor, gsl::span<const int64_t>& dims) const = 0;
37-
virtual void Print(const char* name, const int64_t* tensor, gsl::span<const int64_t>& dims) const = 0;
38-
virtual void Print(const char* name, const float* tensor, gsl::span<const int64_t>& dims) const = 0;
39-
virtual void Print(const char* name, const MLFloat16* tensor, gsl::span<const int64_t>& dims) const = 0;
4020

21+
virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0;
4122
virtual void Print(const char* name, const Tensor& value) const = 0;
4223
virtual void Print(const char* name, const OrtValue& value) const = 0;
4324
virtual void Print(const char* name, int index, bool end_line) const = 0;
4425
virtual void Print(const char* name, const std::string& value, bool end_line) const = 0;
45-
4626
virtual void Print(const std::string& value) const = 0;
4727

28+
#define TENSOR_DUMPER_PRINT_TYPE(dtype) \
29+
virtual void Print(const char* name, const dtype* tensor, int dim0, int dim1) const = 0; \
30+
virtual void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const = 0; \
31+
virtual void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; \
32+
virtual void Print(const char* name, const dtype* tensor, gsl::span<const int64_t>& dims) const = 0;
33+
34+
TENSOR_DUMPER_PRINT_TYPE(int8_t)
35+
TENSOR_DUMPER_PRINT_TYPE(uint8_t)
36+
TENSOR_DUMPER_PRINT_TYPE(int32_t)
37+
TENSOR_DUMPER_PRINT_TYPE(int64_t)
38+
TENSOR_DUMPER_PRINT_TYPE(float)
39+
TENSOR_DUMPER_PRINT_TYPE(MLFloat16)
40+
TENSOR_DUMPER_PRINT_TYPE(BFloat16)
41+
TENSOR_DUMPER_PRINT_TYPE(UInt4x2)
42+
TENSOR_DUMPER_PRINT_TYPE(Int4x2)
43+
#undef TENSOR_DUMPER_PRINT_TYPE
44+
4845
protected:
4946
bool is_enabled_;
5047
};

0 commit comments

Comments
 (0)