@@ -59,63 +59,65 @@ def calculate_batch_size(
5959 return batch_size
6060
6161
62- def _collate_batch (
63- batch_samples : List [Tuple [Dict [str , torch .Tensor ], Any ]],
64- target_seq_len : int
65- ) -> Tuple [Dict [str , torch .Tensor ], torch .Tensor ]:
66- """Collates processed samples into a batch, padding/truncating to target_seq_len."""
67- batch_patch_data = [item [0 ] for item in batch_samples ]
68- batch_labels = [item [1 ] for item in batch_samples ]
69-
70- if not batch_patch_data :
71- return {}, torch .empty (0 )
72-
73- batch_size = len (batch_patch_data )
74- patch_dim = batch_patch_data [0 ]['patches' ].shape [1 ]
75-
76- # Initialize tensors with target sequence length
77- patches_batch = torch .zeros ((batch_size , target_seq_len , patch_dim ), dtype = torch .float32 )
78- patch_coord_batch = torch .zeros ((batch_size , target_seq_len , 2 ), dtype = torch .int64 )
79- patch_valid_batch = torch .zeros ((batch_size , target_seq_len ), dtype = torch .bool ) # Use bool
80-
81- for i , data in enumerate (batch_patch_data ):
82- num_patches = data ['patches' ].shape [0 ]
83- # Take min(num_patches, target_seq_len) patches
84- n_copy = min (num_patches , target_seq_len )
85-
86- patches_batch [i , :n_copy ] = data ['patches' ][:n_copy ]
87- patch_coord_batch [i , :n_copy ] = data ['patch_coord' ][:n_copy ]
88- patch_valid_batch [i , :n_copy ] = data ['patch_valid' ][:n_copy ] # Copy validity flags
89-
90- # Create the final input dict
91- input_dict = {
92- 'patches' : patches_batch ,
93- 'patch_coord' : patch_coord_batch ,
94- 'patch_valid' : patch_valid_batch , # Boolean mask
95- # Note: 'seq_length' might be ambiguous. The target length is target_seq_len.
96- # The actual number of valid patches per sample varies.
97- # 'patch_valid' mask is the most reliable source of truth.
98- }
99-
100- # Attempt to stack labels if they are tensors, otherwise return list
101- try :
102- if isinstance (batch_labels [0 ], torch .Tensor ):
103- labels_tensor = torch .stack (batch_labels )
62+ class NaFlexCollator :
63+ """Custom collator for batching NaFlex-style variable-resolution images."""
64+
65+ def __init__ (
66+ self ,
67+ max_seq_len = None ,
68+ ):
69+ self .max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
70+
71+ def __call__ (self , batch ):
72+ """
73+ Args:
74+ batch: List of tuples (patch_dict, target)
75+
76+ Returns:
77+ A tuple of (input_dict, targets) where input_dict contains:
78+ - patches: Padded tensor of patches
79+ - patch_coord: Coordinates for each patch (y, x)
80+ - patch_valid: Valid indicators
81+ """
82+ assert isinstance (batch [0 ], tuple )
83+ batch_size = len (batch )
84+
85+ # Extract targets
86+ # FIXME need to handle dense (float) targets or always done downstream of this?
87+ targets = torch .tensor ([item [1 ] for item in batch ], dtype = torch .int64 )
88+
89+ # Get patch dictionaries
90+ patch_dicts = [item [0 ] for item in batch ]
91+
92+ # If we have a maximum sequence length constraint, ensure we don't exceed it
93+ if self .max_seq_len is not None :
94+ max_patches = self .max_seq_len
10495 else :
105- # Convert numerical types to tensor, keep others as list (or handle specific types)
106- if isinstance (batch_labels [0 ], (int , float )):
107- labels_tensor = torch .tensor (batch_labels )
108- else :
109- # Cannot convert non-numerical labels easily, return as list
110- # Or handle specific conversion if needed
111- # For FakeDataset, labels are ints, so this works
112- labels_tensor = torch .tensor (batch_labels ) # Assuming labels are numerical
113- except Exception :
114- # Fallback if stacking fails (e.g., different shapes, types)
115- print ("Warning: Could not stack labels into a tensor. Returning list of labels." )
116- labels_tensor = batch_labels # Return as list
96+ # Find the maximum number of patches in this batch
97+ max_patches = max (item ['patches' ].shape [0 ] for item in patch_dicts )
98+
99+ # Get patch dimensionality
100+ patch_dim = patch_dicts [0 ]['patches' ].shape [1 ]
101+
102+ # Prepare tensors for the batch
103+ patches = torch .zeros ((batch_size , max_patches , patch_dim ), dtype = torch .float32 )
104+ patch_coord = torch .zeros ((batch_size , max_patches , 2 ), dtype = torch .int64 ) # [B, N, 2] for (y, x)
105+ patch_valid = torch .zeros ((batch_size , max_patches ), dtype = torch .bool )
106+
107+ # Fill in the tensors
108+ for i , patch_dict in enumerate (patch_dicts ):
109+ num_patches = min (patch_dict ['patches' ].shape [0 ], max_patches )
117110
118- return input_dict , labels_tensor
111+ patches [i , :num_patches ] = patch_dict ['patches' ][:num_patches ]
112+ patch_coord [i , :num_patches ] = patch_dict ['patch_coord' ][:num_patches ]
113+ patch_valid [i , :num_patches ] = patch_dict ['patch_valid' ][:num_patches ]
114+
115+ return {
116+ 'patches' : patches ,
117+ 'patch_coord' : patch_coord ,
118+ 'patch_valid' : patch_valid ,
119+ 'seq_len' : max_patches ,
120+ }, targets
119121
120122
121123class VariableSeqMapWrapper (IterableDataset ):
@@ -161,15 +163,15 @@ def __init__(
161163 self .epoch = epoch
162164 self .batch_divisor = batch_divisor
163165
164- # Pre-initialize transforms for each sequence length
166+ # Pre-initialize transforms and collate fns for each sequence length
165167 self .transforms : Dict [int , Optional [Callable ]] = {}
166- if transform_factory :
167- for seq_len in self .seq_lens :
168+ self .collate_fns : Dict [int , Callable ] = {}
169+ for seq_len in self .seq_lens :
170+ if transform_factory :
168171 self .transforms [seq_len ] = transform_factory (max_seq_len = seq_len , patch_size = self .patch_size )
169- else :
170- for seq_len in self .seq_lens :
171- self .transforms [seq_len ] = None # No transform
172-
172+ else :
173+ self .transforms [seq_len ] = None # No transform
174+ self .collate_fns [seq_len ] = NaFlexCollator (seq_len )
173175 self .patchifier = Patchify (self .patch_size )
174176
175177 # --- Canonical Schedule Calculation (Done Once) ---
@@ -417,6 +419,6 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
417419
418420 # Collate the processed samples into a batch
419421 if batch_samples : # Only yield if we successfully processed samples
420- yield _collate_batch (batch_samples , seq_len )
422+ yield self . collate_fns [ seq_len ] (batch_samples )
421423
422424 # If batch_samples is empty after processing 'indices', an empty batch is skipped.
0 commit comments