Skip to content

Commit 78aee5c

Browse files
committed
Updated classification map to ensure items and device compatibility.
1 parent 747dc31 commit 78aee5c

File tree

1 file changed

+28
-219
lines changed

1 file changed

+28
-219
lines changed

torchsom/core.py

Lines changed: 28 additions & 219 deletions
Original file line numberDiff line numberDiff line change
@@ -782,44 +782,10 @@ def collect_samples(
782782
if closest_neuron in bmus_idx_map:
783783
sample_indices.extend(bmus_idx_map[closest_neuron])
784784

785-
# Retrieve historical features and outputs from indices, while ensuring device compatibility.
786-
# historical_data_buffer = historical_samples[sample_indices].to(self.device)
787-
# historical_output_buffer = (
788-
# historical_outputs[sample_indices].view(-1, 1).to(self.device)
789-
# )
790785
historical_data_buffer = historical_samples[sample_indices]
791786
historical_output_buffer = historical_outputs[sample_indices].view(-1, 1)
792-
793787
return historical_data_buffer, historical_output_buffer
794788

795-
# def build_hit_map(
796-
# self,
797-
# data: torch.Tensor,
798-
# ) -> torch.Tensor:
799-
# """Returns a matrix where element i,j is the number of times that neuron i,j has been the winner.
800-
801-
# Args:
802-
# data (torch.Tensor): input data tensor [batch_size, num_features]
803-
804-
# Returns:
805-
# torch.Tensor: Matrix indicating the number of times each neuron has been identified as bmu. [row_neurons, col_neurons]
806-
# """
807-
808-
# # Ensure device and batch compatibility
809-
# data = data.to(self.device)
810-
# if data.dim() == 1:
811-
# data = data.unsqueeze(0)
812-
813-
# # Get BMUs for all data points at once [batch_size, 2]
814-
# bmus = self.identify_bmus(data)
815-
816-
# # Build the map by counting the occurence of each bmu
817-
# hit_map = torch.zeros((self.x, self.y), device=self.device)
818-
# for row, col in bmus:
819-
# hit_map[row, col] += 1
820-
821-
# return hit_map
822-
823789
def build_hit_map(
824790
self,
825791
data: torch.Tensor,
@@ -876,49 +842,6 @@ def build_hit_map(
876842

877843
return hit_map
878844

879-
# def build_bmus_data_map(
880-
# self,
881-
# data: torch.Tensor,
882-
# return_indices: bool = False,
883-
# ) -> Dict[Tuple[int, int], Any]:
884-
# """Create a mapping of winning neurons to their corresponding data points.
885-
886-
# Args:
887-
# data (torch.Tensor): input data tensor [batch_size, num_features] or [num_features]
888-
# return_indices (bool, optional): If True, return indices instead of data points. Defaults to False.
889-
890-
# Returns:
891-
# Dict[Tuple[int, int], Any]: Dictionary mapping bmus to data samples or indices
892-
# """
893-
894-
# # Ensure device and batch compatibility
895-
# data = data.to(self.device)
896-
# if data.dim() == 1:
897-
# data = data.unsqueeze(0)
898-
899-
# # Get BMUs for all data points at once [batch_size, 2]
900-
# bmus = self.identify_bmus(data)
901-
902-
# # Build the map
903-
# bmus_data_map = defaultdict(list)
904-
# for idx, (row, col) in enumerate(bmus):
905-
906-
# # Convert BMU coordinates to integer tuple for dictionary key
907-
# bmu_pos = (int(row.item()), int(col.item()))
908-
909-
# # Add the corresponding element to the dict
910-
# if return_indices:
911-
# bmus_data_map[bmu_pos].append(idx)
912-
# else:
913-
# bmus_data_map[bmu_pos].append(data[idx])
914-
915-
# # If you return data samples, then for each bmu you need to stack them to have a tensor of all data samples instead of multiple tensors
916-
# if not return_indices:
917-
# for bmu in bmus_data_map:
918-
# bmus_data_map[bmu] = torch.stack(bmus_data_map[bmu])
919-
920-
# return bmus_data_map
921-
922845
def build_bmus_data_map(
923846
self,
924847
data: torch.Tensor,
@@ -1010,10 +933,6 @@ def build_rank_map(
1010933
torch.Tensor: Rank map where each neuron's value is its rank (1 = lowest std = best)
1011934
"""
1012935

1013-
# # Ensure device compatibility
1014-
# data = data.to(self.device)
1015-
# target = target.to(self.device)
1016-
1017936
# ! Now bmus_map is by default on the CPU
1018937
bmus_map = self.build_bmus_data_map(data, return_indices=True)
1019938
# neuron_stds = torch.full((self.x, self.y), float("nan"), device=self.device)
@@ -1294,25 +1213,23 @@ def build_classification_map(
12941213
torch.Tensor: Classification map with the most frequent label for each neuron
12951214
"""
12961215

1297-
# # Ensure device compatibility
1298-
# data = data.to(self.device)
1299-
# target = target.to(self.device)
1300-
13011216
bmus_map = self.build_bmus_data_map(data, return_indices=True)
1302-
# classification_map = torch.full(
1303-
# (self.x, self.y), float("nan"), device=self.device
1304-
# )
13051217
classification_map = torch.full((self.x, self.y), float("nan"))
13061218

13071219
# Retrieve neighborhood offsets based on topology for tie-breaking
13081220
neighborhood_offsets = []
13091221
if self.topology == "hexagonal":
13101222
for order in range(1, neighborhood_order + 1):
1311-
offsets = self._get_hexagonal_offsets(order)
1312-
neighborhood_offsets.extend(
1313-
offsets["even"] if (row % 2 == 0) else offsets["odd"]
1314-
for row in range(self.x)
1315-
)
1223+
# offsets = self._get_hexagonal_offsets(order)
1224+
# neighborhood_offsets.extend(
1225+
# offsets["even"] if (row % 2 == 0) else offsets["odd"]
1226+
# for row in range(self.x)
1227+
# )
1228+
for row in range(self.x):
1229+
offsets = self._get_hexagonal_offsets(order)
1230+
neighborhood_offsets.extend(
1231+
offsets["even"] if (row % 2 == 0) else offsets["odd"]
1232+
)
13161233
else:
13171234
for order in range(1, neighborhood_order + 1):
13181235
neighborhood_offsets.extend(self._get_rectangular_offsets(order))
@@ -1326,8 +1243,8 @@ def build_classification_map(
13261243
Find the most common one
13271244
Check if there is a tie with another label
13281245
"""
1329-
neuron_labels = target[sample_indices]
1330-
label_counts = Counter(neuron_labels.cpu().numpy())
1246+
neuron_labels = target[sample_indices].cpu().numpy()
1247+
label_counts = Counter(neuron_labels)
13311248
max_count = max(label_counts.values())
13321249
top_labels = [
13331250
label for label, count in label_counts.items() if count == max_count
@@ -1338,7 +1255,10 @@ def build_classification_map(
13381255
In case of a tie, consider labels from neighboring neurons to break it.
13391256
"""
13401257
if len(top_labels) == 1:
1341-
classification_map[bmu_pos] = top_labels[0]
1258+
# classification_map[bmu_pos] = top_labels[0]
1259+
classification_map[bmu_pos] = torch.tensor(
1260+
top_labels[0], dtype=classification_map.dtype
1261+
) # Convert NumPy value to tensor scalar
13421262
else:
13431263
neighbor_labels = []
13441264
row, col = bmu_pos
@@ -1368,139 +1288,28 @@ def build_classification_map(
13681288
]
13691289
# If there is a tie with neighbor labels, choose randomly between top labels (including neighbors).
13701290
if len(top_neighbor_labels) == 1:
1371-
classification_map[bmu_pos] = top_neighbor_labels[0]
1291+
# classification_map[bmu_pos] = top_neighbor_labels[0]
1292+
classification_map[bmu_pos] = torch.tensor(
1293+
top_neighbor_labels[0], dtype=classification_map.dtype
1294+
)
13721295
else:
13731296
# classification_map[bmu_pos] = torch.tensor(
1374-
# random.choice(top_neighbor_labels), device=self.device
1297+
# random.choice(top_neighbor_labels)
13751298
# )
1299+
# Choose randomly and convert to tensor
1300+
chosen_label = random.choice(top_neighbor_labels)
13761301
classification_map[bmu_pos] = torch.tensor(
1377-
random.choice(top_neighbor_labels)
1302+
chosen_label, dtype=classification_map.dtype
13781303
)
13791304
# If there are no neighbor labels, choose randomly between previous top labels.
13801305
else:
13811306
# classification_map[bmu_pos] = torch.tensor(
1382-
# random.choice(top_labels), device=self.device
1307+
# random.choice(top_labels)
13831308
# )
1309+
# Choose randomly and convert to tensor
1310+
chosen_label = random.choice(top_labels)
13841311
classification_map[bmu_pos] = torch.tensor(
1385-
random.choice(top_labels)
1312+
chosen_label, dtype=classification_map.dtype
13861313
)
13871314

13881315
return classification_map
1389-
1390-
1391-
# def collect_samples(
1392-
# self,
1393-
# query_sample: torch.Tensor,
1394-
# historical_samples: torch.Tensor,
1395-
# historical_outputs: torch.Tensor,
1396-
# min_buffer_threshold: int = 50,
1397-
# bmus_idx_map: Dict[Tuple[int, int], List[int]] = None,
1398-
# ) -> Tuple[torch.Tensor, torch.Tensor]:
1399-
# """Collect historical samples similar to the query sample using SOM projection.
1400-
1401-
# Args:
1402-
# query_sample (torch.Tensor): The query data point [num_features]
1403-
# historical_samples (torch.Tensor): Historical input data [num_samples, num_features]
1404-
# historical_outputs (torch.Tensor): Historical output values [num_samples]
1405-
# min_buffer_threshold (int, optional): Minimum number of samples to collect. Defaults to 50.
1406-
1407-
# Returns:
1408-
# Tuple[torch.Tensor, torch.Tensor]: (historical_data_buffer, historical_output_buffer)
1409-
# """
1410-
1411-
# # Ensure device compatibility
1412-
# query_sample = query_sample.to(self.device)
1413-
# historical_samples = historical_samples.to(self.device)
1414-
# historical_outputs = historical_outputs.to(self.device)
1415-
1416-
# # Initialize collection lists and tracking set
1417-
# historical_data_list = []
1418-
# historical_output_list = []
1419-
# visited_neurons = set()
1420-
1421-
# # Find BMU for the query sample
1422-
# bmu_pos = self.identify_bmus(query_sample)
1423-
# bmu_tuple = (int(bmu_pos[0].item()), int(bmu_pos[1].item()))
1424-
1425-
# # Collect samples (features and outputs) from the query's BMU if any exist
1426-
# if bmu_tuple in bmus_idx_map and len(bmus_idx_map[bmu_tuple]) > 0:
1427-
# for sample_idx in bmus_idx_map[bmu_tuple]:
1428-
# historical_data_list.append(historical_samples[sample_idx])
1429-
# historical_output_list.append(historical_outputs[sample_idx])
1430-
1431-
# # Get neighbor offsets based on topology and neighborhood order
1432-
# neighbor_order = self.neighborhood_order
1433-
# neighbor_offsets = []
1434-
# for order in range(1, neighbor_order + 1):
1435-
# if self.topology == "hexagonal":
1436-
# if bmu_pos[0] % 2 == 0:
1437-
# nei_order_offsets = self._get_hexagonal_offsets(order)["even"]
1438-
# else:
1439-
# nei_order_offsets = self._get_hexagonal_offsets(order)["odd"]
1440-
# else:
1441-
# nei_order_offsets = self._get_rectangular_offsets(order)
1442-
# neighbor_offsets.extend(nei_order_offsets)
1443-
1444-
# """
1445-
# First, explore all neighbors of the current BMU and retrieve historical samples if they exist
1446-
# Only explore closed neighbors in terms of distance in the grid, not in terms of distance of the weights.
1447-
# """
1448-
# for dx, dy in neighbor_offsets:
1449-
# neighbor_pos = (int(bmu_pos[0].item() + dx), int(bmu_pos[1].item() + dy))
1450-
# if neighbor_pos in visited_neurons:
1451-
# continue
1452-
# visited_neurons.add(neighbor_pos)
1453-
# # Check if the neighbor is 1) within SOM bounds, 2) activated, and 3) has samples
1454-
# if (
1455-
# 0 <= neighbor_pos[0] < self.x
1456-
# and 0 <= neighbor_pos[1] < self.y
1457-
# and neighbor_pos in bmus_idx_map
1458-
# and len(bmus_idx_map[neighbor_pos]) > 0
1459-
# ):
1460-
# for sample_idx in bmus_idx_map[neighbor_pos]:
1461-
# historical_data_list.append(historical_samples[sample_idx])
1462-
# historical_output_list.append(historical_outputs[sample_idx])
1463-
1464-
# """
1465-
# Secondly, ensure we have enough training samples.
1466-
# This time, explore neighbors that are close in terms of distance in the weights space.
1467-
# """
1468-
# historical_samples_count = len(historical_output_list)
1469-
# if historical_samples_count <= min_buffer_threshold:
1470-
# # Calculate distances from current BMU weights to all other neurons
1471-
# neurons_distance_map = self._calculate_distances_to_neurons(
1472-
# data=self.weights.data[bmu_pos[0], bmu_pos[1]]
1473-
# )
1474-
1475-
# # Build min heap of (distance, neuron_position) for neurons with samples
1476-
# distance_min_heap = []
1477-
# for row in range(self.x):
1478-
# for col in range(self.y):
1479-
# neuron_pos = (row, col)
1480-
# if neuron_pos in visited_neurons:
1481-
# continue
1482-
# if neuron_pos in bmus_idx_map and len(bmus_idx_map[neuron_pos]) > 0:
1483-
# distance = neurons_distance_map[row, col].item()
1484-
# heapq.heappush(distance_min_heap, (distance, neuron_pos))
1485-
1486-
# # Pop from min heap and collect samples until we reach the threshold
1487-
# while (
1488-
# distance_min_heap and historical_samples_count <= min_buffer_threshold
1489-
# ):
1490-
# _, closest_neuron = heapq.heappop(distance_min_heap)
1491-
# visited_neurons.add(closest_neuron)
1492-
# if (
1493-
# closest_neuron in bmus_idx_map
1494-
# and len(bmus_idx_map[closest_neuron]) > 0
1495-
# ):
1496-
# for sample_idx in bmus_idx_map[closest_neuron]:
1497-
# historical_data_list.append(historical_samples[sample_idx])
1498-
# historical_output_list.append(historical_outputs[sample_idx])
1499-
# historical_samples_count += 1
1500-
1501-
# # Concatenate collected historical samples (features and outputs)
1502-
# historical_data_buffer = torch.stack(historical_data_list, dim=0)
1503-
# historical_output_buffer = torch.tensor(
1504-
# historical_output_list, device=self.device
1505-
# ).view(-1, 1)
1506-
# return historical_data_buffer, historical_output_buffer

0 commit comments

Comments
 (0)