Skip to content

Commit c8e6c8a

Browse files
mihowclaude
andcommitted
fix(jobs): close prerun-clobber + orphan-NATS-dispatch races
Review on #1324 surfaced two races that left the early-guard non-functional in production: 1. ``task_prerun`` (``pre_update_job_status``) wrote PENDING to the row before the ``run_job`` body inspected status. A canceled or redelivered message therefore had its REVOKED/CANCELING overwritten with PENDING, and the early-guard added in the parent commit never tripped. The existing tests passed only because they invoked ``run_job.apply(args=[…])`` while production uses ``kwargs={"job_id": …}`` — under args, the prerun handler raised ``KeyError`` and exited silently. Switching the tests to ``kwargs=`` reproduces the production code path; the prerun handler now short-circuits when ``Job.is_settled()`` is true, preserving the status the early-guard reads next. 2. For ASYNC_API jobs ``Job.cancel()`` revokes without ``terminate=True``, marks the row REVOKED, and tears down the NATS stream + Redis state. ``MLJob.run`` running in a worker that's still inside ``collect_images`` (slow for large collections) would then proceed to ``queue_images_to_nats`` and recreate the stream the cancel just deleted, dispatching real GPU work to ADC for a revoked job; the results came back to no Redis state and ``_fail_job`` silently overwrote REVOKED with FAILURE. The bootstrap now checks ``Job.status`` (via a values-only read so the in-memory ``progress`` mutations don't clobber the cancel's REVOKED) right after the collect stage and bails out before any dispatch. Adds ``Job.is_settled()`` to centralize the "terminal or being torn down" predicate that ``run_job``'s early-guard, the prerun handler, ``_fail_job``, and the bootstrap guard all needed. Adds two regression tests: one for the prerun-then-guard chain, one for the cancel-during-bootstrap race. Co-Authored-By: Claude <noreply@anthropic.com>
1 parent f772b71 commit c8e6c8a

4 files changed

Lines changed: 138 additions & 9 deletions

File tree

ami/jobs/models.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,28 @@ def run(cls, job: "Job"):
529529
progress=1,
530530
)
531531

532+
# Mid-bootstrap cancel guard. ``collect_images`` above can run for many
533+
# minutes on large collections (S3 list + DB joins), and the user may
534+
# cancel during that window. ``Job.cancel()`` for ASYNC_API does
535+
# ``revoke(terminate=False)`` to avoid SIGKILL'ing this worker, then
536+
# writes REVOKED + tears down the NATS stream / Redis state. Without
537+
# this check we would (a) clobber the cancel's REVOKED via the next
538+
# full ``job.save()`` and (b) proceed to ``queue_images_to_nats``,
539+
# recreating the stream the cancel just deleted and dispatching real
540+
# GPU work to ADC for a revoked job. Refresh is read-only against the
541+
# ``status`` column; the in-memory ``progress`` mutations from the
542+
# collect stage are intentionally dropped on the bail path because the
543+
# job is settled — no further progress writes make sense. Covers
544+
# ASYNC_API (NATS dispatch) and SYNC paths (Celery sub-tasks in
545+
# ``process_images``); INTERNAL jobs benefit too. See
546+
# RolnickLab/antenna#1323.
547+
db_status = Job.objects.values_list("status", flat=True).get(pk=job.pk)
548+
if db_status in JobState.final_states() or db_status == JobState.CANCELING:
549+
job.logger.info(
550+
f"Job {job.pk} settled to {db_status} during bootstrap; " f"skipping dispatch of {len(images)} images"
551+
)
552+
return
553+
532554
# End image collection stage
533555
job.save()
534556

@@ -1041,6 +1063,21 @@ def setup(self, save=True):
10411063
if save:
10421064
self.save()
10431065

