Skip to content

Commit 11de419

Browse files
committed
TorchSOM refactoring:
- Introduced core components including `SOM` and `BaseSOM` classes for SOM functionality. - Implemented configuration management using Pydantic with `SOMConfig` for parameter validation. - Added utility functions for distance calculations, neighborhood functions, and weight initialization. - Created visualization tools for SOM outputs, including a `SOMVisualizer` class. - Established versioning with `__version__` in `version.py`. - Organized code into modular files for better maintainability and clarity. - Included necessary utility functions for grid management and decay functions. - Set up initial structure for hierarchical and growing SOM variants. - Ensured compatibility with PyTorch for tensor operations and GPU support. - Added comprehensive inline documentation and type hints for improved code readability.
1 parent b08c357 commit 11de419

25 files changed

+2718
-0
lines changed

torchsom/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .core import SOM, BaseSOM
2+
from .utils.decay import DECAY_FUNCTIONS
3+
from .utils.distances import DISTANCE_FUNCTIONS
4+
from .utils.neighborhood import NEIGHBORHOOD_FUNCTIONS
5+
from .version import __version__
6+
from .visualization import SOMVisualizer, VisualizationConfig
7+
8+
# Define what should be imported when using 'from torchsom import *'
9+
__all__ = [
10+
"SOM",
11+
"BaseSOM",
12+
"DISTANCE_FUNCTIONS",
13+
"DECAY_FUNCTIONS",
14+
"NEIGHBORHOOD_FUNCTIONS",
15+
"SOMVisualizer",
16+
"VisualizationConfig",
17+
]

torchsom/configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .som_config import SOMConfig
2+
3+
__all__ = ["SOMConfig"]

torchsom/configs/som_config.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Literal, Optional, Union
2+
3+
import torch
4+
from pydantic import BaseModel, Field, validator
5+
6+
7+
class SOMConfig(BaseModel):
8+
"""Configuration for SOM parameters using pydantic for validation."""
9+
10+
# Map structure parameters
11+
x: int = Field(..., description="Number of rows in the map", gt=0)
12+
y: int = Field(..., description="Number of columns in the map", gt=0)
13+
topology: Literal["rectangular", "hexagonal"] = Field(
14+
"rectangular", description="Grid topology"
15+
)
16+
17+
# Training parameters
18+
epochs: int = Field(10, description="Number of training epochs", ge=1)
19+
batch_size: int = Field(5, description="Batch size for training", ge=1)
20+
learning_rate: float = Field(0.5, description="Initial learning rate", gt=0)
21+
sigma: float = Field(1.0, description="Initial neighborhood radius", gt=0)
22+
23+
# Function choices
24+
neighborhood_function: Literal["gaussian", "mexican_hat", "bubble", "triangle"] = (
25+
Field(
26+
"gaussian",
27+
description="Function to determine neuron neighborhood influence",
28+
)
29+
)
30+
distance_function: Literal["euclidean", "cosine", "manhattan", "chebyshev"] = Field(
31+
"euclidean", description="Function to compute distances"
32+
)
33+
lr_decay_function: Literal[
34+
"lr_inverse_decay_to_zero", "lr_linear_decay_to_zero", "asymptotic_decay"
35+
] = Field("asymptotic_decay", description="Learning rate decay function")
36+
sigma_decay_function: Literal[
37+
"sig_inverse_decay_to_one", "sig_linear_decay_to_one", "asymptotic_decay"
38+
] = Field("asymptotic_decay", description="Sigma decay function")
39+
initialization_mode: Literal["random", "pca"] = Field(
40+
"random", description="Weight initialization method"
41+
)
42+
43+
# Other parameters
44+
neighborhood_order: int = Field(
45+
1, description="Neighborhood order for distance calculations", ge=1
46+
)
47+
device: str = Field(
48+
"cuda" if torch.cuda.is_available() else "cpu",
49+
description="Device for tensor computations",
50+
)
51+
random_seed: int = Field(42, description="Random seed for reproducibility")
52+
53+
54+
# SOM:
55+
56+
# x: 100 # Number of rows: 100
57+
# y: 75 # Number of columns: 75
58+
# topology: "rectangular" # Grid topology: "rectangular" or "hexagonal"
59+
# epochs: 50 # Number of epochs to train the model: 50
60+
# batch_size: 256 # Number of samples to train the model: 64
61+
# learning_rate: 0.95 # Initial learning rate: 0.95
62+
# sigma: 5.0 # Initial spread of the neighborhood function: 5.0
63+
# neighborhood_function: "gaussian" # Neighborhood function: "gaussian", "mexican_hat", "bubble", "triangle"
64+
# distance_function: "euclidean" # Function to calculate distance between data and neurons: "euclidean" "cosine" "manhattan" "chebyshev" "weighted_euclidean" (need to provide weights_proportion)
65+
# lr_decay_function: "asymptotic_decay" # Learning rate scheduler: "lr_inverse_decay_to_zero", "lr_linear_decay_to_zero", "asymptotic_decay"
66+
# sigma_decay_function: "asymptotic_decay" # Sigma scheduler: "sig_inverse_decay_to_one" "sig_linear_decay_to_one" "asymptotic_decay"
67+
# initialization_mode: "pca" # Weights initialization method: "random" or "pca"
68+
# neighborhood_order: 3 # Indicate which neighbors should be considered in SOM distance map and JITL buffer: 1
69+
# ! device: "cuda" # Device for tensor computations: "cuda" or "cpu"
70+
# ! random_seed: 42 # Random seed for reproducibility

