@@ -961,46 +961,6 @@ def test_add_layernorm(self):
961961 node = "ipex::add_layernorm"
962962 self .assertTrue (any (n .kind () == node for n in trace_graph .nodes ()))
963963
964- def _test_concat_bn_relu (self , a1 , a2 , a3 , enable_3d = True , use_channels_last = True ):
965- if enable_3d :
966- if use_channels_last :
967- model = ConcatBnRelu3d ().eval ().to (memory_format = torch .channels_last_3d )
968- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
969- with torch .no_grad ():
970- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
971- jit_model = torch .jit .freeze (jit_model )
972- jit_res = jit_model (a1 , a2 , a3 )
973- ori_res = model (a1 , a2 , a3 )
974- self .assertEqual (jit_res , ori_res )
975- else :
976- model = ConcatBnRelu3d ().eval ()
977- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
978- with torch .no_grad ():
979- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
980- jit_model = torch .jit .freeze (jit_model )
981- jit_res = jit_model (a1 , a2 , a3 )
982- ori_res = model (a1 , a2 , a3 )
983- self .assertEqual (jit_res , ori_res )
984- else :
985- if use_channels_last :
986- model = ConcatBnRelu2d ().eval ().to (memory_format = torch .channels_last )
987- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
988- with torch .no_grad ():
989- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
990- jit_model = torch .jit .freeze (jit_model )
991- jit_res = jit_model (a1 , a2 , a3 )
992- ori_res = model (a1 , a2 , a3 )
993- self .assertEqual (jit_res , ori_res )
994- else :
995- model = ConcatBnRelu2d ().eval ()
996- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
997- with torch .no_grad ():
998- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
999- jit_model = torch .jit .freeze (jit_model )
1000- jit_res = jit_model (a1 , a2 , a3 )
1001- ori_res = model (a1 , a2 , a3 )
1002- self .assertEqual (jit_res , ori_res )
1003-
1004964 def test_concat_bn_relu (self ):
1005965 a1 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .bfloat16 ).contiguous (memory_format = torch .channels_last )
1006966 a2 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .bfloat16 ).contiguous (memory_format = torch .channels_last )
@@ -1010,8 +970,10 @@ def test_concat_bn_relu(self):
1010970 with torch .no_grad ():
1011971 jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1012972 jit_model = torch .jit .freeze (jit_model )
1013- jit_res = jit_model (a1 , a2 , a3 )
1014- ori_res = model (a1 , a2 , a3 )
973+ #warmup run
974+ for _ in range (2 ):
975+ jit_res = jit_model (a1 , a2 , a3 )
976+ ori_res = model (a1 , a2 , a3 )
1015977 self .assertEqual (jit_res , ori_res )
1016978
1017979 a1 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
@@ -1022,46 +984,92 @@ def test_concat_bn_relu(self):
1022984 with torch .no_grad ():
1023985 jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1024986 jit_model = torch .jit .freeze (jit_model )
1025- jit_res = jit_model (a1 , a2 , a3 )
1026- ori_res = model (a1 , a2 , a3 )
987+ #warmup run
988+ for _ in range (2 ):
989+ jit_res = jit_model (a1 , a2 , a3 )
990+ ori_res = model (a1 , a2 , a3 )
1027991 self .assertEqual (jit_res , ori_res )
1028992
1029- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = False , use_channels_last = True )
993+ model = ConcatBnRelu2d ().eval ().to (memory_format = torch .channels_last )
994+ model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
995+ with torch .no_grad ():
996+ jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
997+ jit_model = torch .jit .freeze (jit_model )
998+ #warmup run
999+ for _ in range (2 ):
1000+ jit_res = jit_model (a1 , a2 , a3 )
1001+ ori_res = model (a1 , a2 , a3 )
1002+ self .assertEqual (jit_res , ori_res )
10301003
1031- a1 = torch .randn (1 , 16 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1032- a2 = torch .randn (1 , 48 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1033- a3 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1034- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = False , use_channels_last = True )
1004+ a1 = torch .randn (1 , 32 , 18 , 53 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1005+ a2 = torch .randn (1 , 32 , 18 , 53 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1006+ a3 = torch .randn (1 , 32 , 18 , 53 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1007+ with torch .no_grad ():
1008+ jit_res = jit_model (a1 , a2 , a3 )
1009+ ori_res = model (a1 , a2 , a3 )
1010+ self .assertEqual (jit_res , ori_res )
10351011
1036- a1 = torch .randn (1 , 17 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1037- a2 = torch .randn (1 , 47 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1038- a3 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1039- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = False , use_channels_last = True )
1012+ a1 = torch .randn (1 , 16 , 24 , 116 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1013+ a2 = torch .randn (1 , 48 , 24 , 116 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1014+ a3 = torch .randn (1 , 32 , 24 , 116 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1015+ with torch .no_grad ():
1016+ jit_res = jit_model (a1 , a2 , a3 )
1017+ ori_res = model (a1 , a2 , a3 )
1018+ self .assertEqual (jit_res , ori_res )
1019+
1020+ a1 = torch .randn (1 , 17 , 15 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1021+ a2 = torch .randn (1 , 47 , 15 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1022+ a3 = torch .randn (1 , 32 , 15 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1023+ with torch .no_grad ():
1024+ jit_res = jit_model (a1 , a2 , a3 )
1025+ ori_res = model (a1 , a2 , a3 )
1026+ self .assertEqual (jit_res , ori_res )
10401027
10411028 a1 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float )
10421029 a2 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float )
10431030 a3 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float )
1044- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = False , use_channels_last = False )
1031+ with torch .no_grad ():
1032+ jit_res = jit_model (a1 , a2 , a3 )
1033+ ori_res = model (a1 , a2 , a3 )
1034+ self .assertEqual (jit_res , ori_res )
10451035
10461036 a1 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
10471037 a2 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
10481038 a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1049- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = True , use_channels_last = True )
1039+ model = ConcatBnRelu3d ().eval ().to (memory_format = torch .channels_last_3d )
1040+ model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
1041+ with torch .no_grad ():
1042+ jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1043+ jit_model = torch .jit .freeze (jit_model )
1044+ #warmup run
1045+ for _ in range (2 ):
1046+ jit_res = jit_model (a1 , a2 , a3 )
1047+ ori_res = model (a1 , a2 , a3 )
1048+ self .assertEqual (jit_res , ori_res )
10501049
1051- a1 = torch .randn (1 , 16 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1052- a2 = torch .randn (1 , 48 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1053- a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1054- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = True , use_channels_last = True )
1050+ a1 = torch .randn (1 , 16 , 17 , 14 , 31 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1051+ a2 = torch .randn (1 , 48 , 17 , 14 , 31 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1052+ a3 = torch .randn (1 , 32 , 17 , 14 , 31 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1053+ with torch .no_grad ():
1054+ jit_res = jit_model (a1 , a2 , a3 )
1055+ ori_res = model (a1 , a2 , a3 )
1056+ self .assertEqual (jit_res , ori_res )
10551057
10561058 a1 = torch .randn (1 , 17 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
10571059 a2 = torch .randn (1 , 47 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
10581060 a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1059- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = True , use_channels_last = True )
1061+ with torch .no_grad ():
1062+ jit_res = jit_model (a1 , a2 , a3 )
1063+ ori_res = model (a1 , a2 , a3 )
1064+ self .assertEqual (jit_res , ori_res )
10601065
10611066 a1 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float )
10621067 a2 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float )
10631068 a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float )
1064- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = True , use_channels_last = False )
1069+ with torch .no_grad ():
1070+ jit_res = jit_model (a1 , a2 , a3 )
1071+ ori_res = model (a1 , a2 , a3 )
1072+ self .assertEqual (jit_res , ori_res )
10651073
10661074 def test_mha_scores_calculation (self ):
10671075 def _check_match_mha (trace_model , mat1 , mat2 , bias , node = "ipex::mha_scores_calc" ):
0 commit comments