1066+
def is_settled(self) -> bool:
1067+
"""Return True when the job is in a terminal state or being cancelled.
1068+
1069+
Used by every code path that must not start (or continue) work for a
1070+
job whose lifecycle has effectively ended: the ``run_job`` early-guard
1071+
(after acks_late redelivery), the ``MLJob.run`` mid-bootstrap cancel
1072+
check (before ``queue_images_to_nats`` dispatches GPU work to ADC),
1073+
the prerun signal handler (so a redelivered or canceled message does
1074+
not get its status reset to PENDING), and ``_fail_job``. Centralized
1075+
so the predicate stays in one place — adding a new "do not resume"
1076+
status only requires touching :meth:`JobState.final_states` or this
1077+
method.
1078+
"""
1079+
return self.status in JobState.final_states() or self.status == JobState.CANCELING
1080+
10441081
def run(self):
10451082
"""
10461083
Run the job.

ami/jobs/tasks.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def update_async_services_seen_for_project(project_id: int) -> None:
150150
reject_on_worker_lost=True,
151151
)
152152
def run_job(self, job_id: int) -> None:
153-
from ami.jobs.models import Job, JobState
153+
from ami.jobs.models import Job
154154

155155
try:
156156
job = Job.objects.get(pk=job_id)
@@ -161,8 +161,10 @@ def run_job(self, job_id: int) -> None:
161161
# Early-guard: under acks_late, the broker may redeliver this message after a
162162
# worker SIGKILL/OOM, and Job.cancel() may also flip status to CANCELING /
163163
# REVOKED while the message sits in the prefetch buffer. Don't re-run a job
164-
# that's already settled or being torn down.
165-
if job.status in JobState.final_states() or job.status == JobState.CANCELING:
164+
# that's already settled or being torn down. The companion guard in
165+
# pre_update_job_status above prevents the task_prerun signal from
166+
# overwriting that status with PENDING before we get here.
167+
if job.is_settled():
166168
job.logger.info(
167169
f"Skipping run_job for job {job.pk}: already in status {job.status} "
168170
f"(redelivery or cancellation in flight)"
@@ -444,7 +446,7 @@ def _fail_job(job_id: int, reason: str) -> None:
444446
try:
445447
with transaction.atomic():
446448
job = Job.objects.select_for_update().get(pk=job_id)
447-
if job.status in (JobState.CANCELING, *JobState.final_states()):
449+
if job.is_settled():
448450
return
449451
job.update_status(JobState.FAILURE, save=False)
450452
job.finished_at = datetime.datetime.now()
@@ -1327,7 +1329,34 @@ def cleanup_async_job_if_needed(job) -> None:
13271329

13281330
@task_prerun.connect(sender=run_job)
13291331
def pre_update_job_status(sender, task_id, task, **kwargs):
1330-
# in the prerun signal, set the job status to PENDING
1332+
"""Bump the job to PENDING when a worker picks the message up.
1333+
1334+
Skipped when the job is already settled (terminal state) or being
1335+
cancelled. Without that guard, a broker redelivery (acks_late + worker
1336+
crash) or a cancel that arrived while the message was still in the
1337+
prefetch buffer would have its REVOKED/CANCELING status silently
1338+
overwritten with PENDING here, and the ``run_job`` early-guard
1339+
(which reads ``Job.status`` after this signal fires) would then fail
1340+
to short-circuit and re-run the job. See RolnickLab/antenna#1323.
1341+
"""
1342+
from ami.jobs.models import Job
1343+
1344+
job_id = task.request.kwargs.get("job_id") if task.request.kwargs else None
1345+
if job_id is None and task.request.args:
1346+
job_id = task.request.args[0]
1347+
if job_id is not None:
1348+
try:
1349+
job = Job.objects.only("status").get(pk=job_id)
1350+
except Job.DoesNotExist:
1351+
pass
1352+
else:
1353+
if job.is_settled():
1354+
logger.info(
1355+
"task_prerun: skipping PENDING write for job %s in status %s " "(redelivery or cancel in flight)",
1356+
job_id,
1357+
job.status,
1358+
)
1359+
return
13311360
update_job_status(sender, task_id, task, "PENDING", **kwargs)
13321361

13331362

ami/jobs/tests/test_jobs.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,50 @@ def test_cancel_sync_api_job_terminates_celery_task(self):
439439
job.refresh_from_db()
440440
self.assertEqual(job.status, JobState.REVOKED)
441441

442+
def test_mljob_run_bails_when_cancelled_during_bootstrap(self):
443+
"""Regression for the cancel-during-bootstrap race in ASYNC_API jobs.
444+
445+
``Job.cancel()`` revokes without terminate=True (so the local worker is
446+
not SIGTERM'd), marks the row REVOKED, and tears down NATS/Redis state.
447+
If the worker is still inside ``MLJob.run`` (typically blocked in the
448+
slow ``collect_images`` step), it must refresh the row and bail BEFORE
449+
calling ``queue_images_to_nats`` — otherwise it would recreate the
450+
stream and dispatch real GPU work to ADC for a revoked job.
451+
"""
452+
from unittest.mock import patch
453+
454+
from ami.jobs.models import MLJob
455+
456+
pipeline = Pipeline.objects.create(name="Cancel-race pipeline", slug="cancel-race-pipeline")
457+
pipeline.projects.add(self.project)
458+
collection = SourceImageCollection.objects.create(name="Cancel-race collection", project=self.project)
459+
job = Job.objects.create(
460+
project=self.project,
461+
name="Cancel-race",
462+
pipeline=pipeline,
463+
source_image_collection=collection,
464+
status=JobState.STARTED,
465+
dispatch_mode=JobDispatchMode.ASYNC_API,
466+
)
467+
job.setup()
468+
469+
def cancel_mid_collect(*_args, **_kwargs):
470+
# Simulate the user clicking cancel while collect_images is still
471+
# running: rewrite the DB row out from under this in-flight task.
472+
Job.objects.filter(pk=job.pk).update(status=JobState.REVOKED)
473+
return []
474+
475+
with patch.object(
476+
pipeline,
477+
"collect_images",
478+
side_effect=cancel_mid_collect,
479+
), patch("ami.ml.orchestration.jobs.queue_images_to_nats") as mock_queue:
480+
MLJob.run(job)
481+
482+
mock_queue.assert_not_called()
483+
job.refresh_from_db()
484+
self.assertEqual(job.status, JobState.REVOKED)
485+
442486
def test_cancel_job_without_task_id_still_revokes(self):
443487
"""A job that never made it to enqueue (no task_id) still transitions
444488
to REVOKED and triggers async-cleanup (a no-op for non-ASYNC_API)."""

ami/jobs/tests/test_tasks.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def test_skips_when_job_already_revoked(self):
667667
job = self._make_job(JobState.REVOKED)
668668

669669
with patch.object(Job, "run") as mock_run:
670-
result = run_job.apply(args=[job.pk])
670+
result = run_job.apply(kwargs={"job_id": job.pk})
671671

672672
self.assertTrue(result.successful(), msg=f"task should succeed, got {result.state}: {result.traceback}")
673673
mock_run.assert_not_called()
@@ -679,7 +679,7 @@ def test_skips_when_job_canceling(self):
679679
job = self._make_job(JobState.CANCELING)
680680

681681
with patch.object(Job, "run") as mock_run:
682-
result = run_job.apply(args=[job.pk])
682+
result = run_job.apply(kwargs={"job_id": job.pk})
683683

684684
self.assertTrue(result.successful(), msg=f"task should succeed, got {result.state}: {result.traceback}")
685685
mock_run.assert_not_called()
@@ -691,7 +691,7 @@ def test_skips_when_job_already_success(self):
691691
job = self._make_job(JobState.SUCCESS)
692692

693693
with patch.object(Job, "run") as mock_run:
694-
result = run_job.apply(args=[job.pk])
694+
result = run_job.apply(kwargs={"job_id": job.pk})
695695

696696
self.assertTrue(result.successful(), msg=f"task should succeed, got {result.state}: {result.traceback}")
697697
mock_run.assert_not_called()
@@ -703,10 +703,29 @@ def test_runs_when_job_pending(self):
703703
job = self._make_job(JobState.PENDING)
704704

705705
with patch.object(Job, "run") as mock_run:
706-
run_job.apply(args=[job.pk])
706+
run_job.apply(kwargs={"job_id": job.pk})
707707

708708
mock_run.assert_called_once()
709709

710+
def test_prerun_signal_does_not_clobber_revoked_status(self):
711+
"""
712+
Regression: the ``task_prerun`` signal would otherwise call
713+
``update_job_status(state="PENDING")`` and overwrite a REVOKED/CANCELING
714+
status before the ``run_job`` early-guard reads it. With the prerun
715+
guard in place, the status survives the signal, the early-guard fires,
716+
and ``Job.run()`` is not called.
717+
"""
718+
from ami.jobs.tasks import run_job
719+
720+
job = self._make_job(JobState.REVOKED)
721+
722+
with patch.object(Job, "run") as mock_run:
723+
run_job.apply(kwargs={"job_id": job.pk})
724+
725+
job.refresh_from_db()
726+
self.assertEqual(job.status, JobState.REVOKED)
727+
mock_run.assert_not_called()
728+
710729

711730
class TestResultEndpointWithError(APITestCase):
712731
"""Integration test for the result API endpoint with error results."""

0 commit comments

Comments
 (0)