diff --git a/truss/cli/train/deploy_checkpoints.py b/truss/cli/train/deploy_checkpoints.py index aa2996899..4cda489f8 100644 --- a/truss/cli/train/deploy_checkpoints.py +++ b/truss/cli/train/deploy_checkpoints.py @@ -284,6 +284,10 @@ def _get_compute(compute: Optional[Compute]) -> Compute: if not compute: compute = Compute(cpu_count=0, memory="0Mi") compute.accelerator = _get_accelerator_if_specified(compute.accelerator) + # User did not specify an accelerator, so we default to CPU. + if not compute.accelerator: + compute.cpu_count = int(truss_config.DEFAULT_CPU) + compute.memory = truss_config.DEFAULT_MEMORY return compute diff --git a/truss/tests/cli/train/test_deploy_checkpoints.py b/truss/tests/cli/train/test_deploy_checkpoints.py index 35ef802bc..bcaa0a6ab 100644 --- a/truss/tests/cli/train/test_deploy_checkpoints.py +++ b/truss/tests/cli/train/test_deploy_checkpoints.py @@ -57,6 +57,13 @@ def deploy_checkpoints_mock_select(create_mock_prompt): yield mock +@pytest.fixture +def deploy_checkpoints_mock_select_cpu(create_mock_prompt): + with patch("truss.cli.train.deploy_checkpoints.inquirer.select") as mock: + mock.side_effect = lambda message, **kwargs: create_mock_prompt(None) + yield mock + + @pytest.fixture def deploy_checkpoints_mock_text(create_mock_prompt): with patch("truss.cli.train.deploy_checkpoints.inquirer.text") as mock: @@ -164,6 +171,28 @@ def test_prepare_checkpoint_deploy_empty_config( ) +def test_prepare_checkpoint_deploy_empty_config_cpu( + mock_remote, + deploy_checkpoints_mock_select_cpu, + deploy_checkpoints_mock_text, + deploy_checkpoints_mock_checkbox, +): + # Create empty config + empty_config = definitions.DeployCheckpointsConfig() + result = prepare_checkpoint_deploy( + remote_provider=mock_remote, + checkpoint_deploy_config=empty_config, + project_id="project123", + job_id="job123", + ) + assert isinstance(result, PrepareCheckpointResult) + assert result.checkpoint_deploy_config.compute.cpu_count == int( + truss_config.DEFAULT_CPU + ) + assert result.checkpoint_deploy_config.compute.memory == truss_config.DEFAULT_MEMORY + assert result.checkpoint_deploy_config.compute.accelerator is None + + def test_prepare_checkpoint_deploy_complete_config( mock_remote, deploy_checkpoints_mock_select,