forked from henrylin99/quantitative_analysis
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_risk_management.py
More file actions
445 lines (366 loc) · 14.4 KB
/
test_risk_management.py
File metadata and controls
445 lines (366 loc) · 14.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
#!/usr/bin/env python3
"""
风险管理模块测试脚本
测试实时风险管理的各项功能
"""
import requests
import json
import time
from datetime import datetime
# 测试配置
BASE_URL = "http://127.0.0.1:5001"
PORTFOLIO_ID = "demo_portfolio"
def test_api_endpoint(endpoint, method='GET', data=None, description=""):
"""测试API接口"""
url = f"{BASE_URL}{endpoint}"
try:
if method == 'GET':
response = requests.get(url, timeout=10)
elif method == 'POST':
response = requests.post(url, json=data, timeout=10)
elif method == 'PUT':
response = requests.put(url, json=data, timeout=10)
elif method == 'DELETE':
response = requests.delete(url, timeout=10)
result = response.json()
if response.status_code == 200 and result.get('success'):
print(f"✅ {description}: 成功")
return result
else:
print(f"❌ {description}: 失败 - {result.get('message', '未知错误')}")
return None
except requests.exceptions.RequestException as e:
print(f"❌ {description}: 网络错误 - {str(e)}")
return None
except json.JSONDecodeError as e:
print(f"❌ {description}: JSON解析错误 - {str(e)}")
return None
def test_create_sample_portfolio():
"""创建示例投资组合"""
print("\n=== 创建示例投资组合 ===")
# 示例持仓数据
positions = [
{
"portfolio_id": PORTFOLIO_ID,
"ts_code": "000001.SZ",
"position_size": 1000,
"avg_cost": 12.50,
"sector": "银行"
},
{
"portfolio_id": PORTFOLIO_ID,
"ts_code": "000002.SZ",
"position_size": 500,
"avg_cost": 25.80,
"sector": "房地产"
},
{
"portfolio_id": PORTFOLIO_ID,
"ts_code": "600000.SH",
"position_size": 800,
"avg_cost": 8.90,
"sector": "银行"
},
{
"portfolio_id": PORTFOLIO_ID,
"ts_code": "600036.SH",
"position_size": 300,
"avg_cost": 35.20,
"sector": "银行"
},
{
"portfolio_id": PORTFOLIO_ID,
"ts_code": "000858.SZ",
"position_size": 600,
"avg_cost": 18.60,
"sector": "食品饮料"
}
]
success_count = 0
for position in positions:
result = test_api_endpoint(
"/api/realtime-analysis/risk/portfolio",
method='POST',
data=position,
description=f"创建持仓 {position['ts_code']}"
)
if result:
success_count += 1
print(f"成功创建 {success_count}/{len(positions)} 个持仓")
return success_count > 0
def test_portfolio_positions():
"""测试获取投资组合持仓"""
print("\n=== 测试投资组合持仓 ===")
result = test_api_endpoint(
f"/api/realtime-analysis/risk/portfolio/{PORTFOLIO_ID}/positions",
description="获取投资组合持仓"
)
if result and result.get('data'):
positions = result['data']['positions']
print(f"📊 持仓数量: {len(positions)}")
for pos in positions[:3]: # 显示前3个持仓
print(f" {pos['ts_code']}: {pos['position_size']}股, 成本价¥{pos['avg_cost']}")
return result is not None
def test_portfolio_metrics():
"""测试投资组合指标"""
print("\n=== 测试投资组合指标 ===")
result = test_api_endpoint(
f"/api/realtime-analysis/risk/portfolio/{PORTFOLIO_ID}/metrics",
description="获取投资组合指标"
)
if result and result.get('data'):
metrics = result['data']
print(f"📈 总市值: ¥{metrics.get('total_market_value', 0):,.2f}")
print(f"📊 总盈亏: ¥{metrics.get('total_unrealized_pnl', 0):,.2f}")
print(f"📋 持仓数量: {metrics.get('total_positions', 0)}")
# 显示行业分布
sector_dist = metrics.get('sector_distribution', {})
if sector_dist:
print("🏭 行业分布:")
for sector, weight in sector_dist.items():
print(f" {sector}: {weight:.1f}%")
return result is not None
def test_position_monitor():
"""测试持仓风险监控"""
print("\n=== 测试持仓风险监控 ===")
result = test_api_endpoint(
f"/api/realtime-analysis/risk/position-monitor?portfolio_id={PORTFOLIO_ID}",
description="持仓风险监控"
)
if result and result.get('data'):
data = result['data']
risk_summary = data.get('risk_summary', {})
print(f"⚠️ 整体风险等级: {risk_summary.get('overall_risk_level', '未知')}")
print(f"🔴 高风险持仓: {risk_summary.get('high_risk_positions', 0)}")
print(f"🟡 中风险持仓: {risk_summary.get('medium_risk_positions', 0)}")
print(f"📊 风险评分: {risk_summary.get('risk_score', 0)}")
return result is not None
def test_portfolio_risk_calculation():
"""测试投资组合风险计算"""
print("\n=== 测试投资组合风险计算 ===")
data = {
"portfolio_id": PORTFOLIO_ID,
"period_days": 60 # 使用较短的周期进行测试
}
result = test_api_endpoint(
"/api/realtime-analysis/risk/portfolio-risk",
method='POST',
data=data,
description="投资组合风险计算"
)
if result and result.get('data'):
risk_data = result['data']
risk_metrics = risk_data.get('risk_metrics', {})
var_metrics = risk_data.get('var_metrics', {})
print(f"📊 年化收益率: {risk_metrics.get('annual_return', 0):.4f}")
print(f"📈 年化波动率: {risk_metrics.get('annual_volatility', 0):.4f}")
print(f"📉 最大回撤: {risk_metrics.get('max_drawdown', 0):.4f}")
print(f"🎯 夏普比率: {risk_metrics.get('sharpe_ratio', 0):.4f}")
if var_metrics:
print(f"⚠️ VaR(95%): {var_metrics.get('var_95', 0):.4f}")
print(f"⚠️ VaR(99%): {var_metrics.get('var_99', 0):.4f}")
return result is not None
def test_stop_loss_take_profit():
"""测试止损止盈管理"""
print("\n=== 测试止损止盈管理 ===")
data = {
"portfolio_id": PORTFOLIO_ID,
"stop_loss_method": "percentage",
"stop_loss_value": 0.10, # 10%止损
"take_profit_method": "percentage",
"take_profit_value": 0.20 # 20%止盈
}
result = test_api_endpoint(
"/api/realtime-analysis/risk/stop-loss-take-profit",
method='POST',
data=data,
description="止损止盈管理"
)
if result and result.get('data'):
data = result['data']
updated_positions = data.get('updated_positions', [])
triggered_orders = data.get('triggered_orders', [])
print(f"📝 更新持仓数: {len(updated_positions)}")
print(f"⚡ 触发订单数: {len(triggered_orders)}")
if updated_positions:
print("💰 止损止盈设置:")
for pos in updated_positions[:3]: # 显示前3个
print(f" {pos['ts_code']}: 止损¥{pos['stop_loss_price']:.2f}, 止盈¥{pos['take_profit_price']:.2f}")
return result is not None
def test_risk_alerts():
"""测试风险预警"""
print("\n=== 测试风险预警 ===")
# 获取现有预警
result = test_api_endpoint(
f"/api/realtime-analysis/risk/alerts?portfolio_id={PORTFOLIO_ID}",
description="获取风险预警"
)
if result and result.get('data'):
alerts_data = result['data']
alerts_by_level = alerts_data.get('alerts_by_level', {})
total_alerts = sum(len(alerts) for alerts in alerts_by_level.values())
print(f"📢 总预警数: {total_alerts}")
print(f"🔴 高风险预警: {len(alerts_by_level.get('high', []))}")
print(f"🟡 中风险预警: {len(alerts_by_level.get('medium', []))}")
print(f"🟢 低风险预警: {len(alerts_by_level.get('low', []))}")
# 创建测试预警
alert_data = {
"ts_code": "000001.SZ",
"alert_type": "test_alert",
"alert_level": "medium",
"alert_message": "测试预警消息",
"risk_value": 0.15,
"threshold_value": 0.10
}
create_result = test_api_endpoint(
"/api/realtime-analysis/risk/alerts",
method='POST',
data=alert_data,
description="创建测试预警"
)
return result is not None
def test_stress_test():
"""测试压力测试"""
print("\n=== 测试压力测试 ===")
data = {
"portfolio_id": PORTFOLIO_ID
}
result = test_api_endpoint(
"/api/realtime-analysis/risk/stress-test",
method='POST',
data=data,
description="压力测试"
)
if result and result.get('data'):
stress_data = result['data']
scenarios = stress_data.get('scenarios', [])
worst_case = stress_data.get('worst_case', {})
best_case = stress_data.get('best_case', {})
print(f"🧪 测试场景数: {len(scenarios)}")
print(f"📉 最坏情况: {worst_case.get('scenario_name', '未知')} ({worst_case.get('pnl_percentage', 0):.2f}%)")
print(f"📈 最好情况: {best_case.get('scenario_name', '未知')} ({best_case.get('pnl_percentage', 0):.2f}%)")
if scenarios:
print("📊 压力测试结果:")
for scenario in scenarios[:3]: # 显示前3个场景
print(f" {scenario['scenario_name']}: {scenario['pnl_percentage']:.2f}%")
return result is not None
def test_batch_update_prices():
"""测试批量更新价格"""
print("\n=== 测试批量更新价格 ===")
data = {
"portfolio_id": PORTFOLIO_ID
}
result = test_api_endpoint(
"/api/realtime-analysis/risk/batch-update-prices",
method='POST',
data=data,
description="批量更新价格"
)
if result and result.get('data'):
data = result['data']
print(f"📊 总持仓数: {data.get('total_positions', 0)}")
print(f"✅ 更新成功数: {data.get('updated_positions', 0)}")
return result is not None
def test_risk_thresholds():
"""测试风险阈值管理"""
print("\n=== 测试风险阈值管理 ===")
# 获取当前阈值
result = test_api_endpoint(
"/api/realtime-analysis/risk/risk-thresholds",
description="获取风险阈值"
)
if result and result.get('data'):
thresholds = result['data']
print("⚙️ 当前风险阈值:")
for key, value in thresholds.items():
print(f" {key}: {value}")
# 更新阈值
update_data = {
"position_weight": 0.25, # 调整单一持仓权重阈值
"var_limit": 0.06 # 调整VaR限制
}
update_result = test_api_endpoint(
"/api/realtime-analysis/risk/risk-thresholds",
method='PUT',
data=update_data,
description="更新风险阈值"
)
return result is not None
def test_frontend_access():
"""测试前端页面访问"""
print("\n=== 测试前端页面访问 ===")
try:
response = requests.get(f"{BASE_URL}/realtime-analysis/risk-management", timeout=10)
if response.status_code == 200:
print("✅ 风险管理页面访问: 成功")
# 检查页面关键元素
content = response.text
key_elements = [
"实时风险管理",
"持仓管理",
"风险分析",
"预警管理",
"止损止盈",
"压力测试"
]
missing_elements = []
for element in key_elements:
if element not in content:
missing_elements.append(element)
if missing_elements:
print(f"⚠️ 缺少页面元素: {', '.join(missing_elements)}")
else:
print("✅ 页面元素检查: 完整")
return True
else:
print(f"❌ 风险管理页面访问: 失败 (状态码: {response.status_code})")
return False
except requests.exceptions.RequestException as e:
print(f"❌ 风险管理页面访问: 网络错误 - {str(e)}")
return False
def main():
"""主测试函数"""
print("🚀 开始风险管理模块测试")
print(f"📍 测试地址: {BASE_URL}")
print(f"📁 测试组合: {PORTFOLIO_ID}")
print(f"⏰ 测试时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# 测试项目列表
tests = [
("创建示例投资组合", test_create_sample_portfolio),
("投资组合持仓", test_portfolio_positions),
("投资组合指标", test_portfolio_metrics),
("持仓风险监控", test_position_monitor),
("投资组合风险计算", test_portfolio_risk_calculation),
("止损止盈管理", test_stop_loss_take_profit),
("风险预警", test_risk_alerts),
("压力测试", test_stress_test),
("批量更新价格", test_batch_update_prices),
("风险阈值管理", test_risk_thresholds),
("前端页面访问", test_frontend_access)
]
# 执行测试
passed = 0
total = len(tests)
for test_name, test_func in tests:
print(f"\n{'='*50}")
print(f"🧪 测试: {test_name}")
print('='*50)
try:
if test_func():
passed += 1
time.sleep(1) # 避免请求过于频繁
except Exception as e:
print(f"❌ 测试异常: {str(e)}")
# 测试总结
print(f"\n{'='*50}")
print("📊 测试总结")
print('='*50)
print(f"✅ 通过: {passed}/{total} ({passed/total*100:.1f}%)")
print(f"❌ 失败: {total-passed}/{total}")
if passed == total:
print("🎉 所有测试通过!风险管理模块运行正常。")
else:
print("⚠️ 部分测试失败,请检查相关功能。")
print(f"⏰ 测试完成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
if __name__ == "__main__":
main()