forked from henrylin99/quantitative_analysis
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_fixed_ml_training.py
More file actions
132 lines (107 loc) · 4.61 KB
/
test_fixed_ml_training.py
File metadata and controls
132 lines (107 loc) · 4.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试修复后的机器学习训练功能
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from app import create_app
from app.extensions import db
from app.models import MLModelDefinition
from app.services.ml_models import MLModelManager
import requests
import json
def test_fixed_ml_training():
"""测试修复后的机器学习训练功能"""
print("🧪 测试修复后的机器学习训练功能")
print("=" * 60)
try:
# 1. 测试通过API训练模型
print("1️⃣ 测试API训练功能...")
# 准备训练请求
train_data = {
'model_id': 'simple_demo_model',
'start_date': '2023-01-01',
'end_date': '2023-12-31'
}
# 发送训练请求
response = requests.post(
'http://localhost:5001/api/ml-factor/models/train',
json=train_data,
headers={'Content-Type': 'application/json'}
)
print(f" API响应状态: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(" ✅ API训练成功!")
print(f" 📊 训练指标:")
if 'metrics' in result:
for key, value in result['metrics'].items():
if isinstance(value, (int, float)):
print(f" {key}: {value:.4f}")
elif isinstance(value, dict):
print(f" {key}: {len(value)} 项")
else:
print(f" {key}: {value}")
else:
print(f" ❌ API训练失败: {response.text}")
print()
# 2. 测试通过API预测功能
print("2️⃣ 测试API预测功能...")
predict_data = {
'model_id': 'simple_demo_model',
'trade_date': '2025-05-23'
}
response = requests.post(
'http://localhost:5001/api/ml-factor/models/predict',
json=predict_data,
headers={'Content-Type': 'application/json'}
)
print(f" API响应状态: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(" ✅ API预测成功!")
if 'predictions' in result:
predictions = result['predictions']
print(f" 📊 预测结果: {len(predictions)} 只股票")
# 显示前5名
if predictions:
sorted_predictions = sorted(predictions, key=lambda x: x['predicted_return'], reverse=True)
print(" 🏆 预测收益率前5名:")
for i, pred in enumerate(sorted_predictions[:5]):
print(f" {i+1}. {pred['ts_code']}: {pred['predicted_return']:+.4f}")
else:
print(f" ❌ API预测失败: {response.text}")
print()
# 3. 测试直接调用服务
print("3️⃣ 测试直接调用服务...")
app = create_app()
with app.app_context():
ml_manager = MLModelManager()
# 测试训练
print(" 🚀 测试直接训练...")
result = ml_manager.train_model('simple_demo_model', '2023-01-01', '2023-12-31')
if result['success']:
print(" ✅ 直接训练成功!")
print(f" 📊 R²分数: {result['metrics']['test_r2']:.4f}")
else:
print(f" ❌ 直接训练失败: {result['error']}")
# 测试预测
print(" 🔮 测试直接预测...")
predictions = ml_manager.predict('simple_demo_model', '2025-05-23')
if not predictions.empty:
print(f" ✅ 直接预测成功!预测了 {len(predictions)} 只股票")
top_5 = predictions.nlargest(5, 'predicted_return')
print(" 🏆 预测收益率前5名:")
for i, (_, row) in enumerate(top_5.iterrows()):
print(f" {i+1}. {row['ts_code']}: {row['predicted_return']:+.4f}")
else:
print(" ❌ 直接预测失败")
print("\n🎉 测试完成!")
except Exception as e:
print(f"❌ 测试过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_fixed_ml_training()