-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathtests.py
More file actions
4915 lines (4107 loc) · 221 KB
/
tests.py
File metadata and controls
4915 lines (4107 loc) · 221 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import datetime
import logging
import typing
from io import BytesIO
from django.contrib.auth.models import AnonymousUser
from django.core.files.uploadedfile import SimpleUploadedFile
from django.db import connection, models
from django.test import TestCase, override_settings
from guardian.shortcuts import assign_perm, get_perms, remove_perm
from PIL import Image
from rest_framework import status
from rest_framework.test import APIClient, APIRequestFactory, APITestCase
from rich import print
from ami.exports.models import DataExport
from ami.jobs.models import VALID_JOB_TYPES, Job
from ami.main.models import (
Classification,
Deployment,
Detection,
Device,
Event,
Identification,
Occurrence,
Project,
S3StorageSource,
Site,
SourceImage,
SourceImageCollection,
SourceImageUpload,
Tag,
TaxaList,
Taxon,
TaxonRank,
group_images_into_events,
)
from ami.ml.models.pipeline import Pipeline
from ami.ml.models.processing_service import ProcessingService
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig
from ami.tests.fixtures.main import create_captures, create_occurrences, create_taxa, setup_test_project
from ami.tests.fixtures.storage import populate_bucket
from ami.users.models import User
from ami.users.roles import BasicMember, Identifier, MLDataManager, ProjectManager, create_roles_for_project
logger = logging.getLogger(__name__)
class TestProjectSetup(TestCase):
def test_project_creation(self):
project = Project.objects.create(name="New Project with Defaults", create_defaults=True)
self.assertIsInstance(project, Project)
def test_default_related_models(self):
"""Test that the default related models are created correctly when a project is created."""
project = Project.objects.create(name="New Project with Defaults", create_defaults=True)
# Check that the project has a default deployment
self.assertGreaterEqual(project.deployments.count(), 1)
deployment = project.deployments.first()
self.assertIsInstance(deployment, Deployment)
# Check that the deployment has a default site
self.assertGreaterEqual(project.sites.count(), 1)
site = project.sites.first()
self.assertIsInstance(site, Site)
# Check that the deployment has a default device
self.assertGreaterEqual(project.devices.count(), 1)
device = project.devices.first()
self.assertIsInstance(device, Device)
# Check that the project has a default source image collection
self.assertGreaterEqual(project.sourceimage_collections.count(), 1)
collection = project.sourceimage_collections.first()
self.assertIsInstance(collection, SourceImageCollection)
# Disable this test for now, as it requires a more complex setup
def no_test_default_permissions(self):
pass
@override_settings(
DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service",
DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2009/",
)
def test_processing_service_if_configured(self):
"""
Test that the default processing service is created if the environment variables are set.
"""
from ami.ml.models.processing_service import get_or_create_default_processing_service
project = Project.objects.create(name="Test Project for Processing Service", create_defaults=False)
service = get_or_create_default_processing_service(project=project, register_pipelines=False)
self.assertIsNotNone(service, "Default processing service should be created if environment variables are set.")
assert service is not None # For type checking
self.assertIsNotNone(service.endpoint_url)
self.assertIsNotNone(service.name)
self.assertGreaterEqual(project.processing_services.count(), 1)
@override_settings(
DEFAULT_PROCESSING_SERVICE_NAME=None,
DEFAULT_PROCESSING_SERVICE_ENDPOINT=None,
)
def test_processing_service_if_not_configured(self):
"""
Test that the default processing service is not created if the environment variables are not set.
"""
from ami.ml.models.processing_service import get_or_create_default_processing_service
project = Project.objects.create(name="Test Project for Processing Service", create_defaults=False)
service = get_or_create_default_processing_service(project=project)
self.assertIsNone(
service, "Default processing service should not be created if environment variables are not set."
)
@override_settings(
DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service",
DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2000/",
DEFAULT_PIPELINES_ENABLED=[], # All pipelines DISABLED by default
)
def test_processing_service_with_disabled_pipelines(self):
"""
Test that the default processing service is created with all pipelines disabled
if DEFAULT_PIPELINES_ENABLED is any empty list.
"""
project = Project.objects.create(name="Test Project for Processing Service", create_defaults=True)
processing_service = project.processing_services.first()
assert processing_service is not None
# There should be at least two pipelines created by default
self.assertGreaterEqual(processing_service.pipelines.count(), 2)
# All pipelines should be disabled by default
project_pipeline_configs = ProjectPipelineConfig.objects.filter(project=project)
for config in project_pipeline_configs:
self.assertFalse(
config.enabled,
f"Pipeline {config.pipeline.name} should be disabled for project {project.name}.",
)
@override_settings(
DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service",
DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2000/",
DEFAULT_PIPELINES_ENABLED=None, # All pipelines ENABLED by default
)
def test_processing_service_with_enabled_pipelines(self):
"""
Test that the default processing service is created with all pipelines enabled
if the DEFAULT_PIPELINES_ENABLED setting is None (or missing).
"""
project = Project.objects.create(name="Test Project for Processing Service", create_defaults=True)
processing_service = project.processing_services.first()
assert processing_service is not None
# There should be at least two pipelines created by default
self.assertGreaterEqual(processing_service.pipelines.count(), 2)
# All pipelines should be enabled by default
project_pipeline_configs = ProjectPipelineConfig.objects.filter(project=project)
for config in project_pipeline_configs:
self.assertTrue(
config.enabled,
f"Pipeline {config.pipeline.name} should be enabled for project {project.name}.",
)
@override_settings(
DEFAULT_PROCESSING_SERVICE_NAME="Default Processing Service",
DEFAULT_PROCESSING_SERVICE_ENDPOINT="http://ml_backend:2000/", # should have at least two pipelines
DEFAULT_PIPELINES_ENABLED=["constant"],
)
def test_existing_processing_service_new_project(self):
"""
Create a new project, enable all pipelines.
Create a 2nd project, ensure that the same processing service is used and only the enabled pipelines are
registered.
"""
enabled_pipelines = ["constant"]
project_one = Project.objects.create(name="Test Project One", create_defaults=True)
# Enable all pipelines for the first project
ProjectPipelineConfig.objects.filter(project=project_one).update(enabled=True)
project_two = Project.objects.create(name="Test Project Two", create_defaults=True)
project_one_processing_service = project_one.processing_services.first()
project_two_processing_service = project_two.processing_services.first()
assert project_one_processing_service is not None
assert project_two_processing_service is not None
# Ensure only the same processing service instance is used (and they are not None)
self.assertEqual(
project_one_processing_service,
project_two_processing_service,
"Both projects should use the same processing service instance.",
)
# Ensure that only the enabled pipelines are enabled for the second project
project_two_pipeline_configs = ProjectPipelineConfig.objects.filter(project=project_two)
self.assertGreaterEqual(project_two_pipeline_configs.count(), 2, "Project should have at least two pipelines.")
for config in project_two_pipeline_configs:
if config.pipeline.slug in enabled_pipelines:
self.assertTrue(
config.enabled,
f"Pipeline {config.pipeline.name} should be enabled for project {project_two.name}.",
)
else:
self.assertFalse(
config.enabled,
f"Pipeline {config.pipeline.name} should not be enabled for project {project_two.name}.",
)
class TestImageGrouping(TestCase):
def setUp(self) -> None:
print(f"Currently active database: {connection.settings_dict}")
self.project, self.deployment = setup_test_project()
return super().setUp()
def test_grouping(self):
num_nights = 3
images_per_night = 3
create_captures(
deployment=self.deployment,
num_nights=num_nights,
images_per_night=images_per_night,
interval_minutes=10,
)
events = group_images_into_events(
deployment=self.deployment,
max_time_gap=datetime.timedelta(hours=2),
)
assert len(events) == num_nights
for event in events:
assert event.captures.count() == images_per_night
def _populate_continuous_captures(self, days: int = 3, interval_minutes: int = 10):
"""Create ``days`` of gap-free captures (no gap > ``interval_minutes``)."""
import pathlib
import uuid
start = datetime.datetime(2023, 4, 24, 3, 22, 38)
interval = datetime.timedelta(minutes=interval_minutes)
count = int(datetime.timedelta(days=days) / interval)
for i in range(count):
SourceImage.objects.create(
deployment=self.deployment,
timestamp=start + i * interval,
path=pathlib.Path("test") / f"{uuid.uuid4().hex[:8]}_continuous_{i}.jpg",
)
return count
def test_continuous_monitoring_capped_at_24_hours(self):
"""
A deployment that captures images continuously (no gap > max_time_gap)
should still be broken into daily events by the max_event_duration cap,
not coalesced into one multi-day event.
"""
self._populate_continuous_captures(days=3, interval_minutes=10)
events = group_images_into_events(
deployment=self.deployment,
max_time_gap=datetime.timedelta(hours=2),
max_event_duration=datetime.timedelta(hours=24),
)
# 3 days × 24h / 10min = 432 captures; capped at 24h → exactly 3 events.
# `== 3` (not `>= 3`) guards against over-splitting regressions too.
assert len(events) == 3, f"expected exactly 3 daily events, got {len(events)}"
for event in events:
duration = event.end - event.start
assert duration <= datetime.timedelta(hours=24), f"event {event.pk} spans {duration}, exceeds 24h cap"
def test_regrouping_existing_long_event_refreshes_cached_fields(self):
"""
Regression test for the regroup-existing-events path: a deployment
already grouped into a single multi-day event should, after re-running
grouping with the 24h cap, end up with no events exceeding 24h AND
every reused event's cached start/end/captures_count must reflect its
current captures (not its pre-regroup state).
This is narrower than #904's refactor on purpose: it asserts the
observable cap+refresh behavior without depending on the specific
group_by reuse mechanics that #904 is expected to remove.
"""
total_captures = self._populate_continuous_captures(days=3, interval_minutes=10)
# First pass with the cap disabled → a single multi-day "mega-event".
events_uncapped = group_images_into_events(
deployment=self.deployment,
max_time_gap=datetime.timedelta(hours=2),
max_event_duration=None,
)
assert len(events_uncapped) == 1
mega_event = events_uncapped[0]
assert (mega_event.end - mega_event.start) > datetime.timedelta(hours=24)
# Second pass with the cap → must split the mega-event and refresh
# cached fields on the reused event.
group_images_into_events(
deployment=self.deployment,
max_time_gap=datetime.timedelta(hours=2),
max_event_duration=datetime.timedelta(hours=24),
)
all_events = Event.objects.filter(deployment=self.deployment)
assert all_events.count() == 3, f"expected exactly 3 events after regroup, got {all_events.count()}"
for event in all_events:
duration = event.end - event.start
assert duration <= datetime.timedelta(
hours=24
), f"event {event.pk} spans {duration} after regroup; cached fields are stale"
# Per-event cached-count check: catches reused events whose captures_count
# was never refreshed after captures were reassigned away. A sum-only check
# can miss this when two events' errors offset each other.
actual_captures = SourceImage.objects.filter(event=event).count()
assert event.captures_count == actual_captures, (
f"event {event.pk} cached captures_count={event.captures_count} "
f"does not match actual related count={actual_captures}; cached counters are stale"
)
# Orphan check: every capture must belong to some event.
total_assigned = sum(e.captures_count for e in all_events)
assert total_assigned == total_captures, (
f"captures_count across events ({total_assigned}) does not match total captures ({total_captures}); "
f"captures were orphaned during regroup"
)
def test_regrouping_realigns_occurrence_event_id(self):
"""
Regression test for stale ``Occurrence.event_id`` after regroup.
Occurrences are bound to an event once at creation time (from
``detection.source_image.event``). When the 24h cap runs against a
deployment that already has detections + occurrences attached to a
single mega-event, the source_images are reassigned but the
occurrences' event_ids stay stuck at the mega-event unless we
explicitly realign them. This test asserts the realignment plus the
downstream ``occurrences_count`` consistency on the daily events.
"""
self._populate_continuous_captures(days=3, interval_minutes=10)
captures = list(SourceImage.objects.filter(deployment=self.deployment).order_by("timestamp"))
# First pass with the cap disabled → one mega-event holding everything.
group_images_into_events(
deployment=self.deployment,
max_time_gap=datetime.timedelta(hours=2),
max_event_duration=None,
)
mega_event = Event.objects.get(deployment=self.deployment)
# One occurrence per day, picked at mid-day offsets (12h / 36h / 60h)
# so each target sits well inside its event's window, far from the
# exact 24h boundary where the cap's strict ``>`` semantics matter.
# Index-based selection (e.g. ``captures[len // 3]``) lands at exactly
# 24h offset, where Event 2 starts at 24h+10min (one capture past),
# so two of the three targets would otherwise share an event.
start_ts = captures[0].timestamp
targets = [
next(c for c in captures if c.timestamp >= start_ts + datetime.timedelta(hours=12)),
next(c for c in captures if c.timestamp >= start_ts + datetime.timedelta(hours=36)),
next(c for c in captures if c.timestamp >= start_ts + datetime.timedelta(hours=60)),
]
occurrences = []
for capture in targets:
detection = Detection.objects.create(
source_image=capture,
timestamp=capture.timestamp,
bbox=[10, 10, 20, 20],
)
occurrence = Occurrence.objects.create(
event=mega_event,
deployment=self.deployment,
project=self.project,
)
detection.occurrence = occurrence
detection.save()
occurrences.append(occurrence)
# Sanity: all three occurrences point at the mega-event before regroup.
for occurrence in occurrences:
occurrence.refresh_from_db()
assert occurrence.event_id == mega_event.pk
# Second pass: 24h cap → 3 daily events, each occurrence must follow
# its detection's source_image into the corresponding daily event.
group_images_into_events(
deployment=self.deployment,
max_time_gap=datetime.timedelta(hours=2),
max_event_duration=datetime.timedelta(hours=24),
)
for occurrence in occurrences:
occurrence.refresh_from_db()
first_detection = (
Detection.objects.filter(occurrence=occurrence)
.select_related("source_image")
.order_by("source_image__timestamp")
.first()
)
assert first_detection is not None
expected_event_id = first_detection.source_image.event_id
assert occurrence.event_id == expected_event_id, (
f"occurrence {occurrence.pk}: stale event_id={occurrence.event_id} "
f"(expected {expected_event_id} from first detection's source_image)"
)
# Realignment must move all three occurrences onto distinct daily
# events. With targets at mid-day offsets, the three occurrences land
# on three different events — one on each day.
distinct_event_ids = {occ.event_id for occ in occurrences}
assert (
len(distinct_event_ids) == 3
), f"expected 3 distinct event_ids across occurrences, got {distinct_event_ids}"
# Each daily event's cached ``occurrences_count`` must match the live
# computation that ``update_calculated_fields`` itself uses (which
# applies the project's default filters). Catches the case where
# occurrences moved off an event but its cached counter wasn't
# refreshed because the event wasn't tracked as touched.
daily_events = Event.objects.filter(deployment=self.deployment)
assert daily_events.count() == 3
for event in daily_events:
expected = event.get_occurrences_count()
assert event.occurrences_count == expected, (
f"event {event.pk} cached occurrences_count={event.occurrences_count} "
f"!= live get_occurrences_count()={expected}; cached counter is stale"
)
# No occurrence should be left pointing at a deleted/missing event.
assert Occurrence.objects.filter(deployment=self.deployment, event__isnull=True).count() == 0
def test_pruning_empty_events(self):
from ami.main.models import delete_empty_events
captures = create_captures(deployment=self.deployment)
events = Event.objects.filter(captures__in=captures).distinct()
for event in events:
event.captures.all().delete()
delete_empty_events(deployment=self.deployment)
remaining_events = Event.objects.filter(pk__in=[event.pk for event in events])
assert remaining_events.count() == 0
def test_setting_image_dimensions(self):
from ami.main.models import set_dimensions_for_collection
image_width, image_height = 100, 100
captures = create_captures(deployment=self.deployment)
events = Event.objects.filter(captures__in=captures).distinct()
for event in events:
first_image = event.captures.first()
assert first_image is not None
first_image.width, first_image.height = image_width, image_height
first_image.save()
set_dimensions_for_collection(event=event)
for capture in event.captures.all():
# print(capture.path, capture.width, capture.height)
assert (capture.width == image_width) and (capture.height == image_height)
# This test is disabled because it requires certain data to be present in the database
# and data in a configured S3 bucket. Will require Minio or something like it to be running.
# from unittest import TestCase as UnitTestCase
# class TestExistingDatabase(UnitTestCase):
# def test_sync_source_images(self):
# from django.db import models
#
# from ami.main.models import Deployment
# from ami.tasks import sync_source_images
#
# deployment = Deployment.objects.get(
# name="Test",
# )
# sync_source_images(deployment.pk)
#
# # Get deployment with the most captures
# deployment = (
# Deployment.objects.annotate(captures_count=models.Count("captures")).order_by("-captures_count").first()
# )
# if deployment:
# sync_source_images(deployment.pk)
class TestEvents(TestCase):
def setUp(self) -> None:
project, deployment = setup_test_project()
create_captures(deployment=deployment, num_nights=2, images_per_night=5)
self.project = project
self.deployment = deployment
return super().setUp()
def test_event_calculated_fields(self):
event, event_2 = self.deployment.events.all()
# Test initial calculated fields
event.update_calculated_fields(save=True)
event.refresh_from_db()
self.assertEqual(event.captures_count, 5)
self.assertIsNotNone(event.detections_count)
self.assertIsNotNone(event.occurrences_count)
initial_update_date = event.calculated_fields_updated_at
self.assertIsNotNone(initial_update_date)
# Add more captures and test that the calculated fields are updated
for capture in event_2.captures.all():
event.captures.add(capture) # type: ignore
event.update_calculated_fields(save=True)
event.refresh_from_db()
self.assertEqual(event.captures_count, event.get_captures_count())
self.assertEqual(event.captures_count, 10)
self.assertGreater(event.calculated_fields_updated_at, initial_update_date) # type: ignore
def test_event_calculated_fields_batch(self):
from ami.main.models import update_calculated_fields_for_events
last_updated_timestamps = []
for event in self.deployment.events.all().order_by("pk"):
self.assertEqual(event.captures_count, event.get_captures_count())
self.assertEqual(event.detections_count, event.get_detections_count())
self.assertEqual(event.occurrences_count, event.get_occurrences_count())
self.assertIsNotNone(event.calculated_fields_updated_at)
last_updated_timestamps.append(event.calculated_fields_updated_at)
# Delete all detections for all source images and test that the calculated fields are updated
from ami.main.models import Detection
Detection.objects.all().delete()
update_calculated_fields_for_events(last_updated=datetime.datetime(3000, 1, 1, 0, 0, 0))
for event, last_updated in zip(self.deployment.events.all().order_by("pk"), last_updated_timestamps):
self.assertEqual(event.captures_count, event.get_captures_count())
self.assertEqual(event.detections_count, event.get_detections_count())
self.assertEqual(event.occurrences_count, event.get_occurrences_count())
self.assertGreater(event.calculated_fields_updated_at, last_updated)
# Delete all captures and test that the calculated fields are updated
self.deployment.captures.all().delete()
update_calculated_fields_for_events(last_updated=datetime.datetime(3000, 1, 1, 0, 0, 0))
for event, last_updated in zip(self.deployment.events.all().order_by("pk"), last_updated_timestamps):
self.assertEqual(event.captures_count, event.get_captures_count())
self.assertEqual(event.detections_count, event.get_detections_count())
self.assertEqual(event.occurrences_count, event.get_occurrences_count())
self.assertGreater(event.calculated_fields_updated_at, last_updated) # type: ignore
class TestDuplicateFieldsOnChildren(TestCase):
def setUp(self) -> None:
from ami.main.models import Deployment, Project
self.project_one = Project.objects.create(name="Test Project One")
self.project_two = Project.objects.create(name="Test Project Two")
self.deployment = Deployment.objects.create(name="Test Deployment", project=self.project_one)
create_captures(deployment=self.deployment)
self.deployment.save(regroup_async=False) # Ensure events are grouped immediately
create_taxa(project=self.project_one)
create_taxa(project=self.project_two)
create_occurrences(deployment=self.deployment, num=1)
return super().setUp()
def test_initial_project(self):
assert self.deployment.project == self.project_one
assert self.deployment.captures.first().project == self.project_one
assert self.deployment.events.first().project == self.project_one
assert self.deployment.occurrences.first().project == self.project_one
assert self.deployment.occurrences.first().detections.first().source_image.project == self.project_one
def test_change_project(self):
self.deployment.project = self.project_two
self.deployment.save()
self.deployment.refresh_from_db()
assert self.deployment.project == self.project_two
assert self.deployment.captures.first().project == self.project_two
assert self.deployment.events.first().project == self.project_two
assert self.deployment.occurrences.first().project == self.project_two
def test_delete_project(self):
self.project_one.delete()
self.deployment.refresh_from_db()
assert self.deployment.project is None
assert self.deployment.captures.first().project is None
assert self.deployment.events.first().project is None
assert self.deployment.occurrences.first().project is None
class TestSourceImageCollections(TestCase):
def setUp(self) -> None:
from ami.main.models import Deployment, Project
self.project_one = Project.objects.create(name="Test Project One")
self.deployment = Deployment.objects.create(name="Test Deployment", project=self.project_one)
create_captures(deployment=self.deployment, num_nights=2, images_per_night=10, interval_minutes=1)
return super().setUp()
def test_random_sample(self):
from ami.main.models import SourceImageCollection
sample_size = 10
collection = SourceImageCollection.objects.create(
name="Test Random Source Image Collection",
project=self.project_one,
method="random",
kwargs={"size": sample_size},
)
collection.save()
collection.populate_sample()
assert collection.images.count() == sample_size
def test_manual_sample(self):
from ami.main.models import SourceImageCollection
images = self.deployment.captures.all()
collection = SourceImageCollection.objects.create(
name="Test Manual Source Image Collection",
project=self.project_one,
method="manual",
kwargs={"image_ids": [image.pk for image in images]},
)
collection.save()
collection.populate_sample()
assert collection.images.count() == len(images)
for image in images:
assert image in collection.images.all()
def test_interval_sample(self):
# Ensure that the images are 5 at least minutes apart and less than 6 minutes apart within each event
# This depends on the test setUp creating images with a 1 minute interval
from ami.main.models import SourceImageCollection
minute_interval = 10
collection = SourceImageCollection.objects.create(
name="Test Interval Source Image Collection",
project=self.project_one,
method="interval",
kwargs={"minute_interval": minute_interval},
)
collection.save()
collection.populate_sample()
events = collection.images.values_list("event", flat=True).distinct()
for event in events:
last_image = None
for image in collection.images.filter(event=event):
if last_image:
interval = image.timestamp - last_image.timestamp
assert interval >= datetime.timedelta(minutes=minute_interval)
assert interval < datetime.timedelta(minutes=minute_interval + 1)
last_image = image
def test_interval_with_excluded_events(self):
from ami.main.models import SourceImageCollection
minute_interval = 5
events = self.deployment.events.all()
excluded_event = events.first()
assert excluded_event is not None
collection = SourceImageCollection.objects.create(
name="Test Interval With Excluded Events",
project=self.project_one,
method="interval",
kwargs={"minute_interval": minute_interval, "exclude_events": [excluded_event.pk]},
)
collection.save()
collection.populate_sample()
# Ensure that no images from the excluded event are in the collection
for image in collection.images.all():
assert image.event != excluded_event
def test_extra_arguments(self):
# Assert that a value error is raised when trying to call a sampling method with extra arguments
from ami.main.models import SourceImageCollection
collection = SourceImageCollection.objects.create(
name="Test Extra Arguments Collection",
project=self.project_one,
method="interval",
kwargs={"birthday": True, "cake": "chocolate"},
)
collection.save()
with self.assertRaises(TypeError):
collection.populate_sample()
def test_last_and_random(self):
from ami.main.models import SourceImageCollection
collection = SourceImageCollection.objects.create(
name="Test Last and Random Collection",
project=self.project_one,
method="last_and_random_from_each_event",
kwargs={"num_each": 2},
)
collection.save()
collection.populate_sample()
collection_images = collection.images.all()
# 2 nights, last image from each, 2 additional random images from each
self.assertEqual(collection_images.count(), 6)
for event in self.project_one.events.all():
last_capture = event.captures.last()
assert last_capture
# ensure last_capture is in the collection
self.assertIn(last_capture, collection_images)
# ensure there are 2 other random images from each event
self.assertEqual(collection_images.filter(event=event).exclude(pk=last_capture.pk).count(), 2)
def test_random_from_each_event(self):
from ami.main.models import SourceImageCollection
collection = SourceImageCollection.objects.create(
name="Test Random From Each Event Collection",
project=self.project_one,
method="random_from_each_event",
kwargs={"num_each": 2},
)
collection.save()
collection.populate_sample()
collection_images = collection.images.all()
# 2 nights, 2 random images from each
assert collection_images.count() == 4
# Test that there are 2 images from each event
for event in self.project_one.events.all():
assert collection_images.filter(event=event).count() == 2
def test_common_combined_deployment_ids(self):
"""Test that common_combined sampling method correctly filters by deployment_ids"""
from ami.main.models import Deployment, SourceImageCollection
# Create two additional deployments
deployment_two = Deployment.objects.create(name="Test Deployment Two", project=self.project_one)
deployment_three = Deployment.objects.create(name="Test Deployment Three", project=self.project_one)
# Create captures for each deployment
create_captures(deployment=deployment_two, num_nights=2, images_per_night=10, interval_minutes=1)
create_captures(deployment=deployment_three, num_nights=2, images_per_night=10, interval_minutes=1)
# Verify that we have images from the deployments
assert deployment_two.captures.count() > 0
assert deployment_three.captures.count() > 0
# Create collection using only deployment_two and deployment_three
collection = SourceImageCollection.objects.create(
name="Test Common Combined Deployment IDs",
project=self.project_one,
method="common_combined",
kwargs={
"deployment_ids": [deployment_two.pk, deployment_three.pk],
"shuffle": True,
"max_num": 100,
},
)
collection.save()
collection.populate_sample()
collection_images = collection.images.all()
# Verify images only come from specified deployments
self.assertEqual(
collection_images.filter(deployment__in=[deployment_two, deployment_three]).count(),
collection_images.count(),
)
self.assertEqual(collection_images.filter(deployment=self.deployment).count(), 0)
# Verify we got images from both specified deployments
self.assertGreater(collection_images.filter(deployment=deployment_two).count(), 0)
self.assertGreater(collection_images.filter(deployment=deployment_three).count(), 0)
def test_interval_sample_multiple_deployments(self):
"""
Ensure interval sampling applies independently per deployment (station).
Create two deployments with captures spaced 1 minute apart for a few hours,
then sample with `minute_interval=60` and verify the total sampled count equals
the sum of per-deployment hourly samples.
"""
from ami.main.models import SourceImage, SourceImageCollection, sample_captures_by_interval
# Create a new project and two deployments
project = Project.objects.create(name="Multi Dep Project", create_defaults=False)
dep1 = Deployment.objects.create(name="Dep One", project=project)
dep2 = Deployment.objects.create(name="Dep Two", project=project)
# Create captures: 3 hours worth of captures at 1-minute intervals (~180 images)
images_per_night = 180
create_captures(deployment=dep1, num_nights=1, images_per_night=images_per_night, interval_minutes=1)
create_captures(deployment=dep2, num_nights=1, images_per_night=images_per_night, interval_minutes=1)
collection = SourceImageCollection.objects.create(
name="Test Multi-Dep Interval",
project=project,
method="interval",
kwargs={"minute_interval": 60},
)
collection.save()
collection.populate_sample()
sampled_count = collection.images.count()
# Compute expected by sampling each deployment separately
expected = 0
for dep in [dep1, dep2]:
qs = SourceImage.objects.filter(deployment=dep).exclude(timestamp=None).order_by("timestamp")
expected += len(list(sample_captures_by_interval(60, qs)))
self.assertEqual(sampled_count, expected)
class TestTaxonomy(TestCase):
def setUp(self) -> None:
project, deployment = setup_test_project()
create_taxa(project=project)
return super().setUp()
def test_tree(self):
"""
example_tree = {
'taxon': <Taxon: Lepidoptera (order)>,
'children': [
{
'taxon': <Taxon: Vanessa (genus)>,
'children': [
{'taxon': <Taxon: Vanessa atalanta (species)>, 'children': []},
{'taxon': <Taxon: Vanessa cardui (species)>, 'children': []},
{'taxon': <Taxon: Vanessa itea (species)>, 'children': []}
]
}
]
}
"""
from ami.main.models import Taxon
tree = Taxon.objects.tree()
self.assertDictContainsSubset({"taxon": Taxon.objects.get(name="Lepidoptera")}, tree)
def test_rank_formatting(self):
"""
Test that all ranks in the DB are uppercase and match a TaxonRank value
"""
from ami.main.models import Taxon
for taxon in Taxon.objects.all():
self.assertIn(taxon.rank, [rank.name for rank in TaxonRank])
self.assertEqual(taxon.rank, taxon.rank.upper())
def _test_filtered_tree(self, filter_ranks: list[TaxonRank]):
""" """
filter_rank_names = [rank.name for rank in filter_ranks]
expected_taxa = list(Taxon.objects.filter(rank__in=filter_rank_names).all())
tree = Taxon.objects.tree(filter_ranks=filter_ranks)
# collect all Taxon objects in tree to test against expected
def _tree_taxa(tree: dict) -> list[Taxon]:
taxa = []
taxa.append(tree["taxon"])
for child in tree["children"]:
taxa.extend(_tree_taxa(child))
return taxa
taxa_in_tree = _tree_taxa(tree)
expected_taxa = expected_taxa
self.assertListEqual(taxa_in_tree, expected_taxa)
def test_tree_filtered_families(self):
# Try skipping over family
filter_ranks = [TaxonRank.ORDER, TaxonRank.GENUS, TaxonRank.SPECIES]
self._test_filtered_tree(filter_ranks)
def test_tree_filtered_genera(self):
# Try skipping over genus
filter_ranks = [TaxonRank.ORDER, TaxonRank.FAMILY, TaxonRank.SPECIES]
self._test_filtered_tree(filter_ranks)
def test_tree_filtered_species(self):
# Try skipping over species
filter_ranks = [TaxonRank.ORDER, TaxonRank.FAMILY, TaxonRank.GENUS]
self._test_filtered_tree(filter_ranks)
def test_tree_filtered_root(self):
# Try skipping over order
root = Taxon.objects.root()
filter_ranks = [rank for rank in TaxonRank if rank != root.get_rank()]
with self.assertRaises(ValueError):
self._test_filtered_tree(filter_ranks)
def test_update_parents(self):
for taxon in Taxon.objects.all():
taxon.update_parents(save=True)
taxon.refresh_from_db()
self._test_parents_json(taxon)
def test_update_all_parents(self):
from ami.main.models import Taxon
Taxon.objects.update_all_parents()
for taxon in Taxon.objects.exclude(parent=None):
self._test_parents_json(taxon)
def _test_parents_json(self, taxon):
from ami.main.models import TaxonParent, TaxonRank
# Ensure all taxon have parents_json populated
if taxon.parent:
self.assertGreater(
len(taxon.parents_json),
0,
f"Taxon {taxon} has no parents_json, even though it has the parent {taxon.parent}",
)
else:
self.assertEqual(
len(taxon.parents_json),
0,
f"Taxon {taxon} has parents_json, even though it has no parent",
)
for parent_taxon in taxon.parents_json:
# Ensure all parents_json are TaxonParent objects
self.assertIsInstance(parent_taxon, TaxonParent)
self.assertIsInstance(parent_taxon.rank, TaxonRank)
# Ensure a parent rank is not the same as the taxon itself
self.assertNotEqual(taxon.rank, parent_taxon.rank)
# Ensure the order of all parents is correct
sorted_parents = sorted(taxon.parents_json, key=lambda x: x.rank)
self.assertListEqual(taxon.parents_json, sorted_parents)
# For each rank, test that it is lower than the previous rank
previous_rank = None
for parent in taxon.parents_json:
if previous_rank:
self.assertGreater(parent.rank, previous_rank)
previous_rank = parent.rank
# Ensure last item in parents_json is the taxon's direct parent
if taxon.parent:
direct_parent = taxon.parents_json[-1]
self.assertEqual(
direct_parent.id,
taxon.parent_id,
(
f"Taxon {taxon} has incorrect direct parent: {direct_parent.name} != {taxon.parent.name}. "
f"All parents: {taxon.parents_json}"
),
)
class TestTaxonomyViews(TestCase):
def setUp(self) -> None:
project_one, deployment_one = setup_test_project(reuse=False)
project_two, deployment_two = setup_test_project(reuse=False)
create_taxa(project=project_one)
create_taxa(project=project_two)
# Show project & deployment IDs
print(f"Project One: {project_one}")
print(f"Project Two: {project_two}")
print(f"Deployment One: {deployment_one.pk}")
print(f"Deployment Two: {deployment_two.pk}")
create_captures(deployment=deployment_one)