Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 63 additions & 12 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,28 @@ def run(cls, job: "Job"):
progress=1,
)

# Mid-bootstrap cancel guard. ``collect_images`` above can run for many
# minutes on large collections (S3 list + DB joins), and the user may
# cancel during that window. ``Job.cancel()`` for ASYNC_API does
# ``revoke(terminate=False)`` to avoid SIGKILL'ing this worker, then
# writes REVOKED + tears down the NATS stream / Redis state. Without
# this check we would (a) clobber the cancel's REVOKED via the next
# full ``job.save()`` and (b) proceed to ``queue_images_to_nats``,
# recreating the stream the cancel just deleted and dispatching real
# GPU work to ADC for a revoked job. Refresh is read-only against the
# ``status`` column; the in-memory ``progress`` mutations from the
# collect stage are intentionally dropped on the bail path because the
# job is settled — no further progress writes make sense. Covers
# ASYNC_API (NATS dispatch) and SYNC paths (Celery sub-tasks in
# ``process_images``); INTERNAL jobs benefit too. See
# RolnickLab/antenna#1323.
db_status = Job.objects.values_list("status", flat=True).get(pk=job.pk)
if db_status in JobState.final_states() or db_status == JobState.CANCELING:
job.logger.info(
f"Job {job.pk} settled to {db_status} during bootstrap; " f"skipping dispatch of {len(images)} images"
)
return

# End image collection stage
job.save()

Expand Down Expand Up @@ -1041,6 +1063,21 @@ def setup(self, save=True):
if save:
self.save()

def is_settled(self) -> bool:
"""Return True when the job is in a terminal state or being cancelled.

Used by every code path that must not start (or continue) work for a
job whose lifecycle has effectively ended: the ``run_job`` early-guard
(after acks_late redelivery), the ``MLJob.run`` mid-bootstrap cancel
check (before ``queue_images_to_nats`` dispatches GPU work to ADC),
the prerun signal handler (so a redelivered or canceled message does
not get its status reset to PENDING), and ``_fail_job``. Centralized
so the predicate stays in one place — adding a new "do not resume"
status only requires touching :meth:`JobState.final_states` or this
method.
"""
return self.status in JobState.final_states() or self.status == JobState.CANCELING

def run(self):
"""
Run the job.
Expand All @@ -1067,25 +1104,39 @@ def retry(self, async_task=True):

def cancel(self):
"""
Cancel a job. For async_api jobs, clean up NATS/Redis resources
and transition through CANCELING → REVOKED. For other jobs,
revoke the Celery task.
Cancel a job.

For ASYNC_API jobs the long-running work is on remote ADC workers via
NATS, not in the local ``run_job`` celery task — by the time the user
clicks cancel, ``run_job`` has usually already finished
``queue_images_to_nats`` and returned. Tearing down the NATS stream +
Redis state (``cleanup_async_job_if_needed``) is what actually stops
further work: ADC stops being delivered tasks, and any in-flight
result handlers see no Redis state and fast-fail. Calling
``revoke(terminate=True)`` on the (likely-done) run_job would SIGTERM
the worker child if it happens to still be inside the bootstrap (e.g.
a slow ``filter_processed_images`` for a huge collection), which
prior to ``acks_late`` was an unrecoverable message loss. We revoke
without terminate so a not-yet-started copy is dropped without
killing in-flight bootstrap; the in-flight copy then notices
``status == CANCELING`` via the early-guard in ``run_job`` next time
it's invoked (e.g. on redelivery) and bails out cleanly.

