Skip to content

Implement todos tensorboard #20874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions tests/tests_fabric/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,29 +147,52 @@ def test_tensorboard_log_hparams_and_metrics(tmp_path):


@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
def test_tensorboard_log_graph(tmp_path, example_input_array):
"""Test that log graph works with both model.example_input_array and if array is passed externally."""
# TODO(fabric): Test both nn.Module and LightningModule
# TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
def test_tensorboard_log_graph_plain_module(tmp_path, example_input_array):
model = BoringModel()
if example_input_array is not None:
model.example_input_array = None

logger = TensorBoardLogger(tmp_path)
logger._experiment = Mock()

logger.log_graph(model, example_input_array)
if example_input_array is not None:
logger.experiment.add_graph.assert_called_with(model, example_input_array)
else:
logger.experiment.add_graph.assert_not_called()

logger._experiment.reset_mock()

# model wrapped in `FabricModule`
wrapped = _FabricModule(model, strategy=Mock())
logger.log_graph(wrapped, example_input_array)
if example_input_array is not None:
logger.experiment.add_graph.assert_called_with(model, example_input_array)


@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
def test_tensorboard_log_graph_with_batch_transfer_hooks(tmp_path, example_input_array):
model = pytest.importorskip("lightning.pytorch.demos.boring_classes").BoringModel()
logger = TensorBoardLogger(tmp_path)
logger._experiment = Mock()

with (
mock.patch.object(model, "_on_before_batch_transfer", return_value=example_input_array) as before_mock,
mock.patch.object(model, "_apply_batch_transfer_handler", return_value=example_input_array) as transfer_mock,
):
logger.log_graph(model, example_input_array)
logger._experiment.reset_mock()

wrapped = _FabricModule(model, strategy=Mock())
logger.log_graph(wrapped, example_input_array)

if example_input_array is not None:
assert before_mock.call_count == 2
assert transfer_mock.call_count == 2
logger.experiment.add_graph.assert_called_with(model, example_input_array)
else:
before_mock.assert_not_called()
transfer_mock.assert_not_called()
logger.experiment.add_graph.assert_not_called()


@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason="tensorboard is required")
def test_tensorboard_log_graph_warning_no_example_input_array(tmp_path):
"""Test that log graph throws warning if model.example_input_array is None."""
model = BoringModel()
Expand Down
Loading