@@ -17,34 +17,31 @@ class IConsoleDumper {
1717 virtual ~IConsoleDumper () {}
1818 void Disable () { is_enabled_ = false ; }
1919 bool IsEnabled () const { return is_enabled_; }
20- virtual void Print (const char * name, const float * tensor, int dim0, int dim1) const = 0;
21- virtual void Print (const char * name, const MLFloat16* tensor, int dim0, int dim1) const = 0;
22- virtual void Print (const char * name, const size_t * tensor, int dim0, int dim1) const = 0;
23- virtual void Print (const char * name, const int64_t * tensor, int dim0, int dim1) const = 0;
24- virtual void Print (const char * name, const int32_t * tensor, int dim0, int dim1) const = 0;
25-
26- virtual void Print (const char * name, const float * tensor, int dim0, int dim1, int dim2) const = 0;
27- virtual void Print (const char * name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const = 0;
28- virtual void Print (const char * name, const int64_t * tensor, int dim0, int dim1, int dim2) const = 0;
29- virtual void Print (const char * name, const int32_t * tensor, int dim0, int dim1, int dim2) const = 0;
30-
31- virtual void Print (const char * name, const float * tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
32- virtual void Print (const char * name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
33- virtual void Print (const char * name, const int64_t * tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
34- virtual void Print (const char * name, const int32_t * tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
35-
36- virtual void Print (const char * name, const int32_t * tensor, gsl::span<const int64_t >& dims) const = 0;
37- virtual void Print (const char * name, const int64_t * tensor, gsl::span<const int64_t >& dims) const = 0;
38- virtual void Print (const char * name, const float * tensor, gsl::span<const int64_t >& dims) const = 0;
39- virtual void Print (const char * name, const MLFloat16* tensor, gsl::span<const int64_t >& dims) const = 0;
4020
21+ virtual void Print (const char * name, const size_t * tensor, int dim0, int dim1) const = 0;
4122 virtual void Print (const char * name, const Tensor& value) const = 0;
4223 virtual void Print (const char * name, const OrtValue& value) const = 0;
4324 virtual void Print (const char * name, int index, bool end_line) const = 0;
4425 virtual void Print (const char * name, const std::string& value, bool end_line) const = 0;
45-
4626 virtual void Print (const std::string& value) const = 0;
4727
28+ #define TENSOR_DUMPER_PRINT_TYPE (dtype ) \
29+ virtual void Print (const char * name, const dtype* tensor, int dim0, int dim1) const = 0; \
30+ virtual void Print (const char * name, const dtype* tensor, int dim0, int dim1, int dim2) const = 0; \
31+ virtual void Print (const char * name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; \
32+ virtual void Print (const char * name, const dtype* tensor, gsl::span<const int64_t >& dims) const = 0;
33+
34+ TENSOR_DUMPER_PRINT_TYPE (int8_t )
35+ TENSOR_DUMPER_PRINT_TYPE (uint8_t )
36+ TENSOR_DUMPER_PRINT_TYPE (int32_t )
37+ TENSOR_DUMPER_PRINT_TYPE (int64_t )
38+ TENSOR_DUMPER_PRINT_TYPE (float )
39+ TENSOR_DUMPER_PRINT_TYPE (MLFloat16)
40+ TENSOR_DUMPER_PRINT_TYPE (BFloat16)
41+ TENSOR_DUMPER_PRINT_TYPE (UInt4x2)
42+ TENSOR_DUMPER_PRINT_TYPE (Int4x2)
43+ #undef TENSOR_DUMPER_PRINT_TYPE
44+
4845 protected:
4946 bool is_enabled_;
5047};
0 commit comments