|
| 1 | +import logging |
| 2 | +import tensorflow as tf |
| 3 | +import numpy as np |
| 4 | +from sklearn.model_selection import train_test_split |
| 5 | +from sklearn.preprocessing import StandardScaler, OneHotEncoder |
| 6 | +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score |
| 7 | + |
| 8 | +logging.basicConfig(level=logging.ERROR) |
| 9 | + |
| 10 | +def train_model(training_data, model_path, config): |
| 11 | + logging.info("Starting AI model training") |
| 12 | + if not training_data: |
| 13 | + logging.error("Training data is empty.") |
| 14 | + return |
| 15 | + learning_rate = config['ai']['learning_rate'] |
| 16 | + # Load data and preprocess |
| 17 | + X, y = preprocess_data(training_data) |
| 18 | + X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) |
| 19 | + |
| 20 | + # Define model architecture |
| 21 | + model = create_model(learning_rate, X_train.shape[1], config) |
| 22 | + |
| 23 | + # Train the model |
| 24 | + model.fit(X_train, y_train, epochs=100, validation_data=(X_val, y_val)) |
| 25 | + |
| 26 | + # Evaluate the model |
| 27 | + y_pred = model.predict(X_val) |
| 28 | + y_pred = (y_pred > 0.5).astype(int) |
| 29 | + accuracy = accuracy_score(y_val, y_pred) |
| 30 | + precision = precision_score(y_val, y_pred, average='weighted', zero_division=0) |
| 31 | + recall = recall_score(y_val, y_pred, average='weighted', zero_division=0) |
| 32 | + f1 = f1_score(y_val, y_pred, average='weighted', zero_division=0) |
| 33 | + logging.info(f"Model Evaluation: Accuracy={accuracy:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1-Score={f1:.4f}") |
| 34 | + |
| 35 | + # Save the model |
| 36 | + model.save(model_path) |
| 37 | + logging.info("AI model training completed") |
| 38 | + |
| 39 | +def preprocess_data(training_data): |
| 40 | + logging.info("Preprocessing training data") |
| 41 | + # Example data: [target_ip, target_port, exploit_type, outcome] |
| 42 | + X = np.array([[item[0], item[1], item[2]] for item in training_data]) |
| 43 | + y = np.array([item[3] for item in training_data]) |
| 44 | + |
| 45 | + # One-hot encode exploit_type |
| 46 | + encoder = OneHotEncoder(handle_unknown='ignore') |
| 47 | + X_encoded = encoder.fit_transform(X[:, [2]]).toarray() |
| 48 | + |
| 49 | + # Scale numerical features |
| 50 | + scaler = StandardScaler() |
| 51 | + X_scaled = scaler.fit_transform(X[:, :2]) |
| 52 | + |
| 53 | + # Combine encoded and scaled features |
| 54 | + X_processed = np.concatenate((X_scaled, X_encoded), axis=1) |
| 55 | + return X_processed, y |
| 56 | + |
| 57 | +def create_model(learning_rate, input_shape, config): |
| 58 | + logging.info("Creating AI model") |
| 59 | + model = tf.keras.models.Sequential([ |
| 60 | + tf.keras.layers.Dense(config['ai']['dense_layer_1'], activation='relu', input_shape=(input_shape,)), |
| 61 | + tf.keras.layers.Dense(config['ai']['dense_layer_2'], activation='relu'), |
| 62 | + tf.keras.layers.Dense(1, activation='sigmoid') |
| 63 | + ]) |
| 64 | + optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) |
| 65 | + model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy']) |
| 66 | + return model |
| 67 | + |
| 68 | +def integrate_with_new_components(new_component_data, config): |
| 69 | + logging.info("Integrating with new components") |
| 70 | + # Placeholder for integration logic with new components |
| 71 | + integrated_data = { |
| 72 | + "new_component_training_data": new_component_data.get("training_data", []), |
| 73 | + "new_component_model_config": new_component_data.get("model_config", {}) |
| 74 | + } |
| 75 | + return integrated_data |
| 76 | + |
| 77 | +def ensure_compatibility(existing_data, new_component_data, config): |
| 78 | + logging.info("Ensuring compatibility with existing AI training logic") |
| 79 | + # Placeholder for compatibility logic |
| 80 | + compatible_data = { |
| 81 | + "existing_training_data": existing_data.get("training_data", []), |
| 82 | + "existing_model_config": existing_data.get("model_config", {}), |
| 83 | + "new_component_training_data": new_component_data.get("training_data", []), |
| 84 | + "new_component_model_config": new_component_data.get("model_config", {}) |
| 85 | + } |
| 86 | + return compatible_data |
0 commit comments