torchsom/core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .base_som import BaseSOM
2+
from .som import SOM
3+
4+
__all__ = ["SOM", "BaseSOM"]

torchsom/core/base_som.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, List, Optional, Tuple, Union
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
8+
class BaseSOM(nn.Module, ABC):
9+
"""Abstract base class for all SOM variants."""
10+
11+
@abstractmethod
12+
def fit(
13+
self,
14+
data: torch.Tensor,
15+
) -> Tuple[List[float], List[float]]:
16+
"""Train the SOM on the given data.
17+
18+
Args:
19+
data (torch.Tensor): Input data tensor [batch_size, num_features]
20+
21+
Returns:
22+
Tuple[List[float], List[float]]: Quantization and topographic errors [epoch]
23+
"""
24+
pass
25+
26+
@abstractmethod
27+
def identify_bmus(
28+
self,
29+
data: torch.Tensor,
30+
) -> torch.Tensor:
31+
"""Find best matching units for input data.
32+
33+
Args:
34+
data (torch.Tensor): Input data tensor [batch_size, num_features] or [num_features]
35+
36+
Returns:
37+
torch.Tensor: For single sample: Tensor of shape [2] with [row, col].
38+
For batch: Tensor of shape [batch_size, 2] with [row, col] pairs
39+
"""
40+
pass
41+
42+
@abstractmethod
43+
def quantization_error(
44+
self,
45+
data: torch.Tensor,
46+
) -> float:
47+
"""Calculate quantization error.
48+
49+
Args:
50+
data (torch.Tensor): Input data tensor [batch_size, num_features] or [num_features]
51+
52+
Returns:
53+
float: Average quantization error value
54+
"""
55+
pass
56+
57+
@abstractmethod
58+
def topographic_error(
59+
self,
60+
data: torch.Tensor,
61+
) -> float:
62+
"""Calculate topographic error.
63+
64+
Args:
65+
data (torch.Tensor): Input data tensor [batch_size, num_features] or [num_features]
66+
67+
Returns:
68+
float: Topographic error ratio
69+
"""
70+
pass
71+
72+
@abstractmethod
73+
def initialize_weights(
74+
self,
75+
data: torch.Tensor,
76+
mode: str = None,
77+
) -> None:
78+
"""Initialize the SOM weights.
79+
80+
Args:
81+
data (torch.Tensor): Input data tensor [batch_size, num_features]
82+
mode (str, optional): Weight initialization method. Defaults to None.
83+
"""
84+
pass
85+
86+
@abstractmethod
87+
def build_hit_map(
88+
self,
89+
data: torch.Tensor,
90+
) -> torch.Tensor:
91+
"""Build a hit map showing neuron activation frequencies.
92+
93+
Args:
94+
data (torch.Tensor): Input data tensor [batch_size, num_features] or [num_features]
95+
96+
Returns:
97+
torch.Tensor: Hit map [row_neurons, col_neurons]
98+
"""
99+
pass
100+
101+
@abstractmethod
102+
def build_distance_map(
103+
self,
104+
scaling: str = "sum",
105+
) -> torch.Tensor:
106+
"""Build a distance map (U-matrix) showing neuron similarities.
107+
108+
Args:
109+
scaling (str, optional): Scaling method for distances. Defaults to "sum".
110+
111+
Returns:
112+
torch.Tensor: Distance map [row_neurons, col_neurons]
113+
"""
114+
pass

torchsom/core/growing/__init__.py

Whitespace-only changes.

torchsom/core/growing/components.py

Whitespace-only changes.

torchsom/core/growing/growing_som.py

Whitespace-only changes.

torchsom/core/hierarchical/__init__.py

Whitespace-only changes.

torchsom/core/hierarchical/components.py

Whitespace-only changes.

torchsom/core/hierarchical/hierarchical_som.py

Whitespace-only changes.

0 commit comments

Comments
 (0)