@@ -23,6 +23,7 @@ import * as Utils from 'src/lib/Utils';
2323import { SelectedArtifact } from 'src/pages/CompareV2' ;
2424import { LinkedArtifact } from 'src/mlmd/MlmdUtils' ;
2525import * as jspb from 'google-protobuf' ;
26+ import { useState } from 'react' ;
2627import MetricsDropdown from './MetricsDropdown' ;
2728import { MetricsType , RunArtifact } from 'src/lib/v2/CompareUtils' ;
2829
@@ -76,6 +77,38 @@ function newMockLinkedArtifact(id: number, displayName?: string): LinkedArtifact
7677 } as LinkedArtifact ;
7778}
7879
80+ interface ControlledMetricsDropdownProps {
81+ filteredRunArtifacts : RunArtifact [ ] ;
82+ metricsTab : MetricsType ;
83+ selectedArtifacts : SelectedArtifact [ ] ;
84+ namespace ?: string ;
85+ onUpdateSelectedArtifacts ?: ( selectedArtifacts : SelectedArtifact [ ] ) => void ;
86+ }
87+
88+ function ControlledMetricsDropdown ( props : ControlledMetricsDropdownProps ) {
89+ const {
90+ filteredRunArtifacts,
91+ metricsTab,
92+ selectedArtifacts : initialSelectedArtifacts ,
93+ namespace,
94+ onUpdateSelectedArtifacts,
95+ } = props ;
96+ const [ selectedArtifacts , setSelectedArtifacts ] = useState ( initialSelectedArtifacts ) ;
97+
98+ return (
99+ < MetricsDropdown
100+ filteredRunArtifacts = { filteredRunArtifacts }
101+ metricsTab = { metricsTab }
102+ selectedArtifacts = { selectedArtifacts }
103+ updateSelectedArtifacts = { ( nextSelectedArtifacts ) => {
104+ setSelectedArtifacts ( nextSelectedArtifacts ) ;
105+ onUpdateSelectedArtifacts ?.( nextSelectedArtifacts ) ;
106+ } }
107+ namespace = { namespace }
108+ />
109+ ) ;
110+ }
111+
79112testBestPractices ( ) ;
80113describe ( 'MetricsDropdown' , ( ) => {
81114 const updateSelectedArtifactsSpy = vi . fn ( ) ;
@@ -85,6 +118,7 @@ describe('MetricsDropdown', () => {
85118 let scalarMetricsArtifacts : RunArtifact [ ] ;
86119
87120 beforeEach ( ( ) => {
121+ updateSelectedArtifactsSpy . mockReset ( ) ;
88122 emptySelectedArtifacts = [
89123 {
90124 selectedItem : { itemName : '' , subItemName : '' } ,
@@ -157,11 +191,11 @@ describe('MetricsDropdown', () => {
157191 it ( 'Dropdown loaded when content is present' , async ( ) => {
158192 render (
159193 < CommonTestWrapper >
160- < MetricsDropdown
194+ < ControlledMetricsDropdown
161195 filteredRunArtifacts = { scalarMetricsArtifacts }
162196 metricsTab = { MetricsType . CONFUSION_MATRIX }
163197 selectedArtifacts = { emptySelectedArtifacts }
164- updateSelectedArtifacts = { updateSelectedArtifactsSpy }
198+ onUpdateSelectedArtifacts = { updateSelectedArtifactsSpy }
165199 />
166200 </ CommonTestWrapper > ,
167201 ) ;
@@ -251,11 +285,11 @@ describe('MetricsDropdown', () => {
251285
252286 render (
253287 < CommonTestWrapper >
254- < MetricsDropdown
288+ < ControlledMetricsDropdown
255289 filteredRunArtifacts = { scalarMetricsArtifacts }
256290 metricsTab = { MetricsType . HTML }
257291 selectedArtifacts = { emptySelectedArtifacts }
258- updateSelectedArtifacts = { updateSelectedArtifactsSpy }
292+ onUpdateSelectedArtifacts = { updateSelectedArtifactsSpy }
259293 />
260294 </ CommonTestWrapper > ,
261295 ) ;
@@ -296,11 +330,11 @@ describe('MetricsDropdown', () => {
296330
297331 render (
298332 < CommonTestWrapper >
299- < MetricsDropdown
333+ < ControlledMetricsDropdown
300334 filteredRunArtifacts = { scalarMetricsArtifacts }
301335 metricsTab = { MetricsType . MARKDOWN }
302336 selectedArtifacts = { emptySelectedArtifacts }
303- updateSelectedArtifacts = { updateSelectedArtifactsSpy }
337+ onUpdateSelectedArtifacts = { updateSelectedArtifactsSpy }
304338 />
305339 </ CommonTestWrapper > ,
306340 ) ;
@@ -345,11 +379,11 @@ describe('MetricsDropdown', () => {
345379
346380 render (
347381 < CommonTestWrapper >
348- < MetricsDropdown
382+ < ControlledMetricsDropdown
349383 filteredRunArtifacts = { scalarMetricsArtifacts }
350384 metricsTab = { MetricsType . HTML }
351385 selectedArtifacts = { emptySelectedArtifacts }
352- updateSelectedArtifacts = { updateSelectedArtifactsSpy }
386+ onUpdateSelectedArtifacts = { updateSelectedArtifactsSpy }
353387 namespace = 'namespaceInput'
354388 />
355389 </ CommonTestWrapper > ,
@@ -391,11 +425,11 @@ describe('MetricsDropdown', () => {
391425
392426 render (
393427 < CommonTestWrapper >
394- < MetricsDropdown
428+ < ControlledMetricsDropdown
395429 filteredRunArtifacts = { scalarMetricsArtifacts }
396430 metricsTab = { MetricsType . CONFUSION_MATRIX }
397431 selectedArtifacts = { newSelectedArtifacts }
398- updateSelectedArtifacts = { updateSelectedArtifactsSpy }
432+ onUpdateSelectedArtifacts = { updateSelectedArtifactsSpy }
399433 />
400434 </ CommonTestWrapper > ,
401435 ) ;
@@ -404,4 +438,118 @@ describe('MetricsDropdown', () => {
404438 screen . getByText ( 'Choose a first Confusion Matrix artifact' ) ;
405439 screen . getByLabelText ( 'run1 > execution1 > artifact1' ) ;
406440 } ) ;
441+
442+ it ( 'updates the displayed selection when parent-selected artifacts change after mount' , async ( ) => {
443+ const { rerender } = render (
444+ < CommonTestWrapper >
445+ < MetricsDropdown
446+ filteredRunArtifacts = { scalarMetricsArtifacts }
447+ metricsTab = { MetricsType . CONFUSION_MATRIX }
448+ selectedArtifacts = { emptySelectedArtifacts }
449+ updateSelectedArtifacts = { updateSelectedArtifactsSpy }
450+ />
451+ </ CommonTestWrapper > ,
452+ ) ;
453+ await TestUtils . flushPromises ( ) ;
454+
455+ const nextSelectedArtifacts : SelectedArtifact [ ] = [
456+ {
457+ linkedArtifact : firstLinkedArtifact ,
458+ selectedItem : {
459+ itemName : 'run1' ,
460+ subItemName : 'execution1' ,
461+ subItemSecondaryName : 'artifact1' ,
462+ } ,
463+ } ,
464+ emptySelectedArtifacts [ 1 ] ,
465+ ] ;
466+
467+ rerender (
468+ < CommonTestWrapper >
469+ < MetricsDropdown
470+ filteredRunArtifacts = { scalarMetricsArtifacts }
471+ metricsTab = { MetricsType . CONFUSION_MATRIX }
472+ selectedArtifacts = { nextSelectedArtifacts }
473+ updateSelectedArtifacts = { updateSelectedArtifactsSpy }
474+ />
475+ </ CommonTestWrapper > ,
476+ ) ;
477+
478+ expect ( await screen . findByLabelText ( 'run1 > execution1 > artifact1' ) ) . toBeInTheDocument ( ) ;
479+ } ) ;
480+
481+ it ( 're-resolves the selected artifact from the current compare data when labels stay the same' , async ( ) => {
482+ const getHtmlViewerConfigSpy = vi . spyOn ( metricsVisualizations , 'getHtmlViewerConfig' ) ;
483+ getHtmlViewerConfigSpy . mockResolvedValue ( [ ] ) ;
484+
485+ const staleLinkedArtifact = newMockLinkedArtifact ( 10 , 'artifact1' ) ;
486+ const freshLinkedArtifact = newMockLinkedArtifact ( 11 , 'artifact1' ) ;
487+ const selectedArtifactsWithStaleArtifact : SelectedArtifact [ ] = [
488+ {
489+ linkedArtifact : staleLinkedArtifact ,
490+ selectedItem : {
491+ itemName : 'run1' ,
492+ subItemName : 'execution1' ,
493+ subItemSecondaryName : 'artifact1' ,
494+ } ,
495+ } ,
496+ emptySelectedArtifacts [ 1 ] ,
497+ ] ;
498+
499+ const { rerender } = render (
500+ < CommonTestWrapper >
501+ < MetricsDropdown
502+ filteredRunArtifacts = { [
503+ {
504+ run : {
505+ run_id : '1' ,
506+ display_name : 'run1' ,
507+ } ,
508+ executionArtifacts : [
509+ {
510+ execution : newMockExecution ( 1 , 'execution1' ) ,
511+ linkedArtifacts : [ staleLinkedArtifact ] ,
512+ } ,
513+ ] ,
514+ } ,
515+ ] }
516+ metricsTab = { MetricsType . HTML }
517+ selectedArtifacts = { selectedArtifactsWithStaleArtifact }
518+ updateSelectedArtifacts = { updateSelectedArtifactsSpy }
519+ />
520+ </ CommonTestWrapper > ,
521+ ) ;
522+ await TestUtils . flushPromises ( ) ;
523+ await waitFor ( ( ) => {
524+ expect ( getHtmlViewerConfigSpy ) . toHaveBeenLastCalledWith ( [ staleLinkedArtifact ] , undefined ) ;
525+ } ) ;
526+
527+ rerender (
528+ < CommonTestWrapper >
529+ < MetricsDropdown
530+ filteredRunArtifacts = { [
531+ {
532+ run : {
533+ run_id : '2' ,
534+ display_name : 'run1' ,
535+ } ,
536+ executionArtifacts : [
537+ {
538+ execution : newMockExecution ( 2 , 'execution1' ) ,
539+ linkedArtifacts : [ freshLinkedArtifact ] ,
540+ } ,
541+ ] ,
542+ } ,
543+ ] }
544+ metricsTab = { MetricsType . HTML }
545+ selectedArtifacts = { selectedArtifactsWithStaleArtifact }
546+ updateSelectedArtifacts = { updateSelectedArtifactsSpy }
547+ />
548+ </ CommonTestWrapper > ,
549+ ) ;
550+
551+ await waitFor ( ( ) => {
552+ expect ( getHtmlViewerConfigSpy ) . toHaveBeenLastCalledWith ( [ freshLinkedArtifact ] , undefined ) ;
553+ } ) ;
554+ } ) ;
407555} ) ;
0 commit comments