Skip to content

Commit a4dd173

Browse files
mihowannavik
andauthored
Improve representation of generated clusters (#849)
* feat: don't show placeholder scores from human identifications * feat: prioritize determinations from clustering, regardless of score * feat: method to calculate score of cluster points * feat: track more details about cluster members * fix: score calculation and display of classifications for clusters * feat: use common ancestor of predicted taxon for name & parent * feat: use any occurrence's best detection as fallback cover image * feat: update automatic naming * feat: fix missing parents in taxa list view * copy: cleanup auto generated taxa list name * feat: remove species name from cluster name, expand notes. * fix: calculation of determination score --------- Co-authored-by: Anna Viklund <annamariaviklund@gmail.com>
1 parent a6bead3 commit a4dd173

9 files changed

Lines changed: 240 additions & 73 deletions

File tree

ami/main/api/serializers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
SourceImageUpload,
3333
TaxaList,
3434
Taxon,
35+
get_media_url,
3536
validate_filename_timestamp,
3637
)
3738

@@ -520,6 +521,7 @@ class TaxonListSerializer(DefaultSerializer):
520521
occurrences = serializers.SerializerMethodField()
521522
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
522523
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent")
524+
cover_image_url = serializers.SerializerMethodField()
523525
tags = serializers.SerializerMethodField()
524526

525527
def get_tags(self, obj):
@@ -549,6 +551,8 @@ class Meta:
549551
def get_occurrences(self, obj):
550552
"""
551553
Return URL to the occurrences endpoint filtered by this taxon.
554+
555+
Does not make a database query.
552556
"""
553557

554558
params = {}
@@ -561,6 +565,15 @@ def get_occurrences(self, obj):
561565
params=params,
562566
)
563567

568+
def get_cover_image_url(self, obj):
569+
if obj.cover_image_url:
570+
return obj.cover_image_url
571+
elif hasattr(obj, "best_detection_image_path") and obj.best_detection_image_path:
572+
# This attribute is added by an QuerySet annotation
573+
return get_media_url(obj.best_detection_image_path)
574+
else:
575+
return None
576+
564577

565578
class TaxaListSerializer(serializers.ModelSerializer):
566579
taxa = serializers.SerializerMethodField()
@@ -745,6 +758,7 @@ class TaxonSerializer(DefaultSerializer):
745758
parent = TaxonNoParentNestedSerializer(read_only=True)
746759
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent", write_only=True)
747760
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
761+
cover_image_url = serializers.SerializerMethodField()
748762
tags = serializers.SerializerMethodField()
749763

750764
def get_tags(self, obj):
@@ -774,6 +788,15 @@ class Meta:
774788
"unknown_species",
775789
]
776790

791+
def get_cover_image_url(self, obj):
792+
if obj.cover_image_url:
793+
return obj.cover_image_url
794+
elif hasattr(obj, "best_detection_image_path") and obj.best_detection_image_path:
795+
# This attribute is added by an QuerySet annotation
796+
return get_media_url(obj.best_detection_image_path)
797+
else:
798+
return None
799+
777800

778801
class CaptureOccurrenceSerializer(DefaultSerializer):
779802
determination = TaxonNoParentNestedSerializer(read_only=True)

ami/main/api/views.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,8 +1371,19 @@ def get_queryset(self) -> QuerySet:
13711371
qs = self.attach_tags_by_project(qs, project)
13721372

13731373
if project:
1374-
# Allow showing detail views for unobserved taxa
1375-
include_unobserved = True
1374+
include_unobserved = True # Show detail views for unobserved taxa instead of 404
1375+
# @TODO move to a QuerySet manager
1376+
qs = qs.annotate(
1377+
best_detection_image_path=models.Subquery(
1378+
Occurrence.objects.filter(
1379+
self.get_occurrence_filters(project),
1380+
determination_id=models.OuterRef("id"),
1381+
)
1382+
.order_by("-determination_score")
1383+
.values("best_detection__path")[:1],
1384+
output_field=models.TextField(),
1385+
)
1386+
)
13761387
if self.action == "list":
13771388
include_unobserved = self.request.query_params.get("include_unobserved", False)
13781389
qs = self.get_taxa_observed(qs, project, include_unobserved=include_unobserved)

