Skip to content

Commit 066a58a

Browse files
committed
Added live trading support
1 parent e63d3a0 commit 066a58a

File tree

5 files changed

+71
-28
lines changed

5 files changed

+71
-28
lines changed

livetrading/broker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ def __init__(
2323
self._cached_pairs: Dict[Pair] = {}
2424

2525
def subscribe_to_ticker_events(
26-
self, pair: Pair, event_handler
26+
self, pair: Pair, interval: str, event_handler
2727
):
2828
"""Registers a callable that will be called every ticker.
2929
30+
:param bar_duration: The bar duration. One of 1s, 1m, 3m, 5m, 15m, 30m, 1h, 2h, 4h, 6h, 8h, 12h, 1d, 3d, 1w, 1M.
3031
:param pair: The trading pair.
3132
:param event_handler: A callable that receives an TickerEvent.
3233
"""
33-
event_source = TickersEventSource(pair, self.ws_cli)
34+
35+
event_source = TickersEventSource(pair, interval, self.ws_cli)
3436
channel = "ticker"
3537

3638
self._subscribe_to_ws_channel_events(
@@ -78,7 +80,6 @@ def _subscribe_to_ws_channel_events(
7880
):
7981
# Set the event source for the channel.
8082
self.ws_cli.set_channel_event_source(channel, event_source)
81-
# self.ws_cli.subscribe_to_channels()
8283

8384
# Subscribe the event handler to the event source.
8485
self.dispatcher.subscribe(event_source, event_handler)

livetrading/env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
22
"ws_url": "wss://ws-feed.exchange.coinbase.com",
33
"api_url": "https://api.exchange.coinbase.com/",
4-
"ws_timeout": 5}
4+
"ws_timeout": 5
55
}

livetrading/event.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,26 @@
77
from typing import Optional
88

99

10+
intervals = {
11+
"1s": 1,
12+
"1m": 60,
13+
"3m": 3 * 60,
14+
"5m": 5 * 60,
15+
"15m": 15 * 60,
16+
"30m": 30 * 60,
17+
"1h": 3600,
18+
"2h": 2 * 3600,
19+
"4h": 4 * 3600,
20+
"6h": 6 * 3600,
21+
"8h": 8 * 3600,
22+
"12h": 12 * 3600,
23+
"1d": 86400,
24+
"3d": 3 * 86400,
25+
"1w": 7 * 86400,
26+
"1M": 31 * 86400
27+
}
28+
29+
1030
@dataclasses.dataclass
1131
class Bar:
1232
"""A Bar, aka candlestick, is the summary of the trading activity in a given period.
@@ -150,14 +170,16 @@ class TickersEventSource(ChannelEventSource):
150170
151171
:param pair: The trading pair.
152172
"""
153-
def __init__(self, pair: Pair, producer: EventProducer):
173+
def __init__(self, pair: Pair, when: datetime, producer: EventProducer):
154174
super().__init__(producer=producer)
155175
self.pair: Pair = pair
176+
self.when = intervals.get(when)
156177

157178
def push_to_queue(self, message: dict):
158179
timestamp = message["time"]
180+
dt = isoparse(timestamp) + datetime.timedelta(seconds=self.when)
159181
self.events.append(TickerEvent(
160-
isoparse(timestamp),
182+
dt,
161183
Ticker(self.pair, message)))
162184

163185

@@ -199,7 +221,7 @@ class BarEvent(Event):
199221
def __init__(self, when, bar: Bar):
200222
super().__init__(when)
201223

202-
self.bar = bar
224+
self.data = bar
203225

204226

205227
class TickerEvent(Event):
@@ -211,4 +233,4 @@ class TickerEvent(Event):
211233
def __init__(self, when, ticker: Ticker):
212234
super().__init__(when)
213235

214-
self.ticker = ticker
236+
self.data = ticker

livetrading/executor.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import time
22
import datetime
33
import logging
4-
4+
from functools import partial
55
from typing import Any, Dict, List, Set, Optional
66

7+
from backtesting import Backtest
8+
from .converter import ohlcv_to_dataframe
79
from .event import Event, EventSource, EventProducer
810

