@@ -149,6 +149,16 @@ def __init__(self, in_channels, out_channels, **kwargs):
149149
150150 def forward (self , x ):
151151 return F .relu (self .linear (x ), inplace = True )
152+
153+ class LinearGelu (nn .Module ):
154+ def __init__ (self , in_channels , out_channels , ** kwargs ):
155+ super (LinearGelu , self ).__init__ ()
156+ seed = 2018
157+ torch .manual_seed (seed )
158+ self .linear = nn .Linear (in_channels , out_channels , ** kwargs )
159+
160+ def forward (self , x ):
161+ return F .gelu (self .linear (x ))
152162
153163class ConvSumInDiffBlock (nn .Module ):
154164 def __init__ (self , dim , in_channels , out_channels , ** kwargs ):
@@ -544,6 +554,27 @@ def test_output_linear_relu(self):
544554 kind_in_graph = "ipex::linear_relu" )
545555
546556
557+ def test_output_linear_gelu (self ):
558+ self ._test_output (
559+ LinearGelu (3 , 32 , bias = True ),
560+ torch .rand (32 , 3 ),
561+ kind_in_graph = "ipex::linear_gelu" )
562+ self ._test_output_bf16 (
563+ LinearGelu (3 , 32 , bias = True ),
564+ torch .rand (32 , 3 ),
565+ kind_in_graph = "ipex::linear_gelu" ,
566+ prec = 5e-3 )
567+ self ._test_output (
568+ LinearGelu (3 , 32 , bias = False ),
569+ torch .rand (32 , 3 ),
570+ kind_in_graph = "ipex::linear_gelu" )
571+ self ._test_output_bf16 (
572+ LinearGelu (3 , 32 , bias = False ),
573+ torch .rand (32 , 3 ),
574+ kind_in_graph = "ipex::linear_gelu" ,
575+ prec = 5e-3 )
576+
577+
547578 def test_channel_shuffle (self ):
548579 self ._test_output (
549580 ChannelShuffle (10 , 16 , 50 , 50 , 4 ),
0 commit comments