Skip to content

Commit c90fb5f

Browse files
committed
Fix parents field in COCO JSON categories
1 parent 15eed82 commit c90fb5f

2 files changed

Lines changed: 56 additions & 4 deletions

File tree

ami/exports/format_types.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ami.exports.base import BaseExporter
1313
from ami.exports.utils import get_data_in_batches
14-
from ami.main.models import Occurrence, SourceImage, get_media_url
14+
from ami.main.models import Occurrence, SourceImage, Taxon, get_media_url
1515
from ami.ml.schemas import BoundingBox
1616

1717
logger = logging.getLogger(__name__)
@@ -236,6 +236,7 @@ class OccurrenceCocoTabularSerializer(OccurrenceTabularSerializer):
236236
capture_path = serializers.SerializerMethodField()
237237
capture_width = serializers.SerializerMethodField()
238238
capture_height = serializers.SerializerMethodField()
239+
best_machine_prediction_taxon_id = serializers.IntegerField(allow_null=True, default=None)
239240

240241
class Meta(OccurrenceTabularSerializer.Meta):
241242
fields = OccurrenceTabularSerializer.Meta.fields + [
@@ -244,6 +245,7 @@ class Meta(OccurrenceTabularSerializer.Meta):
244245
"capture_path",
245246
"capture_width",
246247
"capture_height",
248+
"best_machine_prediction_taxon_id",
247249
]
248250

249251
def get_source_image_id(self, obj):
@@ -281,6 +283,7 @@ def build_coco_dict_from_occurrence_rows(rows: list[dict], project) -> dict:
281283
categories_by_id: dict[int, dict] = {}
282284
images_by_id: dict[int, dict] = {}
283285
annotations: list[dict] = []
286+
category_taxon_ids: set[int] = set()
284287

285288
for row in rows:
286289
determination_id = row.get("determination_id")
@@ -305,6 +308,7 @@ def build_coco_dict_from_occurrence_rows(rows: list[dict], project) -> dict:
305308
"id": int(determination_id),
306309
"name": det_name,
307310
}
311+
category_taxon_ids.add(int(determination_id))
308312
else:
309313
assert (
310314
categories_by_id[int(determination_id)]["name"] == det_name
@@ -335,15 +339,63 @@ def build_coco_dict_from_occurrence_rows(rows: list[dict], project) -> dict:
335339
"iscrowd": 0, # TODO: Could we use this field to indiate crowd of insects?
336340
"determination_score": row.get("determination_score"),
337341
"verification_status": row.get("verification_status"),
338-
"best_machine_prediction_name": row.get("best_machine_prediction_name"),
342+
"determination_matches_machine_prediction": row.get("determination_matches_machine_prediction"),
339343
"best_machine_prediction_algorithm": row.get("best_machine_prediction_algorithm"),
340344
"best_machine_prediction_score": row.get("best_machine_prediction_score"),
341-
"determination_matches_machine_prediction": row.get("determination_matches_machine_prediction"),
345+
"best_machine_prediction_taxon_id": row.get("best_machine_prediction_taxon_id"),
342346
"best_detection_width": row.get("best_detection_width"),
343347
"best_detection_height": row.get("best_detection_height"),
344348
}
349+
prediction_taxon_id = row.get("best_machine_prediction_taxon_id")
350+
if prediction_taxon_id is not None:
351+
try:
352+
category_taxon_ids.add(int(prediction_taxon_id))
353+
except (TypeError, ValueError):
354+
logger.warning(f"Invalid best_machine_prediction_taxon_id for row: {row}")
345355
annotations.append(ann)
346356

357+
def _serialize_parents_json(parents_json):
358+
if not isinstance(parents_json, list):
359+
return []
360+
serialized = []
361+
for parent in parents_json:
362+
if isinstance(parent, dict):
363+
parent_id = parent.get("id")
364+
parent_name = parent.get("name")
365+
rank = parent.get("rank")
366+
else:
367+
# SchemaField(list[TaxonParent]) may return Pydantic objects rather than dicts.
368+
parent_id = getattr(parent, "id", None)
369+
parent_name = getattr(parent, "name", None)
370+
rank = getattr(parent, "rank", None)
371+
372+
if parent_id is None and parent_name is None and rank is None:
373+
continue
374+
375+
rank_value = rank.value if hasattr(rank, "value") else rank
376+
serialized.append(
377+
{
378+
"id": parent_id,
379+
"name": parent_name,
380+
"rank": str(rank_value) if rank_value is not None else None,
381+
}
382+
)
383+
return serialized
384+
385+
if category_taxon_ids:
386+
taxa = Taxon.objects.filter(id__in=category_taxon_ids).values(
387+
"id", "name", "rank", "parent_id", "parents_json"
388+
)
389+
for taxon in taxa:
390+
taxon_id = int(taxon["id"])
391+
categories_by_id[taxon_id] = {
392+
"id": taxon_id,
393+
"name": taxon.get("name") or categories_by_id.get(taxon_id, {}).get("name", ""),
394+
"rank": taxon.get("rank"),
395+
"parent_id": taxon.get("parent_id"),
396+
"parents": _serialize_parents_json(taxon.get("parents_json")),
397+
}
398+
347399
base = getattr(settings, "EXTERNAL_BASE_URL", "") or ""
348400
info_url = ""
349401
if base.strip():

ami/main/api/serializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ class TaxaListTaxonSerializer(TaxonNoParentNestedSerializer):
702702

703703
class CaptureTaxonSerializer(DefaultSerializer):
704704
parent = TaxonNoParentNestedSerializer(read_only=True)
705-
parents = TaxonParentSerializer(many=True, read_only=True)
705+
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
706706

707707
def get_permissions(self, instance, instance_data):
708708
instance_data["user_permissions"] = []

0 commit comments

Comments
 (0)