Skip to content

Commit bedf7bd

Browse files
Create ai_trainer.py
1 parent 1ab8658 commit bedf7bd

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

src/ai/ai_training/ai_trainer.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import logging
2+
import tensorflow as tf
3+
import numpy as np
4+
from sklearn.model_selection import train_test_split
5+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
6+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
7+
8+
logging.basicConfig(level=logging.ERROR)
9+
10+
def train_model(training_data, model_path, config):
11+
logging.info("Starting AI model training")
12+
if not training_data:
13+
logging.error("Training data is empty.")
14+
return
15+
learning_rate = config['ai']['learning_rate']
16+
# Load data and preprocess
17+
X, y = preprocess_data(training_data)
18+
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
19+
20+
# Define model architecture
21+
model = create_model(learning_rate, X_train.shape[1], config)
22+
23+
# Train the model
24+
model.fit(X_train, y_train, epochs=100, validation_data=(X_val, y_val))
25+
26+
# Evaluate the model
27+
y_pred = model.predict(X_val)
28+
y_pred = (y_pred > 0.5).astype(int)
29+
accuracy = accuracy_score(y_val, y_pred)
30+
precision = precision_score(y_val, y_pred, average='weighted', zero_division=0)
31+
recall = recall_score(y_val, y_pred, average='weighted', zero_division=0)
32+
f1 = f1_score(y_val, y_pred, average='weighted', zero_division=0)
33+
logging.info(f"Model Evaluation: Accuracy={accuracy:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1-Score={f1:.4f}")
34+
35+
# Save the model
36+
model.save(model_path)
37+
logging.info("AI model training completed")
38+
39+
def preprocess_data(training_data):
40+
logging.info("Preprocessing training data")
41+
# Example data: [target_ip, target_port, exploit_type, outcome]
42+
X = np.array([[item[0], item[1], item[2]] for item in training_data])
43+
y = np.array([item[3] for item in training_data])
44+
45+
# One-hot encode exploit_type
46+
encoder = OneHotEncoder(handle_unknown='ignore')
47+
X_encoded = encoder.fit_transform(X[:, [2]]).toarray()
48+
49+
# Scale numerical features
50+
scaler = StandardScaler()
51+
X_scaled = scaler.fit_transform(X[:, :2])
52+
53+
# Combine encoded and scaled features
54+
X_processed = np.concatenate((X_scaled, X_encoded), axis=1)
55+
return X_processed, y
56+
57+
def create_model(learning_rate, input_shape, config):
58+
logging.info("Creating AI model")
59+
model = tf.keras.models.Sequential([
60+
tf.keras.layers.Dense(config['ai']['dense_layer_1'], activation='relu', input_shape=(input_shape,)),
61+
tf.keras.layers.Dense(config['ai']['dense_layer_2'], activation='relu'),
62+
tf.keras.layers.Dense(1, activation='sigmoid')
63+
])
64+
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
65+
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
66+
return model
67+
68+
def integrate_with_new_components(new_component_data, config):
69+
logging.info("Integrating with new components")
70+
# Placeholder for integration logic with new components
71+
integrated_data = {
72+
"new_component_training_data": new_component_data.get("training_data", []),
73+
"new_component_model_config": new_component_data.get("model_config", {})
74+
}
75+
return integrated_data
76+
77+
def ensure_compatibility(existing_data, new_component_data, config):
78+
logging.info("Ensuring compatibility with existing AI training logic")
79+
# Placeholder for compatibility logic
80+
compatible_data = {
81+
"existing_training_data": existing_data.get("training_data", []),
82+
"existing_model_config": existing_data.get("model_config", {}),
83+
"new_component_training_data": new_component_data.get("training_data", []),
84+
"new_component_model_config": new_component_data.get("model_config", {})
85+
}
86+
return compatible_data

0 commit comments

Comments
 (0)