Skip to content

Commit 0cc56c9

Browse files
ooplesclaude
andcommitted
feat: add comprehensive fusion context operation support in hlir to llir lowering
Extends LowerNodeToOpWithFusionContext to support the full range of operations commonly used in modern neural network fusion patterns: - Pooling operations (MaxPool2D, AvgPool2D, GlobalAveragePooling) - Enables Conv+Pool and Conv+BN+ReLU+Pool fusion patterns - Preserves kernel size, stride, and padding parameters - Memory operations (Reshape, Transpose, Flatten) - Enables attention mechanism fusions requiring reshaping - Supports zero-copy view operations within fusions - Reduction operations (ReduceSum, Mean, ReduceMax, ReduceMin) - Enables softmax normalization fusions in attention - Preserves axes and keepDims parameters - Dense/FullyConnected operations - Commonly fused with activation functions (MatMul+Bias+ReLU) - Attention operations (Attention, MultiHeadAttention) - Enables transformer fusion patterns (LayerNorm+Attention) - Preserves numHeads, headDim, scale, and causal parameters - Dropout as identity during inference - Emits explicit identity op that can be optimized away This brings fusion operation support to feature parity with the main LowerNode method, enabling industry-leading fusion patterns for CNNs and transformers. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 680b912 commit 0cc56c9

File tree

1 file changed

+297
-0
lines changed

1 file changed

+297
-0
lines changed

src/InferenceOptimization/IR/Lowering/HLIRToLLIRLowering.cs

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,31 @@ private void LowerFusedNode(HLIRNode<T> node)
767767
OperationType.BatchNormalization or OperationType.LayerNormalization =>
768768
CreateNormalizationOpForFusion(node, outputShape, outputDataType, inputIds),
769769

