Skip to content

Commit c38662c

Browse files
authored
Merge pull request #232 from lonelam/feat/1m-index
feat: add 1m kdata for qmt recording
2 parents 21b09fd + 880348a commit c38662c

File tree

9 files changed

+294
-10
lines changed

9 files changed

+294
-10
lines changed

src/zvt/broker/qmt/qmt_quote.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import pandas as pd
77
from xtquant import xtdata
88

9+
from zvt.contract import Exchange
910
from zvt.contract import IntervalLevel, AdjustType
1011
from zvt.contract.api import decode_entity_id, df_to_db, get_db_session
1112
from zvt.domain import StockQuote, Stock, Stock1dKdata
1213
from zvt.domain.quotes.stock.stock_quote import Stock1mQuote, StockQuoteLog
14+
from zvt.recorders.em import em_api
1315
from zvt.utils.pd_utils import pd_is_not_null
1416
from zvt.utils.time_utils import (
1517
to_time_str,
@@ -84,7 +86,12 @@ def _qmt_instrument_detail_to_stock(stock_detail):
8486

8587

8688
def get_qmt_stocks():
87-
return xtdata.get_stock_list_in_sector("沪深A股")
89+
df = em_api.get_tradable_list(exchange=Exchange.bj)
90+
bj_stock_list = df["entity_id"].map(_to_qmt_code).tolist()
91+
92+
stock_list = xtdata.get_stock_list_in_sector("沪深A股")
93+
stock_list += bj_stock_list
94+
return stock_list
8895

8996

9097
def get_entity_list():
@@ -127,7 +134,10 @@ def get_entity_list():
127134

128135
tick = xtdata.get_full_tick(code_list=[stock])
129136
if tick and tick[stock]:
130-
if code.startswith("300") or code.startswith("688"):
137+
if code.startswith(("83", "87", "88", "889", "82", "920")):
138+
limit_up_price = tick[stock]["lastClose"] * 1.3
139+
limit_down_price = tick[stock]["lastClose"] * 0.7
140+
elif code.startswith("300") or code.startswith("688"):
131141
limit_up_price = tick[stock]["lastClose"] * 1.2
132142
limit_down_price = tick[stock]["lastClose"] * 0.8
133143
else:
@@ -150,9 +160,15 @@ def get_kdata(
150160
):
151161
code = _to_qmt_code(entity_id=entity_id)
152162
period = level.value
163+
start_time = to_time_str(start_timestamp, fmt="YYYYMMDDHHmmss")
164+
end_time = to_time_str(end_timestamp, fmt="YYYYMMDDHHmmss")
153165
# download比较耗时,建议单独定时任务来做
154166
if download_history:
155-
xtdata.download_history_data(stock_code=code, period=period)
167+
print(f"download from {start_time} to {end_time}")
168+
xtdata.download_history_data(
169+
stock_code=code, period=period,
170+
start_time=start_time, end_time=end_time
171+
)
156172
records = xtdata.get_market_data(
157173
stock_list=[code],
158174
period=period,

src/zvt/contract/recorder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def __init__(
196196
end_timestamp=None,
197197
return_unfinished=False,
198198
) -> None:
199+
self.start_timestamp = to_pd_timestamp(start_timestamp)
200+
self.end_timestamp = to_pd_timestamp(end_timestamp)
199201
super().__init__(
200202
force_update,
201203
sleeping_time,
@@ -213,8 +215,6 @@ def __init__(
213215
self.real_time = real_time
214216
self.close_hour, self.close_minute = self.entity_schema.get_close_hour_and_minute()
215217
self.fix_duplicate_way = fix_duplicate_way
216-
self.start_timestamp = to_pd_timestamp(start_timestamp)
217-
self.end_timestamp = to_pd_timestamp(end_timestamp)
218218

219219
def get_latest_saved_record(self, entity):
220220
order = eval("self.data_schema.{}.desc()".format(self.get_evaluated_time_field()))

src/zvt/domain/quotes/index/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@
2020
from .index_1wk_kdata import __all__ as _index_1wk_kdata_all
2121

2222
__all__ += _index_1wk_kdata_all
23+
24+
from .index_1m_kdata import *
25+
from .index_1m_kdata import __all__ as _index_1m_kdata_all
26+
__all__ += _index_1m_kdata_all
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
# this file is generated by gen_kdata_schema function, dont't change it
3+
from sqlalchemy.orm import declarative_base
4+
5+
from zvt.contract import TradableEntity
6+
from zvt.contract.register import register_schema
7+
from zvt.domain.quotes import IndexKdataCommon
8+
9+
KdataBase = declarative_base()
10+
11+
12+
class Index1mKdata(KdataBase, IndexKdataCommon, TradableEntity):
13+
__tablename__ = "index_1m_kdata"
14+
15+
16+
register_schema(providers=["em", "sina", "qmt"], db_name="index_1m_kdata", schema_base=KdataBase, entity_type="index")
17+
18+
19+
# the __all__ is generated
20+
__all__ = ["Index1mKdata"]

src/zvt/recorders/qmt/__init__.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
1-
# -*- coding: utf-8 -*-
2-
1+
# -*- coding: utf-8 -*-#
32

43
# the __all__ is generated
54
__all__ = []
5+
6+
# __init__.py structure:
7+
# common code of the package
8+
# export interface in __all__ which contains __all__ of its sub modules
9+
10+
# import all from submodule quotes
11+
from .quotes import *
12+
from .quotes import __all__ as _quotes_all
13+
14+
__all__ += _quotes_all
15+
16+
# import all from submodule money_flow
17+
from .index import *
18+
from .index import __all__ as _index_all
19+
20+
__all__ += _index_all
21+
22+
# import all from submodule meta
23+
from .meta import *
24+
from .meta import __all__ as _meta_all
25+
26+
__all__ += _meta_all
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
# the __all__ is generated
5+
__all__ = []
6+
7+
# __init__.py structure:
8+
# common code of the package
9+
# export interface in __all__ which contains __all__ of its sub modules
10+
11+
# import all from submodule qmt_kdata_recorder
12+
from .qmt_index_recorder import *
13+
from .qmt_index_recorder import __all__ as _qmt_index_recorder_all
14+
15+
__all__ += _qmt_index_recorder_all
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# -*- coding: utf-8 -*-
2+
import pandas as pd
3+
from zvt.api.kdata import get_kdata_schema
4+
from zvt.broker.qmt import qmt_quote
5+
from zvt.consts import IMPORTANT_INDEX
6+
from zvt.contract import IntervalLevel
7+
from zvt.contract.api import df_to_db
8+
from zvt.contract.recorder import FixedCycleDataRecorder
9+
from zvt.contract.utils import evaluate_size_from_timestamp
10+
from zvt.domain import Index, Index1mKdata
11+
from zvt.utils.pd_utils import pd_is_not_null
12+
from zvt.utils.time_utils import TIME_FORMAT_DAY, TIME_FORMAT_MINUTE, current_date, to_time_str
13+
14+
15+
class QmtIndexRecorder(FixedCycleDataRecorder):
16+
provider = "qmt"
17+
data_schema = Index1mKdata
18+
entity_provider = "em"
19+
entity_schema = Index
20+
download_history_data = False
21+
22+
def __init__(
23+
self,
24+
force_update=True,
25+
sleeping_time=10,
26+
exchanges=None,
27+
entity_id=None,
28+
entity_ids=None,
29+
code=None,
30+
codes=None,
31+
day_data=False,
32+
entity_filters=None,
33+
ignore_failed=True,
34+
real_time=False,
35+
fix_duplicate_way="ignore",
36+
start_timestamp=None,
37+
end_timestamp=None,
38+
level=IntervalLevel.LEVEL_1DAY,
39+
kdata_use_begin_time=False,
40+
one_day_trading_minutes=24 * 60,
41+
return_unfinished=False,
42+
download_history_data=False
43+
) -> None:
44+
level = IntervalLevel(level)
45+
self.entity_type = "index"
46+
self.download_history_data = download_history_data
47+
48+
self.data_schema = get_kdata_schema(entity_type=self.entity_type, level=level, adjust_type=None)
49+
50+
super().__init__(
51+
force_update,
52+
sleeping_time,
53+
exchanges,
54+
entity_id,
55+
entity_ids,
56+
code,
57+
codes,
58+
day_data,
59+
entity_filters,
60+
ignore_failed,
61+
real_time,
62+
fix_duplicate_way,
63+
start_timestamp,
64+
end_timestamp,
65+
level,
66+
kdata_use_begin_time,
67+
one_day_trading_minutes,
68+
return_unfinished,
69+
)
70+
self.one_day_trading_minutes = 240
71+
72+
def record(self, entity, start, end, size, timestamps):
73+
if start and (self.level == IntervalLevel.LEVEL_1DAY):
74+
start = start.date()
75+
if not start:
76+
start = "2005-01-01"
77+
if not end:
78+
end = current_date()
79+
80+
# 统一高频数据习惯,减小数据更新次数,分钟K线需要直接多读1根K线,以兼容start_timestamp=9:30, end_timestamp=15:00的情况
81+
if self.level == IntervalLevel.LEVEL_1MIN:
82+
end += pd.Timedelta(seconds=1)
83+
84+
df = qmt_quote.get_kdata(
85+
entity_id=entity.id,
86+
start_timestamp=start,
87+
end_timestamp=end,
88+
adjust_type=None,
89+
level=self.level,
90+
download_history=self.download_history_data,
91+
)
92+
time_str_fmt = TIME_FORMAT_DAY if self.level == IntervalLevel.LEVEL_1DAY else TIME_FORMAT_MINUTE
93+
if pd_is_not_null(df):
94+
df["entity_id"] = entity.id
95+
df["timestamp"] = pd.to_datetime(df.index)
96+
df["id"] = df.apply(lambda row: f"{row['entity_id']}_{to_time_str(row['timestamp'], fmt=time_str_fmt)}",
97+
axis=1)
98+
df["provider"] = "qmt"
99+
df["level"] = self.level.value
100+
df["code"] = entity.code
101+
df["name"] = entity.name
102+
df.rename(columns={"amount": "turnover"}, inplace=True)
103+
df["change_pct"] = (df["close"] - df["preClose"]) / df["preClose"]
104+
df_to_db(df=df, data_schema=self.data_schema, provider=self.provider, force_update=self.force_update)
105+
106+
else:
107+
self.logger.info(f"no kdata for {entity.id}")
108+
109+
def evaluate_start_end_size_timestamps(self, entity):
110+
if self.download_history_data and self.start_timestamp and self.end_timestamp:
111+
# 历史数据可能碎片化,允许按照实际start和end之间有没有写满数据
112+
expected_size = evaluate_size_from_timestamp(start_timestamp=self.start_timestamp,
113+
end_timestamp=self.end_timestamp, level=self.level,
114+
one_day_trading_minutes=self.one_day_trading_minutes)
115+
116+
recorded_size = self.session.query(self.data_schema).filter(
117+
self.data_schema.entity_id == entity.id,
118+
self.data_schema.timestamp >= self.start_timestamp,
119+
self.data_schema.timestamp <= self.end_timestamp
120+
).count()
121+
122+
if expected_size != recorded_size:
123+
# print(f"expected_size: {expected_size}, recorded_size: {recorded_size}")
124+
return self.start_timestamp, self.end_timestamp, self.default_size, None
125+
126+
start_timestamp, end_timestamp, size, timestamps = super().evaluate_start_end_size_timestamps(entity)
127+
# start_timestamp is the last updated timestamp
128+
if self.end_timestamp is not None:
129+
if start_timestamp >= self.end_timestamp:
130+
return start_timestamp, end_timestamp, 0, None
131+
else:
132+
size = evaluate_size_from_timestamp(
133+
start_timestamp=start_timestamp,
134+
level=self.level,
135+
one_day_trading_minutes=self.one_day_trading_minutes,
136+
end_timestamp=self.end_timestamp,
137+
)
138+
return start_timestamp, self.end_timestamp, size, timestamps
139+
140+
return start_timestamp, end_timestamp, size, timestamps
141+
142+
# # 中证,上海
143+
# def record_cs_index(self, index_type):
144+
# df = cs_index_api.get_cs_index(index_type=index_type)
145+
# df_to_db(data_schema=self.data_schema, df=df, provider=self.provider, force_update=True)
146+
# self.logger.info(f"finish record {index_type} index")
147+
#
148+
# # 国证,深圳
149+
# def record_cn_index(self, index_type):
150+
# if index_type == "cni":
151+
# category_map_url = cn_index_api.cni_category_map_url
152+
# elif index_type == "sz":
153+
# category_map_url = cn_index_api.sz_category_map_url
154+
# else:
155+
# self.logger.error(f"not support index_type: {index_type}")
156+
# assert False
157+
#
158+
# for category, _ in category_map_url.items():
159+
# df = cn_index_api.get_cn_index(index_type=index_type, category=category)
160+
# df_to_db(data_schema=self.data_schema, df=df, provider=self.provider, force_update=True)
161+
# self.logger.info(f"finish record {index_type} index:{category.value}")
162+
163+
164+
if __name__ == "__main__":
165+
# init_log('china_stock_category.log')
166+
start_timestamp = pd.Timestamp("2024-12-01")
167+
end_timestamp = pd.Timestamp("2024-12-03")
168+
QmtIndexRecorder(codes=IMPORTANT_INDEX, level=IntervalLevel.LEVEL_1MIN, sleeping_time=0,
169+
start_timestamp=start_timestamp, end_timestamp=end_timestamp,
170+
download_history_data=True).run()
171+
172+
# the __all__ is generated
173+
__all__ = ["QmtIndexRecorder"]

src/zvt/recorders/qmt/quotes/qmt_kdata_recorder.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
from zvt.api.kdata import get_kdata_schema, get_kdata
55
from zvt.broker.qmt import qmt_quote
66
from zvt.contract import IntervalLevel, AdjustType
7-
from zvt.contract.api import df_to_db
7+
from zvt.contract.api import df_to_db, get_db_session, get_entities
88
from zvt.contract.recorder import FixedCycleDataRecorder
99
from zvt.domain import (
1010
Stock,
1111
StockKdataCommon,
1212
)
1313
from zvt.utils.pd_utils import pd_is_not_null
14-
from zvt.utils.time_utils import current_date, to_time_str
14+
from zvt.utils.time_utils import current_date, to_time_str, now_time_str
1515

1616

1717
class BaseQmtKdataRecorder(FixedCycleDataRecorder):
@@ -69,6 +69,40 @@ def __init__(
6969
return_unfinished,
7070
)
7171

72+
def init_entities(self):
73+
"""
74+
init the entities which we would record data for
75+
76+
"""
77+
if self.entity_provider == self.provider and self.entity_schema == self.data_schema:
78+
self.entity_session = self.session
79+
else:
80+
self.entity_session = get_db_session(provider=self.entity_provider, data_schema=self.entity_schema)
81+
82+
if self.day_data:
83+
df = self.data_schema.query_data(
84+
start_timestamp=now_time_str(), columns=["entity_id", "timestamp"], provider=self.provider
85+
)
86+
if pd_is_not_null(df):
87+
entity_ids = df["entity_id"].tolist()
88+
self.logger.info(f"ignore entity_ids:{entity_ids}")
89+
if self.entity_filters:
90+
self.entity_filters.append(self.entity_schema.entity_id.notin_(entity_ids))
91+
else:
92+
self.entity_filters = [self.entity_schema.entity_id.notin_(entity_ids)]
93+
94+
#: init the entity list
95+
self.entities = get_entities(
96+
session=self.entity_session,
97+
entity_schema=self.entity_schema,
98+
exchanges=self.exchanges,
99+
entity_ids=self.entity_ids,
100+
codes=self.codes,
101+
return_type="domain",
102+
provider=self.entity_provider,
103+
filters=self.entity_filters,
104+
)
105+
72106
def record(self, entity, start, end, size, timestamps):
73107
if start and (self.level == IntervalLevel.LEVEL_1DAY):
74108
start = start.date()

src/zvt/tasks/qmt_data_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from xtquant import xtdata
77

88
from zvt import init_log
9+
from zvt.broker.qmt.qmt_quote import get_qmt_stocks
910
from zvt.contract import AdjustType
1011
from zvt.recorders.qmt.meta import QMTStockRecorder
1112
from zvt.recorders.qmt.quotes import QMTStockKdataRecorder
@@ -16,7 +17,7 @@
1617
def download_data(download_tick=False):
1718
period = "1d"
1819
xtdata.download_sector_data()
19-
stock_codes = xtdata.get_stock_list_in_sector("沪深A股")
20+
stock_codes = get_qmt_stocks()
2021
stock_codes = sorted(stock_codes)
2122
count = len(stock_codes)
2223
download_status = {"ok": False}

0 commit comments

Comments
 (0)