|
34 | 34 | EDGE_DIALECT_GRAPH_KEY,
|
35 | 35 | find_populated_event,
|
36 | 36 | gen_graphs_from_etrecord,
|
| 37 | + get_aot_debug_handle_to_op_name_mapping, |
37 | 38 | is_inference_output_equal,
|
38 | 39 | map_runtime_aot_intermediate_outputs,
|
39 | 40 | merge_overlapping_debug_handles,
|
| 41 | + NodeFilter, |
40 | 42 | TimeScale,
|
41 | 43 | )
|
42 | 44 |
|
@@ -364,6 +366,112 @@ class X:
|
364 | 366 | msg = str(cm.exception)
|
365 | 367 | self.assertIn("Cannot convert value of type", msg)
|
366 | 368 |
|
| 369 | + def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self): |
| 370 | + # Create a simple graph module with one node |
| 371 | + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) |
| 372 | + node = graph_module.graph.create_node( |
| 373 | + "call_function", target=torch.mul, args=(), kwargs={}, name="op1" |
| 374 | + ) |
| 375 | + node.meta["debug_handle"] = 1 |
| 376 | + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) |
| 377 | + expected_result = {(1,): "op1"} |
| 378 | + self.assertEqual(debug_handle_to_op_name, expected_result) |
| 379 | + |
| 380 | + def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self): |
| 381 | + # Create a simple graph module with two nodes |
| 382 | + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) |
| 383 | + node1 = graph_module.graph.create_node( |
| 384 | + "call_function", target=torch.mul, args=(), kwargs={}, name="op1" |
| 385 | + ) |
| 386 | + node1.meta["debug_handle"] = (1, 2) |
| 387 | + node2 = graph_module.graph.create_node( |
| 388 | + "call_function", target=torch.mul, args=(), kwargs={}, name="op2" |
| 389 | + ) |
| 390 | + node2.meta["debug_handle"] = 3 |
| 391 | + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) |
| 392 | + expected_result = { |
| 393 | + ( |
| 394 | + 1, |
| 395 | + 2, |
| 396 | + ): "op1", |
| 397 | + (3,): "op2", |
| 398 | + } |
| 399 | + self.assertEqual(debug_handle_to_op_name, expected_result) |
| 400 | + |
| 401 | + def test_get_aot_debug_handle_to_op_name_mapping_no_debug_handles(self): |
| 402 | + # Create a simple graph module with no nodes |
| 403 | + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) |
| 404 | + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) |
| 405 | + expected_result = {} |
| 406 | + self.assertEqual(debug_handle_to_op_name, expected_result) |
| 407 | + |
| 408 | + def test_node_filter_match(self): |
| 409 | + node_filter = NodeFilter( |
| 410 | + "debug_handle", "call_function", exclude_ops=["getitem"] |
| 411 | + ) |
| 412 | + |
| 413 | + # Create a mock node that matches the filter criteria |
| 414 | + mock_node = torch.fx.Node( |
| 415 | + graph=torch.fx.Graph(), |
| 416 | + name="mock_node", |
| 417 | + op="call_function", |
| 418 | + target=torch.nn.functional.relu, |
| 419 | + args=(), |
| 420 | + kwargs={}, |
| 421 | + ) |
| 422 | + mock_node.meta["debug_handle"] = (1, 2) |
| 423 | + # Test that the filter matches the mock node |
| 424 | + self.assertTrue(node_filter.matches(mock_node)) |
| 425 | + |
| 426 | + def test_node_filter_key_mismatch(self): |
| 427 | + node_filter = NodeFilter( |
| 428 | + "debug_handle", "call_function", exclude_ops=["getitem"] |
| 429 | + ) |
| 430 | + mock_node_metadata_key_mismatch = torch.fx.Node( |
| 431 | + graph=torch.fx.Graph(), |
| 432 | + name="mock_node_metadata_key_mismatch", |
| 433 | + op="call_function", |
| 434 | + target=torch.nn.functional.relu, |
| 435 | + args=(), |
| 436 | + kwargs={}, |
| 437 | + ) |
| 438 | + # Test that the filter doesn't match the mock node (meta doesn't have debug_handle key) |
| 439 | + self.assertFalse(node_filter.matches(mock_node_metadata_key_mismatch)) |
| 440 | + |
| 441 | + def test_node_filter_ops_mismatch(self): |
| 442 | + node_filter = NodeFilter( |
| 443 | + "debug_handle", "call_function", exclude_ops=["getitem"] |
| 444 | + ) |
| 445 | + |
| 446 | + mock_node_exclude_ops_mismatch = torch.fx.Node( |
| 447 | + graph=torch.fx.Graph(), |
| 448 | + name="getitem", |
| 449 | + op="call_function", |
| 450 | + target=torch.nn.functional.relu, |
| 451 | + args=(), |
| 452 | + kwargs={}, |
| 453 | + ) |
| 454 | + mock_node_exclude_ops_mismatch.meta["debug_handle"] = (1, 2) |
| 455 | + # Test that the filter doesn't match the mock node (exclude_ops mismatch) |
| 456 | + self.assertFalse(node_filter.matches(mock_node_exclude_ops_mismatch)) |
| 457 | + |
| 458 | + def test_node_op_type_mismatch(self): |
| 459 | + node_filter = NodeFilter( |
| 460 | + "debug_handle", "call_function", exclude_ops=["getitem"] |
| 461 | + ) |
| 462 | + |
| 463 | + mock_node_op_type_mismatch = torch.fx.Node( |
| 464 | + graph=torch.fx.Graph(), |
| 465 | + name="mock_node_op_type_mismatch", |
| 466 | + op="get_attr", |
| 467 | + target="torch.nn.functional.relu", |
| 468 | + args=(), |
| 469 | + kwargs={}, |
| 470 | + ) |
| 471 | + mock_node_op_type_mismatch.meta["debug_handle"] = (1, 2) |
| 472 | + # Test that the filter doesn't match the mock node (op_type mismatch) |
| 473 | + self.assertFalse(node_filter.matches(mock_node_op_type_mismatch)) |
| 474 | + |
367 | 475 |
|
368 | 476 | def gen_mock_operator_graph_with_expected_map() -> (
|
369 | 477 | Tuple[OperatorGraph, Dict[int, OperatorNode]]
|
|
0 commit comments