Skip to content

Commit 0a0ea5e

Browse files
authored
Fixes & features for higher taxon ranks (#496)
* Update admin to use new parents_json * Add tests for parents_json & recursive counts * Default taxon parents to list, update django-pydantic-fields * Use real TaxonRank objects everywhere except API responses * Show all occurrences under a taxon parent * Placeholder hack to include a recursive occurrence count for detail views * Give up on add all occurrence counts in one query for now * Link to occurrences list in species detail modal, even if occurrences count not available
1 parent f7f3470 commit 0a0ea5e

9 files changed

Lines changed: 330 additions & 41 deletions

File tree

ami/main/admin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ class TaxonAdmin(admin.ModelAdmin[Taxon]):
242242
)
243243
list_filter = ("lists", "rank", TaxonParentFilter)
244244
search_fields = ("name",)
245-
exclude = ("parents",)
246245

247246
# annotate queryset with occurrence counts and allow sorting
248247
# https://docs.djangoproject.com/en/3.2/ref/contrib/admin/#django.contrib.admin.ModelAdmin.list_display
@@ -278,7 +277,10 @@ def update_display_names(self, request: HttpRequest, queryset: QuerySet[Taxon])
278277
ordering="parents",
279278
)
280279
def parent_names(self, obj) -> str:
281-
return ", ".join([str(taxon) for taxon in obj.parents.values_list("name", flat=True)])
280+
if obj.parents_json:
281+
return ", ".join([str(taxon.name) for taxon in obj.parents_json])
282+
else:
283+
return ""
282284

283285
actions = [update_species_parents, update_display_names]
284286

ami/main/api/serializers.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import datetime
22

33
from django.db.models import QuerySet
4-
from django_pydantic_field.rest_framework import SchemaField
54
from rest_framework import serializers
65

76
from ami.base.serializers import DefaultSerializer, get_current_user, reverse_with_params
@@ -29,7 +28,6 @@
2928
SourceImageCollection,
3029
SourceImageUpload,
3130
Taxon,
32-
TaxonParent,
3331
)
3432

3533

@@ -408,13 +406,22 @@ class Meta:
408406
]
409407

410408

409+
class TaxonParentSerializer(serializers.Serializer):
410+
id = serializers.IntegerField()
411+
name = serializers.CharField()
412+
rank = serializers.SerializerMethodField()
413+
414+
def get_rank(self, obj):
415+
return obj.rank.value
416+
417+
411418
class TaxonNestedSerializer(TaxonNoParentNestedSerializer):
412419
"""
413420
Simple Taxon serializer with 1 level of nested parents.
414421
"""
415422

416423
parent = TaxonNoParentNestedSerializer(read_only=True)
417-
parents = SchemaField(list[TaxonParent], source="parents_json", read_only=True)
424+
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
418425