For INTERNAL / SYNC_API jobs the celery task body owns the entire
job lifecycle, so terminating it remains the only way to stop
active work.
"""
self.status = JobState.CANCELING
self.save()

is_async_api = self.dispatch_mode == JobDispatchMode.ASYNC_API
if self.task_id:
task = run_job.AsyncResult(self.task_id)
if task:
task.revoke(terminate=True)
if self.dispatch_mode == JobDispatchMode.ASYNC_API:
# For async jobs we need to set the status to revoked here since the task already
# finished (it only queues the images).
self.status = JobState.REVOKED
self.save()
else:
self.status = JobState.REVOKED
self.save()
task.revoke(terminate=not is_async_api)

self.status = JobState.REVOKED
Comment on lines 1134 to +1138
self.save()

cleanup_async_job_if_needed(self)

Expand Down
84 changes: 68 additions & 16 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,19 @@ def update_async_services_seen_for_project(project_id: int) -> None:
)


@celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
# acks_late + reject_on_worker_lost so a worker SIGKILL/OOM mid-task does not
# silently drop the job: the broker holds the message until the task body
# either completes successfully or raises, and redelivers if the worker dies.
# Pairs with the early-guard below — a redelivered run_job that finds the job
# already in a terminal state (or mid-cancellation) returns cleanly instead of
# re-running side effects. See RolnickLab/antenna#1323.
@celery_app.task(
bind=True,
soft_time_limit=default_soft_time_limit,
time_limit=default_time_limit,
acks_late=True,
reject_on_worker_lost=True,
)
def run_job(self, job_id: int) -> None:
from ami.jobs.models import Job

Expand All @@ -145,21 +157,34 @@ def run_job(self, job_id: int) -> None:
except Job.DoesNotExist as e:
raise e
# self.retry(exc=e, countdown=1, max_retries=1)

# Early-guard: under acks_late, the broker may redeliver this message after a
# worker SIGKILL/OOM, and Job.cancel() may also flip status to CANCELING /
# REVOKED while the message sits in the prefetch buffer. Don't re-run a job
# that's already settled or being torn down. The companion guard in
# pre_update_job_status above prevents the task_prerun signal from
# overwriting that status with PENDING before we get here.
if job.is_settled():
job.logger.info(
f"Skipping run_job for job {job.pk}: already in status {job.status} "
f"(redelivery or cancellation in flight)"
)
return
Comment on lines +161 to +172
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Entry-only cancel guard is race-prone for ASYNC_API jobs.

Lines 165-170 guard only before job.run(). If cancel happens after that check, the task can still reach async dispatch and enqueue work under a canceled job because ASYNC_API cancel no longer terminates the worker process. Add a second DB refresh/status check immediately before async dispatch (e.g., right before queue_images_to_nats) and abort when status is CANCELING/terminal.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@ami/jobs/tasks.py` around lines 161 - 170, The pre-run guard using job.status
/ JobState.final_states() is insufficient for ASYNC_API jobs because
cancellation may occur after the initial check but before dispatch; to fix, add
a second status refresh and guard immediately before the async dispatch call
(right before queue_images_to_nats) by reloading the Job from the DB (e.g., call
the model refresh/get by PK) and aborting the task (return) if the reloaded
job.status is JobState.CANCELING or in JobState.final_states(), logging a
similar skip message; ensure you reference the same job PK/logger and perform
this check right before queue_images_to_nats to avoid enqueuing work for
canceled jobs.


job.logger.info(f"Running job {job}")
try:
job.run()
except Exception as e:
job.logger.error(f'Job #{job.pk} "{job.name}" failed: {e}')
raise
else:
job.logger.info(f"Running job {job}")
try:
job.run()
except Exception as e:
job.logger.error(f'Job #{job.pk} "{job.name}" failed: {e}')
raise
else:
from ami.jobs.models import JobDispatchMode
from ami.jobs.models import JobDispatchMode

job.refresh_from_db()
if job.dispatch_mode == JobDispatchMode.ASYNC_API and not job.progress.is_complete():
_log_worker_availability(job)
else:
job.logger.info(f"Finished job {job}")
job.refresh_from_db()
if job.dispatch_mode == JobDispatchMode.ASYNC_API and not job.progress.is_complete():
_log_worker_availability(job)
else:
job.logger.info(f"Finished job {job}")


def _log_worker_availability(job) -> None:
Expand Down Expand Up @@ -421,7 +446,7 @@ def _fail_job(job_id: int, reason: str) -> None:
try:
with transaction.atomic():
job = Job.objects.select_for_update().get(pk=job_id)
if job.status in (JobState.CANCELING, *JobState.final_states()):
if job.is_settled():
return
job.update_status(JobState.FAILURE, save=False)
job.finished_at = datetime.datetime.now()
Expand Down Expand Up @@ -1304,7 +1329,34 @@ def cleanup_async_job_if_needed(job) -> None:

@task_prerun.connect(sender=run_job)
def pre_update_job_status(sender, task_id, task, **kwargs):
# in the prerun signal, set the job status to PENDING
"""Bump the job to PENDING when a worker picks the message up.

Skipped when the job is already settled (terminal state) or being
cancelled. Without that guard, a broker redelivery (acks_late + worker
crash) or a cancel that arrived while the message was still in the
prefetch buffer would have its REVOKED/CANCELING status silently
overwritten with PENDING here, and the ``run_job`` early-guard
(which reads ``Job.status`` after this signal fires) would then fail
to short-circuit and re-run the job. See RolnickLab/antenna#1323.
"""
from ami.jobs.models import Job

job_id = task.request.kwargs.get("job_id") if task.request.kwargs else None
if job_id is None and task.request.args:
job_id = task.request.args[0]
if job_id is not None:
try:
job = Job.objects.only("status").get(pk=job_id)
except Job.DoesNotExist:
pass
else:
if job.is_settled():
logger.info(
"task_prerun: skipping PENDING write for job %s in status %s " "(redelivery or cancel in flight)",
job_id,
job.status,
)
return
update_job_status(sender, task_id, task, "PENDING", **kwargs)


Expand Down
134 changes: 130 additions & 4 deletions ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,136 @@ def test_run_job_unauthenticated(self):
# Accept either 401 (TokenAuthentication) or 403 (SessionAuthentication with AnonymousUser)
self.assertIn(resp.status_code, [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN])

def test_cancel_job(self):
# This cannot be tested until we have a way to cancel jobs
# and a way to run async tasks in tests.
pass
def test_cancel_async_api_job_does_not_terminate_celery_task(self):
"""ASYNC_API cancel must revoke without terminate=True.

