Skip to content
6 changes: 6 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance();

switch (level) {
case TransformerLevel::Default: {
if (!session_options.free_dimension_overrides.empty()) {
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
session_options.free_dimension_overrides));
}
} break;
case TransformerLevel::Level1: {
// RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run)
// so run them first so there is potentially less for the more intensive optimizations like ConstantFolding,
Expand Down
65 changes: 41 additions & 24 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(ensure_unique_dq_for_node_unit, *session_logger_, graph));
}

// apply execution provider independent level 1 graph optimizations.
// apply execution provider independent level 0 and 1 graph optimizations.
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.ApplyTransformers(graph, TransformerLevel::Default, *session_logger_));
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.ApplyTransformers(graph, TransformerLevel::Level1, *session_logger_));

// if saving model to ORT format we only assign nodes a custom EP can handle and don't compile them.
Expand Down Expand Up @@ -3645,34 +3646,50 @@ common::Status InferenceSession::AddPredefinedTransformers(
RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn,
const logging::Logger& logger) const {
const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider);
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
for (int i = static_cast<int>(TransformerLevel::Default); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
TransformerLevel level = static_cast<TransformerLevel>(i);
if (graph_optimization_level >= level) {
// Generate and register transformers for level
auto transformers_to_register = [&]() {
const bool use_full_build_optimizations =
level == TransformerLevel::Level1 ||
minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations;

if (use_full_build_optimizations) {
std::function<onnxruntime::InlinedVector<std::unique_ptr<GraphTransformer>>()> transformers_to_register;

// Enable free dimension override even when the graph optimization level is 0.
// If the optimization level is above 0, the override will be applied during level 1 optimization.
if (level == TransformerLevel::Default) {
if (graph_optimization_level == TransformerLevel::Default) {
transformers_to_register = [&]() {
return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger,
optimizers_to_disable_,
GetIntraOpThreadPoolToUse());
} else {
const auto sat_context =
minimal_build_optimization_handling ==
MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations
? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{
record_runtime_optimization_produced_op_schema_fn}}
: SatApplyContextVariant{SatDirectApplicationContext{}};
return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep,
logger,
optimizers_to_disable_,
GetIntraOpThreadPoolToUse());
}
}();
};
}
} else {
if (graph_optimization_level >= level) {
// Generate and register transformers for level
transformers_to_register = [&]() {
const bool use_full_build_optimizations =
level == TransformerLevel::Level1 ||
minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations;

if (use_full_build_optimizations) {
return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger,
optimizers_to_disable_,
GetIntraOpThreadPoolToUse());
} else {
const auto sat_context =
minimal_build_optimization_handling ==
MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations
? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{
record_runtime_optimization_produced_op_schema_fn}}
: SatApplyContextVariant{SatDirectApplicationContext{}};
return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep,
logger,
optimizers_to_disable_,
GetIntraOpThreadPoolToUse());
}
};
}
}

for (auto& entry : transformers_to_register) {
if (transformers_to_register) { // Ensure the lambda is initialized before invoking it
for (auto& entry : transformers_to_register()) {
ORT_RETURN_IF_ERROR(transformer_manager.Register(std::move(entry), level));
}
}
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/test/optimizer/free_dimension_override_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {

void TestFreeDimensions(FreeDimensionOverrideType overrideType) {
void TestFreeDimensions(FreeDimensionOverrideType overrideType, TransformerLevel level) {
auto model_uri = ORT_TSTR("testdata/abs_free_dimensions.onnx");

std::shared_ptr<Model> model;
Expand All @@ -43,9 +43,9 @@ void TestFreeDimensions(FreeDimensionOverrideType overrideType) {
auto graph_transformer = std::make_unique<FreeDimensionOverrideTransformer>(overrides);

onnxruntime::GraphTransformerManager graph_transformation_mgr(5);
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(graph_transformer), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(graph_transformer), level));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1,
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, level,
DefaultLoggingManager().DefaultLogger()));

// Verify that the shape of the input graph has the correct values
Expand Down Expand Up @@ -73,8 +73,10 @@ void TestFreeDimensions(FreeDimensionOverrideType overrideType) {
}

TEST(FreeDimensionOverrideDenotationTransformerTest, Test) {
TestFreeDimensions(FreeDimensionOverrideType::Denotation);
TestFreeDimensions(FreeDimensionOverrideType::Name);
TestFreeDimensions(FreeDimensionOverrideType::Denotation, TransformerLevel::Level1);
TestFreeDimensions(FreeDimensionOverrideType::Name, TransformerLevel::Level1);
TestFreeDimensions(FreeDimensionOverrideType::Denotation, TransformerLevel::Default);
TestFreeDimensions(FreeDimensionOverrideType::Name, TransformerLevel::Default);
}

} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,50 @@ void CheckNhwcTransformerIsApplied(const PathString& ort_model_path,
graph_op_counts_checker,
graph_checker));
};

#if !defined(ORT_MINIMAL_BUILD)
// if level 0 optimization is enabled the free dimension override should be enabled.
void CheckFreeDimensionOverrideIsApplied(const PathString& model_path,
TransformerLevel level,
FreeDimensionOverrideType overrideType) {
SessionOptions so{};
so.graph_optimization_level = level;
if (overrideType == FreeDimensionOverrideType::Denotation) {
so.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{"DATA_BATCH", overrideType, 1});
so.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{"DATA_CHANNEL", overrideType, 42});
} else {
so.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{"Dim1", overrideType, 1});
so.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{"Dim2", overrideType, 42});
}

GraphCheckerFn graph_checker = [](const Graph& graph) {
// Verify that the shape of the input graph has the correct values

const auto& graph_inputs = graph.GetInputs();
ASSERT_TRUE(graph_inputs.size() == 1); // This model only has a single input ('x')

const auto* input_shape = graph_inputs[0]->Shape();
ASSERT_TRUE(input_shape->dim_size() == 3); // Model takes a 3D tensor as input; two of those dimensions are (were) free dimensions

ASSERT_TRUE(input_shape->dim(0).denotation() == "DATA_BATCH");
ASSERT_TRUE(input_shape->dim(0).has_dim_value());
ASSERT_TRUE(input_shape->dim(0).dim_value() == 1);

ASSERT_TRUE(input_shape->dim(1).denotation() == "DATA_CHANNEL");
ASSERT_TRUE(input_shape->dim(1).has_dim_value());
ASSERT_TRUE(input_shape->dim(1).dim_value() == 42);
};

ASSERT_NO_FATAL_FAILURE(LoadAndInitializeSession(
so, model_path,
nullptr,
graph_checker));
};
#endif // !defined(ORT_MINIMAL_BUILD)
} // namespace

TEST(GraphRuntimeOptimizationTest, QDQConv) {
Expand Down Expand Up @@ -374,8 +418,14 @@ TEST(GraphRuntimeOptimizationTest, TestNhwcTransformerDirectlyUpdatesQLinearConv
{"com.microsoft.QLinearConv", n}}));
});
}

#if !defined(ORT_MINIMAL_BUILD)
TEST(GraphRuntimeOptimizationTest, TestFreeDimensionOverride) {
CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Default, FreeDimensionOverrideType::Denotation);
CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Default, FreeDimensionOverrideType::Name);
CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Level1, FreeDimensionOverrideType::Denotation);
CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Level1, FreeDimensionOverrideType::Name);
}

TEST(GraphRuntimeOptimizationTest, TestOnlyApplyMinimalBuildOptimizations) {
// This test assumes that AttentionFusion is not included in the minimal build optimizations.
// Update it if that changes.
Expand Down
Loading