770+
// Pooling operations (for fused Conv+Pool patterns)
771+
OperationType.MaxPool2D or OperationType.AvgPool2D or OperationType.GlobalAveragePooling =>
772+
CreatePoolingOpForFusion(node, outputShape, outputDataType, inputIds),
773+
774+
// Memory/reshape operations (for fused attention + reshape patterns)
775+
OperationType.Reshape or OperationType.Transpose or OperationType.Flatten =>
776+
CreateMemoryOpForFusion(node, outputShape, outputDataType, inputIds),
777+
778+
// Reduction operations (for fused attention + reduction patterns)
779+
OperationType.ReduceSum or OperationType.Mean or OperationType.ReduceMean or
780+
OperationType.ReduceMax or OperationType.ReduceMin =>
781+
CreateReductionOpForFusion(node, outputShape, outputDataType, inputIds),
782+
783+
// Dense/fully-connected operations (commonly fused with activations)
784+
OperationType.Dense or OperationType.FullyConnected =>
785+
CreateMatMulOp(node, outputShape, outputDataType, inputIds),
786+
787+
// Attention operations (for transformer fusions)
788+
OperationType.Attention or OperationType.MultiHeadAttention =>
789+
CreateAttentionOpForFusion(node, outputShape, outputDataType, inputIds),
790+
791+
// Dropout is identity during inference (no-op in fusion context)
792+
OperationType.Dropout =>
793+
CreateIdentityOpForFusion(node, outputShape, outputDataType, inputIds),
794+
770795
// Unsupported operation in fused context
771796
_ => throw new InvalidOperationException(
772797
$"Operation '{node.Operation}' is not supported within fused operations. " +
@@ -993,6 +1018,278 @@ private FusedOp CreateNormalizationOpForFusion(
9931018
return fusedOp;
9941019
}
9951020

1021+
/// <summary>
1022+
/// Creates a FusedOp representing a pooling operation for use within fused operation contexts.
1023+
/// </summary>
1024+
/// <param name="node">The HLIR node representing the pooling operation.</param>
1025+
/// <param name="outputShape">The output shape for this sub-operation.</param>
1026+
/// <param name="outputDataType">The output data type for this sub-operation.</param>
1027+
/// <param name="inputIds">The resolved input buffer IDs within the fusion context.</param>
1028+
/// <returns>A FusedOp containing pooling-specific parameters for LLIR execution.</returns>
1029+
/// <remarks>
1030+
/// <para>
1031+
/// This method enables pooling operations (MaxPool2D, AvgPool2D, GlobalAveragePooling) to be
1032+
/// included in fused operation patterns such as Conv+Pool or Conv+BN+ReLU+Pool. The pooling
1033+
/// parameters (kernel size, stride, padding) are extracted from the HLIR node attributes.
1034+
/// </para>
1035+
/// <para>
1036+
/// Pooling operations have spatial window, stride, and padding parameters that distinguish
1037+
/// them from simple reductions. This method uses FusedOp to preserve these windowed operation
1038+
/// semantics for efficient runtime execution.
1039+
/// </para>
1040+
/// </remarks>
1041+
private FusedOp CreatePoolingOpForFusion(
1042+
HLIRNode<T> node,
1043+
int[] outputShape,
1044+
IRDataType outputDataType,
1045+
int[] inputIds)
1046+
{
1047+
var outputId = _llirGraph.AllocateBufferId();
1048+
1049+
var pattern = node.Operation switch
1050+
{
1051+
OperationType.MaxPool2D => "MaxPool2D",
1052+
OperationType.AvgPool2D => "AvgPool2D",
1053+
OperationType.GlobalAveragePooling => "GlobalAvgPool",
1054+
_ => "MaxPool2D"
1055+
};
1056+
1057+
// Extract pooling parameters from node attributes
1058+
var kernelH = GetAttributeInt(node, "kernelH", 2);
1059+
var kernelW = GetAttributeInt(node, "kernelW", 2);
1060+
var strideH = GetAttributeInt(node, "strideH", 2);
1061+
var strideW = GetAttributeInt(node, "strideW", 2);
1062+
var padH = GetAttributeInt(node, "padH", 0);
1063+
var padW = GetAttributeInt(node, "padW", 0);
1064+
1065+
var fusedOp = new FusedOp
1066+
{
1067+
OutputId = outputId,
1068+
Name = node.Name,
1069+
InputIds = inputIds,
1070+
OutputShape = outputShape,
1071+
OutputDataType = outputDataType,
1072+
Device = GetDeviceForNode(node),
1073+
FusionPattern = pattern,
1074+
SourceHLIRNodeId = node.Id
1075+
};
1076+
1077+
// Store pooling parameters in attributes for runtime execution
1078+
fusedOp.Attributes["kernelH"] = kernelH;
1079+
fusedOp.Attributes["kernelW"] = kernelW;
1080+
fusedOp.Attributes["strideH"] = strideH;
1081+
fusedOp.Attributes["strideW"] = strideW;
1082+
fusedOp.Attributes["padH"] = padH;
1083+
fusedOp.Attributes["padW"] = padW;
1084+
1085+
return fusedOp;
1086+
}
1087+
1088+
/// <summary>
1089+
/// Creates a MemoryOp representing a memory/reshape operation for use within fused operation contexts.
1090+
/// </summary>
1091+
/// <param name="node">The HLIR node representing the memory operation.</param>
1092+
/// <param name="outputShape">The output shape for this sub-operation.</param>
1093+
/// <param name="outputDataType">The output data type for this sub-operation.</param>
1094+
/// <param name="inputIds">The resolved input buffer IDs within the fusion context.</param>
1095+
/// <returns>A MemoryOp containing reshape/transpose parameters for LLIR execution.</returns>
1096+
/// <remarks>
1097+
/// <para>
1098+
/// This method enables memory operations (Reshape, Transpose, Flatten) to be included in
1099+
/// fused operation patterns such as attention mechanisms that require reshaping between
1100+
/// matrix multiplications. These operations are typically zero-copy or view operations.
1101+
/// </para>
1102+
/// </remarks>
1103+
private MemoryOp CreateMemoryOpForFusion(
1104+
HLIRNode<T> node,
1105+
int[] outputShape,
1106+
IRDataType outputDataType,
1107+
int[] inputIds)
1108+
{
1109+
var outputId = _llirGraph.AllocateBufferId();
1110+
1111+
var memOpType = node.Operation switch
1112+
{
1113+
OperationType.Reshape => MemoryOpType.Reshape,
1114+
OperationType.Transpose => MemoryOpType.Transpose,
1115+
OperationType.Flatten => MemoryOpType.Reshape,
1116+
_ => MemoryOpType.Copy
1117+
};
1118+
1119+
var op = new MemoryOp
1120+
{
1121+
OutputId = outputId,
1122+
Name = node.Name,
1123+
InputIds = inputIds,
1124+
OutputShape = outputShape,
1125+
OutputDataType = outputDataType,
1126+
Device = GetDeviceForNode(node),
1127+
MemoryOpType = memOpType,
1128+
SourceHLIRNodeId = node.Id
1129+
};
1130+
1131+
// Handle transpose permutation
1132+
if (node.Operation == OperationType.Transpose &&
1133+
node.Attributes.TryGetValue("perm", out var perm))
1134+
{
1135+
op.Permutation = GetAttributeIntArray(perm, "perm");
1136+
}
1137+
1138+
// Handle reshape new shape
1139+
if (node.Operation == OperationType.Reshape || node.Operation == OperationType.Flatten)
1140+
{
1141+
op.NewShape = outputShape;
1142+
}
1143+
1144+
return op;
1145+
}
1146+
1147+
/// <summary>
1148+
/// Creates a ReduceOp representing a reduction operation for use within fused operation contexts.
1149+
/// </summary>
1150+
/// <param name="node">The HLIR node representing the reduction operation.</param>
1151+
/// <param name="outputShape">The output shape for this sub-operation.</param>
1152+
/// <param name="outputDataType">The output data type for this sub-operation.</param>
1153+
/// <param name="inputIds">The resolved input buffer IDs within the fusion context.</param>
1154+
/// <returns>A ReduceOp containing reduction parameters for LLIR execution.</returns>
1155+
/// <remarks>
1156+
/// <para>
1157+
/// This method enables reduction operations (ReduceSum, Mean, ReduceMax, ReduceMin) to be
1158+
/// included in fused operation patterns such as attention softmax normalization or
1159+
/// mean pooling operations. The reduction axes and keepDims parameters are preserved.
1160+
/// </para>
1161+
/// </remarks>
1162+
private ReduceOp CreateReductionOpForFusion(
1163+
HLIRNode<T> node,
1164+
int[] outputShape,
1165+
IRDataType outputDataType,
1166+
int[] inputIds)
1167+
{
1168+
var outputId = _llirGraph.AllocateBufferId();
1169+
1170+
var reduceType = node.Operation switch
1171+
{
1172+
OperationType.ReduceSum => ReduceType.Sum,
1173+
OperationType.Mean or OperationType.ReduceMean => ReduceType.Mean,
1174+
OperationType.ReduceMax => ReduceType.Max,
1175+
OperationType.ReduceMin => ReduceType.Min,
1176+
_ => ReduceType.Sum
1177+
};
1178+
1179+
var axes = node.Attributes.TryGetValue("axes", out var ax) ? GetAttributeIntArray(ax, "axes") : Array.Empty<int>();
1180+
var keepDims = node.Attributes.TryGetValue("keepDims", out var kd) && (bool)kd;
1181+
1182+
// Get input shape for accurate cost estimation
1183+
var inputShape = node.InputTypes.Count > 0 && node.InputTypes[0].Shape != null
1184+
? node.InputTypes[0].Shape
1185+
: Array.Empty<int>();
1186+
1187+
return new ReduceOp
1188+
{
1189+
OutputId = outputId,
1190+
Name = node.Name,
1191+
InputIds = inputIds,
1192+
OutputShape = outputShape,
1193+
OutputDataType = outputDataType,
1194+
Device = GetDeviceForNode(node),
1195+
ReduceType = reduceType,
1196+
Axes = axes,
1197+
KeepDims = keepDims,
1198+
InputShape = inputShape,
1199+
SourceHLIRNodeId = node.Id
1200+
};
1201+
}
1202+
1203+
/// <summary>
1204+
/// Creates a FusedOp representing an attention operation for use within fused operation contexts.
1205+
/// </summary>
1206+
/// <param name="node">The HLIR node representing the attention operation.</param>
1207+
/// <param name="outputShape">The output shape for this sub-operation.</param>
1208+
/// <param name="outputDataType">The output data type for this sub-operation.</param>
1209+
/// <param name="inputIds">The resolved input buffer IDs within the fusion context.</param>
1210+
/// <returns>A FusedOp containing attention parameters for LLIR execution.</returns>
1211+
/// <remarks>
1212+
/// <para>
1213+
/// This method enables attention operations (Attention, MultiHeadAttention) to be included
1214+
/// in fused transformer patterns such as LayerNorm+Attention or Attention+FFN fusions.
1215+
/// These operations benefit significantly from fusion to reduce memory bandwidth.
1216+
/// </para>
1217+
/// </remarks>
1218+
private FusedOp CreateAttentionOpForFusion(
1219+
HLIRNode<T> node,
1220+
int[] outputShape,
1221+
IRDataType outputDataType,
1222+
int[] inputIds)
1223+
{
1224+
var outputId = _llirGraph.AllocateBufferId();
1225+
1226+
var pattern = node.Operation == OperationType.MultiHeadAttention
1227+
? "MultiHeadAttention"
1228+
: "Attention";
1229+
1230+
// Extract attention-specific parameters
1231+
var numHeads = GetAttributeInt(node, "numHeads", 8);
1232+
var headDim = GetAttributeInt(node, "headDim", 64);
1233+
var scale = GetAttributeDouble(node, "scale", 1.0 / Math.Sqrt(headDim));
1234+
var causal = GetAttributeBool(node, "causal", false);
1235+
1236+
var fusedOp = new FusedOp
1237+
{
1238+
OutputId = outputId,
1239+
Name = node.Name,
1240+
InputIds = inputIds,
1241+
OutputShape = outputShape,
1242+
OutputDataType = outputDataType,
1243+
Device = GetDeviceForNode(node),
1244+
FusionPattern = pattern,
1245+
SourceHLIRNodeId = node.Id
1246+
};
1247+
1248+
// Store attention parameters for runtime execution
1249+
fusedOp.Attributes["numHeads"] = numHeads;
1250+
fusedOp.Attributes["headDim"] = headDim;
1251+
fusedOp.Attributes["scale"] = scale;
1252+
fusedOp.Attributes["causal"] = causal;
1253+
1254+
return fusedOp;
1255+
}
1256+
1257+
/// <summary>
1258+
/// Creates an identity ElementwiseOp for operations that are no-ops during inference.
1259+
/// </summary>
1260+
/// <param name="node">The HLIR node representing the no-op operation.</param>
1261+
/// <param name="outputShape">The output shape for this sub-operation.</param>
1262+
/// <param name="outputDataType">The output data type for this sub-operation.</param>
1263+
/// <param name="inputIds">The resolved input buffer IDs within the fusion context.</param>
1264+
/// <returns>An ElementwiseOp configured as an identity operation.</returns>
1265+
/// <remarks>
1266+
/// <para>
1267+
/// This method handles operations like Dropout that become identity operations during
1268+
/// inference. Rather than special-casing these in the fusion executor, we emit an
1269+
/// explicit identity op that can be optimized away during execution planning.
1270+
/// </para>
1271+
/// </remarks>
1272+
private ElementwiseOp CreateIdentityOpForFusion(
1273+
HLIRNode<T> node,
1274+
int[] outputShape,
1275+
IRDataType outputDataType,
1276+
int[] inputIds)
1277+
{
1278+
var outputId = _llirGraph.AllocateBufferId();
1279+
1280+
return new ElementwiseOp
1281+
{
1282+
OutputId = outputId,
1283+
Name = node.Name,
1284+
InputIds = inputIds,
1285+
OutputShape = outputShape,
1286+
OutputDataType = outputDataType,
1287+
Device = GetDeviceForNode(node),
1288+
ElementwiseType = ElementwiseOpType.Identity,
1289+
SourceHLIRNodeId = node.Id
1290+
};
1291+
}
1292+
9961293
#endregion
9971294

9981295
#region Helpers

0 commit comments

Comments
 (0)