911
logger = logging.getLogger(__name__)
@@ -13,17 +15,29 @@ class EventDispatcher:
1315
"""Responsible for connecting event sources to event handlers and dispatching events
1416
in the right order.
1517
"""
16-
def __init__(self):
18+
def __init__(self, strategy):
1719
self._event_handlers: Dict[EventSource, List[Any]] = {}
1820
self._prefetched_events: Dict[EventSource, Optional[Event]] = {}
1921
self._prev_events: Dict[EventSource, datetime.datetime] = {}
2022
self._producers: Set[EventProducer] = set()
2123
self._running = False
2224
self._stopped = False
2325
self._current_event_dt = None
26+
self.strategy = strategy
27+
self.backtesting = None
2428

2529
def set_strategy(self, strategy):
26-
self.strategy = strategy
30+
self._strategy = strategy
31+
32+
def set_backtesting_partial(self, cash: float = 10_000,
33+
commission: float = .0,
34+
margin: float = 1.,
35+
trade_on_close=False,
36+
hedging=False,
37+
exclusive_orders=False):
38+
self.backtesting = partial(Backtest, strategy=self.strategy, cash=cash, commission=commission,
39+
margin=margin, trade_on_close=trade_on_close,
40+
hedging=hedging, exclusive_orders=exclusive_orders)
2741

2842
def subscribe(self, source: EventSource, event_handler: Any):
2943
"""Registers an callable that will be called when an event source has new events.
@@ -64,6 +78,10 @@ def _dispatch_next(self, ge_or_assert: Optional[datetime.datetime]):
6478
]
6579
for source in sources_to_pop:
6680
if source.events:
81+
df = ohlcv_to_dataframe([event.data for event in source.events])
82+
bt = self.backtesting(data=df)
83+
bt.run()
84+
6785
event = source.events.pop()
6886
# Check that events from the same source are returned in order.
6987
prev_event = self._prev_events.get(source)
@@ -92,8 +110,6 @@ def _dispatch_next(self, ge_or_assert: Optional[datetime.datetime]):
92110
self._prefetched_events[source] = None
93111

94112
self._current_event_dt = None
95-
self.strategy.next()
96-
self.strategy.init()
97113
return next_dt
98114

99115
def stop(self):

livetrading/live_trading.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,34 @@
1+
import pandas as pd
12
import websocket
23

34
from backtesting import Strategy
4-
from backtesting._util import _Data
55
from livetrading import executor
66
from livetrading.broker import Broker, Pair
77
from livetrading.config import config
8-
from livetrading.converter import ohlcv_to_dataframe
8+
9+
10+
def SMA(arr: pd.Series, n: int) -> pd.Series:
11+
"""
12+
Returns `n`-period simple moving average of array `arr`.
13+
"""
14+
return pd.Series(arr).rolling(n).mean()
915

1016

1117
class LiveStrategy(Strategy):
18+
n1 = 10
19+
n2 = 20
1220

13-
def __init__(self, broker):
14-
super().__init__(broker=broker, data=[], params={})
15-
self.event_data = []
21+
def __init__(self, broker, data, params):
22+
super().__init__(broker=broker, data=data, params=params)
1623

1724
def init(self):
18-
super().init()
19-
self.set_atr_periods()
25+
sma1 = self.I(SMA, self.data.Close, self.n1)
26+
sma2 = self.I(SMA, self.data.Close, self.n2)
2027

2128
def set_atr_periods(self):
2229
if len(self.data) > 1:
2330
print(self.data.High, self.data.Low)
2431

25-
def on_bar_event(self, event):
26-
self.event_data.append(event.ticker)
27-
event_df = ohlcv_to_dataframe(self.event_data)
28-
self._data = _Data(event_df.copy(deep=False))
29-
3032
def next(self):
3133
print(self.data)
3234

@@ -46,19 +48,21 @@ def on_event(self, bar_event):
4648

4749
websocket.enableTrace(False)
4850

49-
event_dis = executor.EventDispatcher()
51+
event_dis = executor.EventDispatcher(LiveStrategy)
5052

5153
exchange = Broker(event_dis, config=config)
5254

5355
pair_info = exchange.get_pair_info('BTC-USD')
5456

5557
position_mgr = PositionManager(exchange, 0.8)
5658

57-
strategy = LiveStrategy(exchange)
59+
strategy = LiveStrategy(exchange, [], {})
5860

5961
exchange.subscribe_to_ticker_events(Pair(base_symbol="UTC", quote_symbol="SDT"),
60-
strategy.on_bar_event)
62+
'3m', position_mgr.on_event)
6163

6264
event_dis.set_strategy(strategy)
6365

66+
event_dis.set_backtesting_partial(cash=100000)
67+
6468
event_dis.run()

0 commit comments

Comments
 (0)