ami/main/management/commands/import_taxa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def create_taxon(self, taxon_data: dict, root_taxon_parent: Taxon) -> tuple[set[
359359
parent = None
360360
if taxon.parent != parent:
361361
if not created:
362-
logger.warn(f"Changing parent of {taxon} from {taxon.parent} to more specific {parent}")
362+
logger.warning(f"Changing parent of {taxon} from {taxon.parent} to more specific {parent}")
363363
taxon.parent = parent
364364
taxon.save(update_calculated_fields=False)
365365
if not created:

ami/main/models.py

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,31 +2430,22 @@ def get_best_detection(self) -> Detection | None:
24302430

24312431
def get_best_predictions(self, filters: dict = {}) -> models.QuerySet[Classification]:
24322432
"""
2433-
Retrieve the classification with the max score for each algorithm
2434-
from any detection belonging to this occurrence.
2433+
Retrieve all classifications for this occurrence in chronological order.
24352434
24362435
This data is for the list of predictions in the Identification tab of the Occurrence Detail view
24372436
in the UI. See the OccurrenceSerializer for the serializer method.
24382437
24392438
If this is need for a list view (multiple occurrenes) it should be overriden
24402439
in the viewset to use the pre-fetched classifications instead of hitting the database
24412440
for each occurrence (n+1 query problem).
2441+
2442+
In the past, this was a more complext query that returned a single result
2443+
for each algorithm, but now it returns all classifications for the occurrence
24422444
"""
2443-
# Get the highest scoring classification for each algorithm
2444-
# Use a subquery to find the max score for each algorithm
2445-
subquery = (
2446-
Classification.objects.filter(detection__occurrence=self, **filters)
2447-
.values("algorithm")
2448-
.annotate(max_score=models.Max("score"))
2449-
)
24502445

2451-
# Join the subquery results to get the classifications with those max scores
2452-
# This ensures we get one classification per algorithm (the one with highest score)
24532446
classifications = Classification.objects.filter(
24542447
detection__occurrence=self,
24552448
**filters,
2456-
algorithm__in=models.Subquery(subquery.values("algorithm")),
2457-
score__in=models.Subquery(subquery.values("max_score")),
24582449
).order_by("-created_at")
24592450

24602451
return classifications
@@ -2473,6 +2464,22 @@ def get_best_prediction(self, filters: dict = {}) -> Classification | None:
24732464
# Get all classifications for this occurrence to choose from
24742465
all_classifications = Classification.objects.filter(detection__occurrence=self, **filters)
24752466

2467+
# Prioritize derived classifications (e.g. clustering) regardless of score
2468+
derived_classification_task_types = (
2469+
"clustering",
2470+
"tracking",
2471+
)
2472+
derived_classification = (
2473+
all_classifications.filter(
2474+
algorithm__task_type__in=derived_classification_task_types,
2475+
terminal=True,
2476+
)
2477+
.order_by("-created_at")
2478+
.first()
2479+
)
2480+
if derived_classification:
2481+
return derived_classification
2482+
24762483
# First try to get a terminal classification
24772484
terminal_classification = all_classifications.filter(terminal=True).order_by("-score", "-created_at").first()
24782485
if terminal_classification:
@@ -2481,6 +2488,9 @@ def get_best_prediction(self, filters: dict = {}) -> Classification | None:
24812488
# If no terminal classification exists, fall back to non-terminal
24822489
return all_classifications.filter(terminal=False).order_by("-score").first()
24832490

2491+
def get_best_ood_prediction(self) -> Classification | None:
2492+
return self.get_best_prediction(filters={"ood_score__isnull": False})
2493+
24842494
def get_best_identification(self) -> Identification | None:
24852495
"""
24862496
The most recent human identification is used as the best identification.
@@ -2489,17 +2499,17 @@ def get_best_identification(self) -> Identification | None:
24892499
"""
24902500
return Identification.objects.filter(occurrence=self, withdrawn=False).order_by("-created_at").first()
24912501

2492-
def get_determination_score(self) -> float | None:
2493-
if not self.determination:
2502+
def get_determination_score(self, prediction: Classification | None = None) -> float | None:
2503+
"""
2504+
Always return a score from an algorithm, even if a human has identified the occurrence.
2505+
"""
2506+
best_prediction = prediction or self.get_best_prediction()
2507+
if not best_prediction:
24942508
return None
2495-
elif self.best_identification:
2496-
return self.best_identification.score
2497-
elif self.best_prediction:
2498-
return self.best_prediction.score
24992509
else:
2500-
return None
2510+
return best_prediction.score
25012511

2502-
def get_determination_ood_score(self) -> float | None:
2512+
def get_determination_ood_score(self, prediction: Classification | None = None) -> float | None:
25032513
"""
25042514
Calculate the OOD score for the whole occurrence.
25052515
Uses the average OOD score of all detections belonging to this occurrence.
@@ -2508,16 +2518,16 @@ def get_determination_ood_score(self) -> float | None:
25082518
"""
25092519
# Get the best prediction that has an OOD score
25102520
# this should be the last classification before the clustering algorithm
2511-
# @TODO copy the OOD score from the best classification to the clustering classification during clustering
2512-
best_prediction = self.get_best_prediction(filters={"ood_score__isnull": False})
2521+
best_prediction = prediction or self.get_best_ood_prediction()
25132522
if not best_prediction:
25142523
return None
2515-
mean_ood_score = Classification.objects.filter(
2516-
detection__occurrence=self,
2517-
ood_score__isnull=False,
2518-
algorithm=best_prediction.algorithm,
2519-
).aggregate(models.Avg("ood_score"),)["ood_score__avg"]
2520-
return mean_ood_score
2524+
else:
2525+
mean_ood_score = Classification.objects.filter(
2526+
detection__occurrence=self,
2527+
ood_score__isnull=False,
2528+
algorithm=best_prediction.algorithm,
2529+
).aggregate(models.Avg("ood_score"),)["ood_score__avg"]
2530+
return mean_ood_score
25212531

25222532
def context_url(self):
25232533
detection = self.best_detection
@@ -2540,16 +2550,6 @@ def save(self, update_determination=True, *args, **kwargs):
25402550
save=True,
25412551
)
25422552

2543-
if self.determination and not self.determination_score:
2544-
# This may happen for legacy occurrences that were created
2545-
# before the determination_score field was added
2546-
# @TODO remove
2547-
self.determination_score = self.get_determination_score()
2548-
if not self.determination_score:
2549-
logger.warning(f"Could not determine score for {self}")
2550-
else:
2551-
self.save(update_determination=False)
2552-
25532553
class Meta:
25542554
ordering = ["-determination_score"]
25552555

@@ -2584,23 +2584,24 @@ def update_occurrence_determination(
25842584

25852585
# Collect all necessary values first
25862586
best_identification = occurrence.get_best_identification()
2587-
best_prediction = occurrence.get_best_prediction() if not best_identification else None
2587+
best_prediction = occurrence.get_best_prediction()
2588+
best_ood_prediction = occurrence.get_best_ood_prediction()
25882589

25892590
# Best detection is used as the representative image for the occurrence in either case
25902591
best_detection = occurrence.get_best_detection()
25912592

2592-
# Determine values for all attributes
2593+
# Update the determination (Taxon) first
25932594
new_determination = None
2594-
new_determination_score = None
2595-
new_determination_ood_score = occurrence.get_determination_ood_score()
25962595

25972596
# Identifications take precedence over machine predictions
25982597
if best_identification:
25992598
new_determination = best_identification.taxon
2600-
new_determination_score = best_identification.score
26012599
elif best_prediction:
26022600
new_determination = best_prediction.taxon
2603-
new_determination_score = best_prediction.score
2601+
2602+
# Update scores, which may or may not come from the same source as the determination
2603+
new_determination_score = occurrence.get_determination_score(prediction=best_prediction)
2604+
new_determination_ood_score = occurrence.get_determination_ood_score(prediction=best_ood_prediction)
26042605

26052606
# Prepare fields that need to be updated (using a dictionary for bulk update)
26062607
update_fields = {}
@@ -2862,6 +2863,7 @@ class Config:
28622863
# so we can sort by rank. The DRF serializer will convert it to a string.
28632864
# just for the API responses.
28642865
use_enum_values = False
2866+
frozen = True # Allow hashing for use in a set
28652867

28662868

28672869
@final
@@ -3099,6 +3101,60 @@ def save(self, update_calculated_fields=True, *args, **kwargs):
30993101
self.update_calculated_fields(save=True)
31003102

31013103

3104+
def find_common_ancestor_taxon(
3105+
taxa: list["Taxon"],
3106+
ignore_missing_parents: bool = True,
3107+
) -> typing.Optional["Taxon"]:
3108+
"""
3109+
Find the common ancestor taxon for a list of taxa.
3110+
Args:
3111+
taxa (list[Taxon]): A list of Taxon objects.
3112+
ignore_rootless (bool): If True, ignore taxa without parents. Defaults to True.
3113+
Returns:
3114+
Taxon | None: The common ancestor taxon, or None if no common ancestor exists.
3115+
"""
3116+
if not taxa:
3117+
return None
3118+
3119+
# Filter taxa based on whether they have parents
3120+
valid_taxa = taxa
3121+
if ignore_missing_parents:
3122+
valid_taxa = [t for t in taxa if t.parents_json]
3123+
rootless_count = len(taxa) - len(valid_taxa)
3124+
if rootless_count:
3125+
logger.warning(f"Ignoring {rootless_count} rootless taxa")
3126+
3127+
if not valid_taxa:
3128+
logger.error("No taxa with parents found")
3129+
return None
3130+
3131+
# Build ancestor sets for each taxon
3132+
ancestor_sets = []
3133+
for taxon in valid_taxa:
3134+
ancestors = set(taxon.parents_json)
3135+
# Include the taxon itself
3136+
ancestors.add(TaxonParent(id=taxon.pk, name=taxon.name, rank=TaxonRank(taxon.rank)))
3137+
ancestor_sets.append(ancestors)
3138+
3139+
# Find common ancestors
3140+
common_ancestors = set.intersection(*ancestor_sets)
3141+
3142+
if not common_ancestors:
3143+
logger.info("No common ancestor found")
3144+
return None
3145+
3146+
# Find the most specific common ancestor (highest rank index)
3147+
best_ancestor = max(common_ancestors, key=lambda a: list(TaxonRank).index(a.rank))
3148+
3149+
logger.info(f"Common ancestor: {best_ancestor.name} ({best_ancestor.rank})")
3150+
3151+
# Return the actual Taxon object
3152+
from .models import Taxon
3153+
3154+
result = Taxon.objects.get(id=best_ancestor.id)
3155+
return result
3156+
3157+
31023158
@final
31033159
class TaxaList(BaseModel):
31043160
"""A checklist of taxa"""

ami/main/tests/test_occurrence_determination.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_update_with_identification(self):
9696

9797
# Check that the determination is set to the identification's taxon
9898
self.assertEqual(self.occurrence.determination, self.taxon2)
99-
self.assertEqual(self.occurrence.determination_score, 1.0) # Human identifications have score 1.0
99+
self.assertEqual(self.occurrence.determination_score, None) # Human identifications have no score
100100
self.assertEqual(self.occurrence.best_identification, identification)
101101

102102
def test_identification_overrides_classification(self):

ami/ml/clustering_algorithms/agglomerative.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from scipy.spatial.distance import pdist, squareform
66
from sklearn.cluster import AgglomerativeClustering
7+
from sklearn.metrics import silhouette_samples
78

89
from .base_clusterer import BaseClusterer
910
from .preprocessing_features import dimension_reduction, standardize
@@ -69,7 +70,7 @@ def setup(self, data_dict):
6970
data_dict["val"]["feat_list"], data_dict["val"]["label_list"]
7071
)
7172

72-
def cluster(self, features):
73+
def cluster(self, features) -> tuple[np.ndarray, np.ndarray]:
7374
logger.info(f"distance threshold: {self.distance_threshold}")
7475
logger.info("features shape: %s", features.shape)
7576
logger.info(f"self.n_components: {self.n_components}")
@@ -84,8 +85,21 @@ def cluster(self, features):
8485
linkage = self.config.get("algorithm_kwargs", {}).get("linkage", "ward")
8586
logger.info(f" features shape after PCA: {features.shape}")
8687

87-
clusters = AgglomerativeClustering(
88+
cluster_ids = AgglomerativeClustering(
8889
n_clusters=None, distance_threshold=self.distance_threshold, linkage=linkage
8990
).fit_predict(features)
9091

91-
return clusters
92+
try:
93+
silhouette_scores = silhouette_samples(features, cluster_ids)
94+
silhouette_scores = np.asarray(silhouette_scores)
95+
# Scale from -1 to 1 to 0 to 1
96+
silhouette_scores = (silhouette_scores + 1) / 2
97+
except ValueError:
98+
# If silhouette scores cannot be computed, return an array of zeros
99+
logger.warning(
100+
f"Returned {len(cluster_ids)} clusters for {len(features)} features. "
101+
"Cannot compute silhouette scores so setting them to zero."
102+
)
103+
silhouette_scores = np.zeros(features.shape[0], dtype=np.float32)
104+
105+
return cluster_ids, silhouette_scores

0 commit comments

Comments
 (0)