1111
1212from ami .exports .base import BaseExporter
1313from ami .exports .utils import get_data_in_batches
14- from ami .main .models import Occurrence , SourceImage , get_media_url
14+ from ami .main .models import Occurrence , SourceImage , Taxon , get_media_url
1515from ami .ml .schemas import BoundingBox
1616
1717logger = logging .getLogger (__name__ )
@@ -236,6 +236,7 @@ class OccurrenceCocoTabularSerializer(OccurrenceTabularSerializer):
236236 capture_path = serializers .SerializerMethodField ()
237237 capture_width = serializers .SerializerMethodField ()
238238 capture_height = serializers .SerializerMethodField ()
239+ best_machine_prediction_taxon_id = serializers .IntegerField (allow_null = True , default = None )
239240
240241 class Meta (OccurrenceTabularSerializer .Meta ):
241242 fields = OccurrenceTabularSerializer .Meta .fields + [
@@ -244,6 +245,7 @@ class Meta(OccurrenceTabularSerializer.Meta):
244245 "capture_path" ,
245246 "capture_width" ,
246247 "capture_height" ,
248+ "best_machine_prediction_taxon_id" ,
247249 ]
248250
249251 def get_source_image_id (self , obj ):
@@ -281,6 +283,7 @@ def build_coco_dict_from_occurrence_rows(rows: list[dict], project) -> dict:
281283 categories_by_id : dict [int , dict ] = {}
282284 images_by_id : dict [int , dict ] = {}
283285 annotations : list [dict ] = []
286+ category_taxon_ids : set [int ] = set ()
284287
285288 for row in rows :
286289 determination_id = row .get ("determination_id" )
@@ -305,6 +308,7 @@ def build_coco_dict_from_occurrence_rows(rows: list[dict], project) -> dict:
305308 "id" : int (determination_id ),
306309 "name" : det_name ,
307310 }
311+ category_taxon_ids .add (int (determination_id ))
308312 else :
309313 assert (
310314 categories_by_id [int (determination_id )]["name" ] == det_name
@@ -335,15 +339,63 @@ def build_coco_dict_from_occurrence_rows(rows: list[dict], project) -> dict:
335339 "iscrowd" : 0 , # TODO: Could we use this field to indiate crowd of insects?
336340 "determination_score" : row .get ("determination_score" ),
337341 "verification_status" : row .get ("verification_status" ),
338- "best_machine_prediction_name " : row .get ("best_machine_prediction_name " ),
342+ "determination_matches_machine_prediction " : row .get ("determination_matches_machine_prediction " ),
339343 "best_machine_prediction_algorithm" : row .get ("best_machine_prediction_algorithm" ),
340344 "best_machine_prediction_score" : row .get ("best_machine_prediction_score" ),
341- "determination_matches_machine_prediction " : row .get ("determination_matches_machine_prediction " ),
345+ "best_machine_prediction_taxon_id " : row .get ("best_machine_prediction_taxon_id " ),
342346 "best_detection_width" : row .get ("best_detection_width" ),
343347 "best_detection_height" : row .get ("best_detection_height" ),
344348 }
349+ prediction_taxon_id = row .get ("best_machine_prediction_taxon_id" )
350+ if prediction_taxon_id is not None :
351+ try :
352+ category_taxon_ids .add (int (prediction_taxon_id ))
353+ except (TypeError , ValueError ):
354+ logger .warning (f"Invalid best_machine_prediction_taxon_id for row: { row } " )
345355 annotations .append (ann )
346356
357+ def _serialize_parents_json (parents_json ):
358+ if not isinstance (parents_json , list ):
359+ return []
360+ serialized = []
361+ for parent in parents_json :
362+ if isinstance (parent , dict ):
363+ parent_id = parent .get ("id" )
364+ parent_name = parent .get ("name" )
365+ rank = parent .get ("rank" )
366+ else :
367+ # SchemaField(list[TaxonParent]) may return Pydantic objects rather than dicts.
368+ parent_id = getattr (parent , "id" , None )
369+ parent_name = getattr (parent , "name" , None )
370+ rank = getattr (parent , "rank" , None )
371+
372+ if parent_id is None and parent_name is None and rank is None :
373+ continue
374+
375+ rank_value = rank .value if hasattr (rank , "value" ) else rank
376+ serialized .append (
377+ {
378+ "id" : parent_id ,
379+ "name" : parent_name ,
380+ "rank" : str (rank_value ) if rank_value is not None else None ,
381+ }
382+ )
383+ return serialized
384+
385+ if category_taxon_ids :
386+ taxa = Taxon .objects .filter (id__in = category_taxon_ids ).values (
387+ "id" , "name" , "rank" , "parent_id" , "parents_json"
388+ )
389+ for taxon in taxa :
390+ taxon_id = int (taxon ["id" ])
391+ categories_by_id [taxon_id ] = {
392+ "id" : taxon_id ,
393+ "name" : taxon .get ("name" ) or categories_by_id .get (taxon_id , {}).get ("name" , "" ),
394+ "rank" : taxon .get ("rank" ),
395+ "parent_id" : taxon .get ("parent_id" ),
396+ "parents" : _serialize_parents_json (taxon .get ("parents_json" )),
397+ }
398+
347399 base = getattr (settings , "EXTERNAL_BASE_URL" , "" ) or ""
348400 info_url = ""
349401 if base .strip ():
0 commit comments