419426
class Meta(TaxonNoParentNestedSerializer.Meta):
420427
fields = TaxonNoParentNestedSerializer.Meta.fields + [
@@ -492,7 +499,7 @@ def get_occurrence_images(self, obj):
492499

493500
class CaptureTaxonSerializer(DefaultSerializer):
494501
parent = TaxonNoParentNestedSerializer(read_only=True)
495-
parents = SchemaField(list[TaxonParent], source="parents_json", read_only=True)
502+
parents = TaxonParentSerializer(many=True, read_only=True)
496503

497504
class Meta:
498505
model = Taxon
@@ -649,7 +656,7 @@ class TaxonSerializer(DefaultSerializer):
649656
parent = TaxonNoParentNestedSerializer(read_only=True)
650657
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent", write_only=True)
651658
# parents = TaxonParentNestedSerializer(many=True, read_only=True, source="parents_json")
652-
parents = SchemaField(list[TaxonParent], source="parents_json", read_only=True)
659+
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
653660

654661
class Meta:
655662
model = Taxon

ami/main/api/views.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from django.utils import timezone
1212
from django_filters.rest_framework import DjangoFilterBackend
1313
from rest_framework import exceptions as api_exceptions
14-
from rest_framework import serializers, status, viewsets
14+
from rest_framework import filters, serializers, status, viewsets
1515
from rest_framework.decorators import action
1616
from rest_framework.exceptions import NotFound
1717
from rest_framework.filters import SearchFilter
@@ -626,6 +626,20 @@ def get_serializer_class(self):
626626
# "detection_algorithm").all()
627627

628628

629+
class CustomDeterminationFilter(filters.BaseFilterBackend):
630+
def filter_queryset(self, request, queryset, view):
631+
determination_id = request.query_params.get("determination")
632+
if determination_id:
633+
try:
634+
taxon = Taxon.objects.get(id=determination_id)
635+
return queryset.filter(
636+
models.Q(determination=taxon) | models.Q(determination__parents_json__contains=[{"id": taxon.id}])
637+
)
638+
except Taxon.DoesNotExist:
639+
return queryset.none() # or just return queryset if you prefer
640+
return queryset
641+
642+
629643
class OccurrenceViewSet(DefaultViewSet):
630644
"""
631645
API endpoint that allows occurrences to be viewed or edited.
@@ -634,7 +648,9 @@ class OccurrenceViewSet(DefaultViewSet):
634648
queryset = Occurrence.objects.all()
635649

636650
serializer_class = OccurrenceSerializer
637-
filterset_fields = ["event", "deployment", "determination", "project", "determination__rank"]
651+
# filter_backends = [CustomDeterminationFilter, DjangoFilterBackend, NullsLastOrderingFilter, SearchFilter]
652+
filter_backends = DefaultViewSetMixin.filter_backends + [CustomDeterminationFilter]
653+
filterset_fields = ["event", "deployment", "project", "determination__rank"]
638654
ordering_fields = [
639655
"created_at",
640656
"updated_at",
@@ -681,6 +697,7 @@ def get_queryset(self) -> QuerySet:
681697
.exclude(first_appearance_timestamp=None) # This must come after annotations
682698
.order_by("-determination_score")
683699
)
700+
684701
else:
685702
qs = qs.prefetch_related(
686703
Prefetch(
@@ -895,6 +912,15 @@ def get_queryset(self) -> QuerySet:
895912

896913
return qs
897914

915+
# def retrieve(self, request: Request, *args, **kwargs) -> Response:
916+
# """
917+
# Override the serializer to include the recursive occurrences count
918+
# """
919+
# taxon: Taxon = self.get_object()
920+
# taxon.occurrences_count = taxon.occurrences_count_recursive() # type: ignore
921+
# response = Response(TaxonSerializer(taxon, context={"request": request}).data)
922+
# return response
923+
898924

899925
class ClassificationViewSet(DefaultViewSet):
900926
"""
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Generated by Django 4.2.10 on 2024-08-06 17:47
2+
3+
import ami.main.models
4+
from django.db import migrations, models
5+
import django_pydantic_field._migration_serializers
6+
import django_pydantic_field.fields
7+
8+
9+
class Migration(migrations.Migration):
10+
dependencies = [
11+
("main", "0034_remove_taxon_parents_taxon_parents_json"),
12+
]
13+
14+
operations = [
15+
migrations.AlterField(
16+
model_name="taxon",
17+
name="parents_json",
18+
field=django_pydantic_field.fields.PydanticSchemaField(
19+
blank=True,
20+
config=None,
21+
default=list,
22+
schema=django_pydantic_field._migration_serializers.GenericContainer(
23+
list, (ami.main.models.TaxonParent,)
24+
),
25+
),
26+
),
27+
migrations.AlterField(
28+
model_name="taxon",
29+
name="rank",
30+
field=models.CharField(
31+
choices=[
32+
("ORDER", "ORDER"),
33+
("SUPERFAMILY", "SUPERFAMILY"),
34+
("FAMILY", "FAMILY"),
35+
("SUBFAMILY", "SUBFAMILY"),
36+
("TRIBE", "TRIBE"),
37+
("SUBTRIBE", "SUBTRIBE"),
38+
("GENUS", "GENUS"),
39+
("SPECIES", "SPECIES"),
40+
("UNKNOWN", "UNKNOWN"),
41+
],
42+
default="SPECIES",
43+
max_length=255,
44+
),
45+
),
46+
]

ami/main/models.py

Lines changed: 95 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2106,45 +2106,93 @@ def root(self):
21062106

21072107
def update_all_parents(self):
21082108
"""Efficiently update all parents for all taxa."""
2109-
21102109
taxa = self.get_queryset().select_related("parent")
2111-
21122110
logging.info(f"Updating the cached parent tree for {taxa.count()} taxa")
21132111

21142112
# Build a dictionary of taxon parents
2115-
parents = {taxon: taxon.parent for taxon in taxa}
2113+
parents = {taxon.id: taxon.parent_id for taxon in taxa}
2114+
2115+
# Precompute all parents in a single pass
2116+
all_parents = {}
2117+
for taxon_id in parents:
2118+
if taxon_id not in all_parents:
2119+
taxon_parents = []
2120+
current_id = taxon_id
2121+
while current_id in parents:
2122+
current_id = parents[current_id]
2123+
taxon_parents.append(current_id)
2124+
all_parents[taxon_id] = taxon_parents
2125+
2126+
# Prepare bulk update data
2127+
bulk_update_data = []
2128+
for taxon in taxa:
2129+
taxon_parents = all_parents[taxon.id]
2130+
parent_taxa = list(taxa.filter(id__in=taxon_parents))
2131+
taxon_parents = [
2132+
TaxonParent(
2133+
id=taxon.id,
2134+
name=taxon.name,
2135+
rank=taxon.rank,
2136+
)
2137+
for taxon in parent_taxa
2138+
]
2139+
taxon_parents.sort(key=lambda t: t.rank)
21162140

2117-
# Update all parents
2118-
for taxon, parent in parents.items():
2119-
logging.info(f"Updating parents for {taxon}")
2141+
bulk_update_data.append(taxon)
21202142

2121-
taxon_parents = []
2122-
while parent:
2123-
taxon_parents.append(parent)
2124-
# If this is None, the parent is the root taxon, so we stop here.
2125-
parent = parents.get(parent)
2143+
# Perform bulk update
2144+
# with transaction.atomic():
2145+
# self.bulk_update(bulk_update_data, ["parents_json"], batch_size=1000)
2146+
# There is a bug that causes the bulk update to fail with a custom JSONField
2147+
# https://code.djangoproject.com/ticket/35167
2148+
# So we have to update each taxon individually
2149+
for taxon in bulk_update_data:
2150+
taxon.save(update_fields=["parents_json"])
21262151

2127-
# Convert the taxa to the JSON TaxonParent type
2128-
taxon_parents = [TaxonParent(id=t.pk, name=t.name, rank=TaxonRank(t.rank)) for t in taxon_parents]
2152+
logging.info(f"Updated parents for {len(bulk_update_data)} taxa")
21292153

2130-
# Sort the parents by rank (achievable because TaxonRank is an ordered enum)
2131-
taxon_parents.sort(key=lambda t: t.rank)
2154+
def with_children(self):
2155+
qs = self.get_queryset()
2156+
# Add Taxon that are children of this Taxon using parents_json field (not direct_children)
21322157

2133-
taxon.parents_json = taxon_parents
2134-
taxon.save()
2158+
# example for single taxon:
2159+
taxon = Taxon.objects.get(pk=1)
2160+
taxa = Taxon.objects.filter(parents_json__contains=[{"id": taxon.id}])
2161+
# add them to the queryset
2162+
qs = qs.annotate(children=models.Subquery(taxa.values("id")))
2163+
return qs
2164+
2165+
def with_occurrence_counts(self) -> models.QuerySet:
2166+
"""
2167+
Count the number of occurrences for a taxon and all occurrences of the taxon's children.
2168+
2169+
@TODO Try a recursive CTE in a raw SQL query,
2170+
or count the occurrences in a separate query and attach them to the Taxon objects.
2171+
"""
2172+
2173+
raise NotImplementedError(
2174+
"Occurrence counts can not be calculated in a subquery with the current JSONField schema. "
2175+
"Fetch them per taxon."
2176+
)
21352177

21362178

21372179
class TaxonParent(pydantic.BaseModel):
21382180
"""
21392181
Should contain all data needed for TaxonParentSerializer
2182+
2183+
Needs a custom encoder and decoder for for the TaxonRank enum
2184+
because it is an OrderedEnum and not a standard str Enum.
21402185
"""
21412186

21422187
id: int
21432188
name: str
21442189
rank: TaxonRank
21452190

21462191
class Config:
2147-
use_enum_values = True
2192+
# Make sure the TaxonRank is retrieved as an object and not a string
2193+
# so we can sort by rank. The DRF serializer will convert it to a string.
2194+
# just for the API responses.
2195+
use_enum_values = False
21482196

21492197

21502198
@final
@@ -2161,7 +2209,7 @@ class Taxon(BaseModel):
21612209
# Examples how to query this JSON array field
21622210
# Taxon.objects.filter(parents_json__contains=[{"id": 1}])
21632211
# https://stackoverflow.com/a/53942463/966058
2164-
parents_json = SchemaField(list[TaxonParent], null=True, blank=True)
2212+
parents_json = SchemaField(list[TaxonParent], null=False, blank=True, default=list)
21652213

21662214
active = models.BooleanField(default=True)
21672215
synonym_of = models.ForeignKey("self", on_delete=models.SET_NULL, null=True, blank=True, related_name="synonyms")
@@ -2174,7 +2222,6 @@ class Taxon(BaseModel):
21742222

21752223
projects = models.ManyToManyField("Project", related_name="taxa")
21762224
direct_children: models.QuerySet["Taxon"]
2177-
children: models.QuerySet["Taxon"]
21782225
occurrences: models.QuerySet[Occurrence]
21792226
classifications: models.QuerySet["Classification"]
21802227
lists: models.QuerySet["TaxaList"]
@@ -2186,6 +2233,9 @@ class Taxon(BaseModel):
21862233

21872234
objects: TaxaManager = TaxaManager()
21882235

2236+
# Type hints for auto-generated fields
2237+
parent_id: int | None
2238+
21892239
def __str__(self) -> str:
21902240
name_with_rank = f"{self.name} ({self.rank})"
21912241
return name_with_rank
@@ -2213,13 +2263,24 @@ def num_direct_children(self) -> int:
22132263
return self.direct_children.count()
22142264

22152265
def num_children_recursive(self) -> int:
2216-
# @TODO how to do this with a single query?
2217-
return self.children.count() + sum(child.num_children_recursive() for child in self.children.all())
2266+
# Use the parents_json field to get all children
2267+
return Taxon.objects.filter(parents_json__contains=[{"id": self.pk}]).count()
22182268

22192269
def occurrences_count(self) -> int:
22202270
# return self.occurrences.count()
22212271
return 0
22222272

2273+
def occurrences_count_recursive(self) -> int:
2274+
"""
2275+
Use the parents_json field to get all children, count their occurrences and sum them.
2276+
"""
2277+
return (
2278+
Taxon.objects.filter(models.Q(models.Q(parents_json__contains=[{"id": self.pk}]) | models.Q(id=self.pk)))
2279+
.annotate(occurrences_count=models.Count("occurrences"))
2280+
.aggregate(models.Sum("occurrences_count"))["occurrences_count__sum"]
2281+
or 0
2282+
)
2283+
22232284
def detections_count(self) -> int:
22242285
# return Detection.objects.filter(occurrence__determination=self).count()
22252286
return 0
@@ -2288,21 +2349,25 @@ def list_names(self) -> str:
22882349

22892350
def update_parents(self, save=True):
22902351
"""
2291-
Populate the cached "parents" list by recursively following the "parent" field.
2352+
Populate the cached `parents_json` list by recursively following the `parent` field.
22922353
2293-
@TODO this requires the parents' parents already being up-to-date, which may not always be the case.
2354+
@TODO this requires all of the taxon's parent taxa to have the `parent` attribute set correctly.
22942355
"""
22952356

2296-
taxon = self
2357+
current_taxon = self
22972358
parents = []
2298-
while taxon.parent is not None:
2299-
parents.append(TaxonParent(id=taxon.parent.id, name=taxon.parent.name, rank=taxon.parent.rank))
2300-
taxon = taxon.parent
2359+
while current_taxon.parent is not None:
2360+
parents.append(
2361+
TaxonParent(id=current_taxon.parent.id, name=current_taxon.parent.name, rank=current_taxon.parent.rank)
2362+
)
2363+
current_taxon = current_taxon.parent
23012364
# Sort parents by rank using ordered enum
23022365
parents = sorted(parents, key=lambda t: t.rank)
2303-
taxon.parents_json = parents
2366+
self.parents_json = parents
23042367
if save:
2305-
taxon.save()
2368+
self.save()
2369+
2370+
return parents
23062371

23072372
class Meta:
23082373
ordering = [

0 commit comments

Comments
 (0)