The remote ADC worker is doing the actual work via NATS — terminating
the (likely-done) local ``run_job`` bootstrap doesn't stop them, and
SIGTERM'ing a still-bootstrapping child loses the message under the
broker's early-ack default. Cleanup of NATS/Redis state is what
actually stops further work.
"""
from unittest.mock import MagicMock, patch

job = Job.objects.create(
project=self.project,
name="Cancel async_api",
task_id="fake-async-task-id",
status=JobState.STARTED,
dispatch_mode=JobDispatchMode.ASYNC_API,
)

with patch("ami.jobs.models.run_job") as mock_run_job, patch(
"ami.jobs.models.cleanup_async_job_if_needed"
) as mock_cleanup:
mock_task = MagicMock()
mock_run_job.AsyncResult.return_value = mock_task

job.cancel()

mock_run_job.AsyncResult.assert_called_once_with("fake-async-task-id")
mock_task.revoke.assert_called_once_with(terminate=False)
mock_cleanup.assert_called_once_with(job)

job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED)

def test_cancel_sync_api_job_terminates_celery_task(self):
"""SYNC_API / INTERNAL cancel must keep terminate=True.

Their celery task body owns the entire job lifecycle, so terminating
the task is the only way to stop active work.
"""
from unittest.mock import MagicMock, patch

job = Job.objects.create(
project=self.project,
name="Cancel sync_api",
task_id="fake-sync-task-id",
status=JobState.STARTED,
dispatch_mode=JobDispatchMode.SYNC_API,
)

with patch("ami.jobs.models.run_job") as mock_run_job, patch(
"ami.jobs.models.cleanup_async_job_if_needed"
) as mock_cleanup:
mock_task = MagicMock()
mock_run_job.AsyncResult.return_value = mock_task

job.cancel()

mock_task.revoke.assert_called_once_with(terminate=True)
mock_cleanup.assert_called_once_with(job)

job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED)

def test_mljob_run_bails_when_cancelled_during_bootstrap(self):
"""Regression for the cancel-during-bootstrap race in ASYNC_API jobs.

``Job.cancel()`` revokes without terminate=True (so the local worker is
not SIGTERM'd), marks the row REVOKED, and tears down NATS/Redis state.
If the worker is still inside ``MLJob.run`` (typically blocked in the
slow ``collect_images`` step), it must refresh the row and bail BEFORE
calling ``queue_images_to_nats`` — otherwise it would recreate the
stream and dispatch real GPU work to ADC for a revoked job.
"""
from unittest.mock import patch

from ami.jobs.models import MLJob

pipeline = Pipeline.objects.create(name="Cancel-race pipeline", slug="cancel-race-pipeline")
pipeline.projects.add(self.project)
collection = SourceImageCollection.objects.create(name="Cancel-race collection", project=self.project)
job = Job.objects.create(
project=self.project,
name="Cancel-race",
pipeline=pipeline,
source_image_collection=collection,
status=JobState.STARTED,
dispatch_mode=JobDispatchMode.ASYNC_API,
)
job.setup()

def cancel_mid_collect(*_args, **_kwargs):
# Simulate the user clicking cancel while collect_images is still
# running: rewrite the DB row out from under this in-flight task.
Job.objects.filter(pk=job.pk).update(status=JobState.REVOKED)
return []

with patch.object(
pipeline,
"collect_images",
side_effect=cancel_mid_collect,
), patch("ami.ml.orchestration.jobs.queue_images_to_nats") as mock_queue:
MLJob.run(job)

mock_queue.assert_not_called()
job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED)

def test_cancel_job_without_task_id_still_revokes(self):
"""A job that never made it to enqueue (no task_id) still transitions
to REVOKED and triggers async-cleanup (a no-op for non-ASYNC_API)."""
from unittest.mock import patch

job = Job.objects.create(
project=self.project,
name="Cancel never-enqueued",
task_id="",
status=JobState.PENDING,
dispatch_mode=JobDispatchMode.INTERNAL,
)

with patch("ami.jobs.models.run_job") as mock_run_job, patch(
"ami.jobs.models.cleanup_async_job_if_needed"
) as mock_cleanup:
job.cancel()
mock_run_job.AsyncResult.assert_not_called()
mock_cleanup.assert_called_once_with(job)

job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED)

def test_list_jobs_with_ids_only(self):
"""Test the ids_only parameter returns only job IDs."""
Expand Down
Loading
Loading