Skip to content

Commit 01d826a

Browse files
chilo-msSanket Kale
authored andcommitted
Enable free dimension override for graph optimization level 0 (microsoft#25425)
Enable free dimension override even when the graph optimization level is 0. If the optimization level is above 0, the override will be applied during level 1 optimization.
1 parent 85e8fb7 commit 01d826a

File tree

4 files changed

+105
-30
lines changed

4 files changed

+105
-30
lines changed

onnxruntime/core/optimizer/graph_transformer_utils.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,12 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
220220
AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance();
221221

222222
switch (level) {
223+
case TransformerLevel::Default: {
224+
if (!session_options.free_dimension_overrides.empty()) {
225+
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
226+
session_options.free_dimension_overrides));
227+
}
228+
} break;
223229
case TransformerLevel::Level1: {
224230
// RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run)
225231
// so run them first so there is potentially less for the more intensive optimizations like ConstantFolding,

onnxruntime/core/session/inference_session.cc

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,7 +1349,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
13491349
ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(ensure_unique_dq_for_node_unit, *session_logger_, graph));
13501350
}
13511351

1352-
// apply execution provider independent level 1 graph optimizations.
1352+
// apply execution provider independent level 0 and 1 graph optimizations.
1353+
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.ApplyTransformers(graph, TransformerLevel::Default, *session_logger_));
13531354
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.ApplyTransformers(graph, TransformerLevel::Level1, *session_logger_));
13541355

13551356
// if saving model to ORT format we only assign nodes a custom EP can handle and don't compile them.
@@ -3645,34 +3646,50 @@ common::Status InferenceSession::AddPredefinedTransformers(
36453646
RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn,
36463647
const logging::Logger& logger) const {
36473648
const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider);
3648-
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
3649+
for (int i = static_cast<int>(TransformerLevel::Default); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
36493650
TransformerLevel level = static_cast<TransformerLevel>(i);
3650-
if (graph_optimization_level >= level) {
3651-
// Generate and register transformers for level
3652-
auto transformers_to_register = [&]() {
3653-
const bool use_full_build_optimizations =
3654-
level == TransformerLevel::Level1 ||
3655-
minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations;
3656-
3657-
if (use_full_build_optimizations) {
3651+
std::function<onnxruntime::InlinedVector<std::unique_ptr<GraphTransformer>>()> transformers_to_register;
3652+
3653+
// Enable free dimension override even when the graph optimization level is 0.
3654+
// If the optimization level is above 0, the override will be applied during level 1 optimization.
3655+
if (level == TransformerLevel::Default) {
3656+
if (graph_optimization_level == TransformerLevel::Default) {
3657+
transformers_to_register = [&]() {
36583658
return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger,
36593659
optimizers_to_disable_,
36603660
GetIntraOpThreadPoolToUse());
3661-
} else {
3662-
const auto sat_context =
3663-
minimal_build_optimization_handling ==
3664-
MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations
3665-
? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{
3666-
record_runtime_optimization_produced_op_schema_fn}}
3667-
: SatApplyContextVariant{SatDirectApplicationContext{}};
3668-
return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep,
3669-
logger,
3670-
optimizers_to_disable_,
3671-
GetIntraOpThreadPoolToUse());
3672-
}
3673-
}();
3661+
};
3662+
}
3663+
} else {
3664+
if (graph_optimization_level >= level) {
3665+
// Generate and register transformers for level
3666+
transformers_to_register = [&]() {
3667+
const bool use_full_build_optimizations =
3668+
level == TransformerLevel::Level1 ||
3669+
minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations;
3670+
3671+
if (use_full_build_optimizations) {
3672+
return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger,
3673+
optimizers_to_disable_,
3674+
GetIntraOpThreadPoolToUse());
3675+
} else {
3676+
const auto sat_context =
3677+
minimal_build_optimization_handling ==
3678+
MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations
3679+
? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{
3680+
record_runtime_optimization_produced_op_schema_fn}}
3681+
: SatApplyContextVariant{SatDirectApplicationContext{}};
3682+
return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep,
3683+
logger,
3684+
optimizers_to_disable_,
3685+
GetIntraOpThreadPoolToUse());
3686+
}
3687+
};
3688+
}
3689+
}
36743690

3675-
for (auto& entry : transformers_to_register) {
3691+
if (transformers_to_register) { // Ensure the lambda is initialized before invoking it
3692+
for (auto& entry : transformers_to_register()) {
36763693
ORT_RETURN_IF_ERROR(transformer_manager.Register(std::move(entry), level));
36773694
}
36783695
}

onnxruntime/test/optimizer/free_dimension_override_test.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using namespace ONNX_NAMESPACE;
1818
namespace onnxruntime {
1919
namespace test {
2020

21-
void TestFreeDimensions(FreeDimensionOverrideType overrideType) {
21+
void TestFreeDimensions(FreeDimensionOverrideType overrideType, TransformerLevel level) {
2222
auto model_uri = ORT_TSTR("testdata/abs_free_dimensions.onnx");
2323

2424
std::shared_ptr<Model> model;
@@ -43,9 +43,9 @@ void TestFreeDimensions(FreeDimensionOverrideType overrideType) {
4343
auto graph_transformer = std::make_unique<FreeDimensionOverrideTransformer>(overrides);
4444

4545
onnxruntime::GraphTransformerManager graph_transformation_mgr(5);
46-
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(graph_transformer), TransformerLevel::Level1));
46+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(graph_transformer), level));
4747

48-
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1,
48+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, level,
4949
DefaultLoggingManager().DefaultLogger()));
5050

5151
// Verify that the shape of the input graph has the correct values
@@ -73,8 +73,10 @@ void TestFreeDimensions(FreeDimensionOverrideType overrideType) {
7373
}
7474

7575
TEST(FreeDimensionOverrideDenotationTransformerTest, Test) {
76-
TestFreeDimensions(FreeDimensionOverrideType::Denotation);
77-
TestFreeDimensions(FreeDimensionOverrideType::Name);
76+
TestFreeDimensions(FreeDimensionOverrideType::Denotation, TransformerLevel::Level1);
77+
TestFreeDimensions(FreeDimensionOverrideType::Name, TransformerLevel::Level1);
78+
TestFreeDimensions(FreeDimensionOverrideType::Denotation, TransformerLevel::Default);
79+
TestFreeDimensions(FreeDimensionOverrideType::Name, TransformerLevel::Default);
7880
}
7981

8082
} // namespace test

onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,50 @@ void CheckNhwcTransformerIsApplied(const PathString& ort_model_path,
280280
graph_op_counts_checker,
281281
graph_checker));
282282
};
283+
284+
#if !defined(ORT_MINIMAL_BUILD)
285+
// if level 0 optimization is enabled the free dimension override should be enabled.
286+
void CheckFreeDimensionOverrideIsApplied(const PathString& model_path,
287+
TransformerLevel level,
288+
FreeDimensionOverrideType overrideType) {
289+
SessionOptions so{};
290+
so.graph_optimization_level = level;
291+
if (overrideType == FreeDimensionOverrideType::Denotation) {
292+
so.free_dimension_overrides.push_back(
293+
onnxruntime::FreeDimensionOverride{"DATA_BATCH", overrideType, 1});
294+
so.free_dimension_overrides.push_back(
295+
onnxruntime::FreeDimensionOverride{"DATA_CHANNEL", overrideType, 42});
296+
} else {
297+
so.free_dimension_overrides.push_back(
298+
onnxruntime::FreeDimensionOverride{"Dim1", overrideType, 1});
299+
so.free_dimension_overrides.push_back(
300+
onnxruntime::FreeDimensionOverride{"Dim2", overrideType, 42});
301+
}
302+
303+
GraphCheckerFn graph_checker = [](const Graph& graph) {
304+
// Verify that the shape of the input graph has the correct values
305+
306+
const auto& graph_inputs = graph.GetInputs();
307+
ASSERT_TRUE(graph_inputs.size() == 1); // This model only has a single input ('x')
308+
309+
const auto* input_shape = graph_inputs[0]->Shape();
310+
ASSERT_TRUE(input_shape->dim_size() == 3); // Model takes a 3D tensor as input; two of those dimensions are (were) free dimensions
311+
312+
ASSERT_TRUE(input_shape->dim(0).denotation() == "DATA_BATCH");
313+
ASSERT_TRUE(input_shape->dim(0).has_dim_value());
314+
ASSERT_TRUE(input_shape->dim(0).dim_value() == 1);
315+
316+
ASSERT_TRUE(input_shape->dim(1).denotation() == "DATA_CHANNEL");
317+
ASSERT_TRUE(input_shape->dim(1).has_dim_value());
318+
ASSERT_TRUE(input_shape->dim(1).dim_value() == 42);
319+
};
320+
321+
ASSERT_NO_FATAL_FAILURE(LoadAndInitializeSession(
322+
so, model_path,
323+
nullptr,
324+
graph_checker));
325+
};
326+
#endif // !defined(ORT_MINIMAL_BUILD)
283327
} // namespace
284328

285329
TEST(GraphRuntimeOptimizationTest, QDQConv) {
@@ -374,8 +418,14 @@ TEST(GraphRuntimeOptimizationTest, TestNhwcTransformerDirectlyUpdatesQLinearConv
374418
{"com.microsoft.QLinearConv", n}}));
375419
});
376420
}
377-
378421
#if !defined(ORT_MINIMAL_BUILD)
422+
TEST(GraphRuntimeOptimizationTest, TestFreeDimensionOverride) {
423+
CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Default, FreeDimensionOverrideType::Denotation);
424+
CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Default, FreeDimensionOverrideType::Name);
425+
CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Level1, FreeDimensionOverrideType::Denotation);
426+
CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Level1, FreeDimensionOverrideType::Name);
427+
}
428+
379429
TEST(GraphRuntimeOptimizationTest, TestOnlyApplyMinimalBuildOptimizations) {
380430
// This test assumes that AttentionFusion is not included in the minimal build optimizations.
381431
// Update it if that changes.

0 commit comments

Comments
 (0)