Skip to content

Added live trading support #1009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added livetrading/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions livetrading/broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from decimal import Decimal
from typing import Any, Dict, Optional

from livetrading.event import KLinesEventSource, Pair, PairInfo, TickersEventSource
from livetrading.rest_cli import RestClient
from livetrading.websocket_client import WSClient


class Broker:
"""A client for crypto currency exchange.
:param dispatcher: The event dispatcher.
:param config: Config settings for exchange.
"""
def __init__(
self, dispatcher, config
):
self.dispatcher = dispatcher
self.config = config
self.api_cli = RestClient(self.config)
self.cli: Optional[Any] = None # external libs as ccxt
self.ws_cli = WSClient(config)
self._cached_pairs: Dict[Pair] = {}

def subscribe_to_ticker_events(
self, pair: Pair, interval: str, event_handler
):
"""Registers a callable that will be called every ticker.
:param bar_duration: The bar duration. One of 1s, 1m, 3m, 5m, 15m, 30m, 1h, 2h, 4h, 6h, 8h, 12h, 1d, 3d, 1w, 1M.
:param pair: The trading pair.
:param event_handler: A callable that receives an TickerEvent.
"""

event_source = TickersEventSource(pair, interval, self.ws_cli)
channel = "ticker"

self._subscribe_to_ws_channel_events(
channel,
event_handler,
event_source
)

def subscribe_to_bar_events(
self, pair: Pair, event_handler, interval
):
"""Registers a callable that will be called every bar.
:param pair: The trading pair.
:param event_handler: A callable that receives an BarEvent.
"""
event_source = KLinesEventSource(pair, self.ws_cli)
channel = event_source.ws_channel(interval)

self._subscribe_to_ws_channel_events(
channel,
event_handler,
event_source
)

def get_pair_info(self, pair: Pair) -> PairInfo:
"""Returns information about a trading pair.
:param pair: The trading pair.
"""
ret = self._cached_pairs.get(pair)
api_path = '/'.join(['products', pair])
if not ret:
pair_info = self.api_cli.call(method='GET', apipath=api_path)
self._cached_pairs[pair] = PairInfo(Decimal(pair_info['base_increment']),
Decimal(pair_info['quote_increment']))
return self._cached_pairs

def get_data_df(self, event_source):
data_source = self.ws_cli.event_sources[event_source]
return list(data_source.events)

def _subscribe_to_ws_channel_events(
self, channel: str, event_handler, event_source
):
# Set the event source for the channel.
self.ws_cli.set_channel_event_source(channel, event_source)

# Subscribe the event handler to the event source.
self.dispatcher.subscribe(event_source, event_handler)
4 changes: 4 additions & 0 deletions livetrading/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from configloader import ConfigLoader

config = ConfigLoader()
config.update_from_json_file('path_to_json_file')
17 changes: 17 additions & 0 deletions livetrading/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pandas as pd

DEFAULT_DATAFRAME_COLUMNS = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']

def ohlcv_to_dataframe(historical_data: list) -> pd.DataFrame:
"""
Converts historical data to a Dataframe
:param historical_data: list with candle (OHLCV) data
:return: DataFrame
"""
df = pd.DataFrame(
[{fn: getattr(f, fn) for fn in DEFAULT_DATAFRAME_COLUMNS} for f in historical_data]
)
df['Date'] = pd.to_datetime(df['Date'], unit='ms', utc=True, )
df = df.set_index('Date')
df = df.sort_index(ascending=True)
return df.head()
5 changes: 5 additions & 0 deletions livetrading/env
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"ws_url": "wss://ws-feed.exchange.coinbase.com",
"api_url": "https://api.exchange.coinbase.com/",
"ws_timeout": 5
}
236 changes: 236 additions & 0 deletions livetrading/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import abc
import dataclasses
import datetime

from collections import deque
from dateutil.parser import isoparse
from typing import Optional


intervals = {
"1s": 1,
"1m": 60,
"3m": 3 * 60,
"5m": 5 * 60,
"15m": 15 * 60,
"30m": 30 * 60,
"1h": 3600,
"2h": 2 * 3600,
"4h": 4 * 3600,
"6h": 6 * 3600,
"8h": 8 * 3600,
"12h": 12 * 3600,
"1d": 86400,
"3d": 3 * 86400,
"1w": 7 * 86400,
"1M": 31 * 86400
}


@dataclasses.dataclass
class Bar:
"""A Bar, aka candlestick, is the summary of the trading activity in a given period.
:param date: The beginning of the period. It must have timezone information set.
:param pair: The trading pair.
:param open: The opening price.
:param high: The highest traded price.
:param low: The lowest traded price.
:param close: The closing price.
:param volume: The volume traded.
"""
date: datetime
pair: str
Open: float
High: float
Low: float
Close: float
Volume: float


@dataclasses.dataclass
class Pair:
"""A trading pair.
:param base_symbol: The base symbol.
:param quote_symbol: The quote symbol.
"""
base_symbol: str
quote_symbol: str

def __str__(self):
# change format here to reflect corresponding exchange
return "{}-{}".format(self.base_symbol, self.quote_symbol)


@dataclasses.dataclass
class PairInfo:
"""Information about a trading pair.
:param base_increment: The increment for the base symbol.
:param quote_increment: The increment for the quote symbol.
"""
base_increment: float
quote_increment: float


class Ticker:
"""A Ticker constantly updating stream of information about a stock.
:param datetime: The beginning of the period. It must have timezone information set.
:param pair: The trading pair.
:param open: The opening price.
:param high: The highest traded price.
:param low: The lowest traded price.
:param price: The price.
:param volume: The volume traded.
"""
def __init__(self, pair: Pair, json: dict):
self.pair: Pair = pair
self.json: dict = json
self.Date = isoparse(json['time'])
self.Volume = float(json["volume_24h"])
self.Open = float(json["open_24h"])
self.High = float(json["high_24h"])
self.Low = float(json["low_24h"])
self.Close = float(json["price"])


class KlineBar(Bar):
"""
K-line, aka candlestick, is a chart marked with the opening price, closing price,
highest price, and lowest price to reflect price changes.
:param pair: The trading pair.
:param json: Message json.
"""
def __init__(self, pair: Pair, json: dict):
super().__init__(
datetime.utcfromtimestamp(
int(json["t"] / 1e3).replace(tzinfo=datetime.timezone.utc)),
pair, float(json["o"]), float(json["h"]),
float(json["l"]), float(json["c"]), float(json["v"])
)
self.pair: Pair = pair
self.json: dict = json


class EventProducer:
"""Base class for event producers.
.. note::
Main method is for main functions that should be performed for an event producer.
Finalize method is called on error or stop.
"""
def main(self):
"""Override to run the loop that produces events."""
pass

def finalize(self):
"""Override to perform task and transaction cancellation."""
pass


class Event:
"""Base class for events.
:param when: The datetime when the event occurred.
Used to calculate the datetime for the next event.
It must have timezone information set.
"""

def __init__(self, when: datetime.datetime):
self.when: datetime.datetime = when


class EventSource(metaclass=abc.ABCMeta):
"""Base class for events storage.
:param producer: EventProducer.
"""

def __init__(self, producer: Optional[EventProducer] = None):
self.producer = producer
self.events = deque()


class ChannelEventSource(EventSource):
"""Base class for websockets channels.
:param producer: EventProducer.
"""
def __init__(self, producer: EventProducer):
super().__init__(producer=producer)

@abc.abstractmethod
def push_to_queue(self, message: dict):
raise NotImplementedError()


class TickersEventSource(ChannelEventSource):
"""An event source for :class:`Ticker` instances.
:param pair: The trading pair.
"""
def __init__(self, pair: Pair, when: datetime, producer: EventProducer):
super().__init__(producer=producer)
self.pair: Pair = pair
self.when = intervals.get(when)

def push_to_queue(self, message: dict):
timestamp = message["time"]
dt = isoparse(timestamp) + datetime.timedelta(seconds=self.when)
self.events.append(TickerEvent(
dt,
Ticker(self.pair, message)))


class KLinesEventSource(EventSource):
"""An event source for :class:`KLineBar` instances.
:param pair: The trading pair..
"""
def __init__(self, pair: Pair, producer: EventProducer):
super().__init__(producer=producer)
self.pair: Pair = pair

def push_to_queue(self, message: dict):
kline_event = message["data"]
kline = kline_event["k"]
# Wait for the last update to the kline.
if kline["x"] is False:
return
self.events.append(BarEvent(
datetime.utcfromtimestamp(
int(kline_event["E"] / 1e3).replace(tzinfo=datetime.timezone.utc)),
KlineBar(self.pair, kline)))

def ws_channel(self, interval: str) -> str:
"""
Generate websocket channel
"""
return "{}@kline_{}".format(
"{}{}".format(self.pair.base_symbol.upper(), self.pair.quote_symbol.upper()).lower(),
interval)


class BarEvent(Event):
"""An event for :class:`Bar` instances.
:param when: The datetime when the event occurred. It must have timezone information set.
:param bar: The bar.
"""
def __init__(self, when, bar: Bar):
super().__init__(when)

self.data = bar


class TickerEvent(Event):
"""An event for :class:`Ticker` instances.
:param when: The datetime when the event occurred. It must have timezone information set.
:param ticker: The Ticker.
"""
def __init__(self, when, ticker: Ticker):
super().__init__(when)

self.data = ticker
130 changes: 130 additions & 0 deletions livetrading/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import time
import datetime
import logging
from functools import partial
from typing import Any, Dict, List, Set, Optional

from backtesting import Backtest
from .converter import ohlcv_to_dataframe
from .event import Event, EventSource, EventProducer

logger = logging.getLogger(__name__)


class EventDispatcher:
"""Responsible for connecting event sources to event handlers and dispatching events
in the right order.
"""
def __init__(self, strategy):
self._event_handlers: Dict[EventSource, List[Any]] = {}
self._prefetched_events: Dict[EventSource, Optional[Event]] = {}
self._prev_events: Dict[EventSource, datetime.datetime] = {}
self._producers: Set[EventProducer] = set()
self._running = False
self._stopped = False
self._current_event_dt = None
self.strategy = strategy
self.backtesting = None

def set_strategy(self, strategy):
self._strategy = strategy

def set_backtesting_partial(self, cash: float = 10_000,
commission: float = .0,
margin: float = 1.,
trade_on_close=False,
hedging=False,
exclusive_orders=False):
self.backtesting = partial(Backtest, strategy=self.strategy, cash=cash, commission=commission,
margin=margin, trade_on_close=trade_on_close,
hedging=hedging, exclusive_orders=exclusive_orders)

def subscribe(self, source: EventSource, event_handler: Any):
"""Registers an callable that will be called when an event source has new events.
:param source: An event source.
:param event_handler: An callable that receives an event.
"""
assert not self._running
handlers = self._event_handlers.setdefault(source, [])
if event_handler not in handlers:
handlers.append(event_handler)
if source.producer:
self._producers.add(source.producer)

def run(self):
assert not self._running, "Running or already ran"

self._running = True
try:
# Run producers and dispatch loop.
for producer in self._producers:
producer.main()
self._dispatch_loop()
except Exception as error:
logger.error(error)
finally:
for producer in self._producers:
producer.finalize()

def on_error(self, error: Any):
logger.error(error)

def _dispatch_next(self, ge_or_assert: Optional[datetime.datetime]):
# Pre-fetch events from all sources.
sources_to_pop = [
source for source in self._event_handlers.keys() if
self._prefetched_events.get(source) is None
]
for source in sources_to_pop:
if source.events:
df = ohlcv_to_dataframe([event.data for event in source.events])
bt = self.backtesting(data=df)
bt.run()

event = source.events.pop()
# Check that events from the same source are returned in order.
prev_event = self._prev_events.get(source)
if prev_event is not None and event.when < prev_event.when:
continue

self._prev_events[source] = event
self._prefetched_events[source] = event

# Calculate the datetime for the next event using the prefetched events.
next_dt = None
prefetched_events = [e for e in self._prefetched_events.values() if e]
if prefetched_events:
next_dt = min(map(lambda e: e.when, prefetched_events))
assert ge_or_assert is None or next_dt is None or next_dt >= ge_or_assert, \
f"{next_dt} can't be dispatched after {ge_or_assert}"

# Dispatch events matching the desired datetime.
event_handlers = []
for source, e in self._prefetched_events.items():
if e is not None and e.when == next_dt:
# Collect event handlers for the event source.
event_handlers += [event_handler(e) for event_handler in
self._event_handlers.get(source, [])]
# Consume the event.
self._prefetched_events[source] = None

self._current_event_dt = None
return next_dt

def stop(self):
"""Requests the event dispatcher to stop the event processing loop."""
self._stopped = True

for producer in self._producers:
producer.finalize()

def _dispatch_loop(self):
last_dt = None

while not self._stopped:
dispatched_dt = self._dispatch_next(last_dt)
if dispatched_dt is None:
time.sleep(0.01)
else:
last_dt = dispatched_dt
68 changes: 68 additions & 0 deletions livetrading/live_trading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pandas as pd
import websocket

from backtesting import Strategy
from livetrading import executor
from livetrading.broker import Broker, Pair
from livetrading.config import config


def SMA(arr: pd.Series, n: int) -> pd.Series:
"""
Returns `n`-period simple moving average of array `arr`.
"""
return pd.Series(arr).rolling(n).mean()


class LiveStrategy(Strategy):
n1 = 10
n2 = 20

def __init__(self, broker, data, params):
super().__init__(broker=broker, data=data, params=params)

def init(self):
sma1 = self.I(SMA, self.data.Close, self.n1)
sma2 = self.I(SMA, self.data.Close, self.n2)

def set_atr_periods(self):
if len(self.data) > 1:
print(self.data.High, self.data.Low)

def next(self):
print(self.data)


class PositionManager:
def __init__(self, exchange, position_amount):
assert position_amount > 0
self.exchange = exchange
self.position_amount = position_amount

def on_event(self, bar_event):
# react on event from websocket
pass


if __name__ == '__main__':

websocket.enableTrace(False)

event_dis = executor.EventDispatcher(LiveStrategy)

exchange = Broker(event_dis, config=config)

pair_info = exchange.get_pair_info('BTC-USD')

position_mgr = PositionManager(exchange, 0.8)

strategy = LiveStrategy(exchange, [], {})

exchange.subscribe_to_ticker_events(Pair(base_symbol="UTC", quote_symbol="SDT"),
'3m', position_mgr.on_event)

event_dis.set_strategy(strategy)

event_dis.set_backtesting_partial(cash=100000)

event_dis.run()
37 changes: 37 additions & 0 deletions livetrading/rest_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json
import logging
import requests

from typing import Optional
from urllib.parse import urljoin

logger = logging.getLogger(__name__)


class RestClient:
""""Class for REST API.
:param config: Config settings for exchange.
"""
def __init__(self, config):
self.url = config['api_url']
self.session = requests.Session()
self.session.auth = (config.get('username'), config.get('password'))

def call(self, method, apipath, params: Optional[dict] = None, data=None):

if str(method).upper() not in ('GET', 'POST', 'PUT', 'DELETE'):
raise ValueError(f'invalid method <{method}>')

headers = {"Accept": "application/json",
"Content-Type": "application/json"
}
url = urljoin(self.url, apipath)

try:
resp = self.session.request(method, url, headers=headers, data=json.dumps(data),
params=params)
if resp.status_code == 200:
return resp.json()
return resp.text
except ConnectionError:
logger.warning("Connection error")
63 changes: 63 additions & 0 deletions livetrading/websocket_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
import websocket, json, _thread

from typing import Dict, List, Set

from livetrading.event import EventSource, EventProducer

logger = logging.getLogger(__name__)


class WSClient(EventProducer, websocket.WebSocketApp):
""""Class for channel based web socket clients.
:param config: Config settings for exchange.
"""
def __init__(self, config):
super(WSClient, self).__init__(config['ws_url'])
self.event_sources: Dict[str, EventSource] = {}
self.pending_subscriptions: Set[str] = set()
self.timeout = config['ws_timeout']
self.on_open = lambda ws: self.subscribe_msg()
self.on_message = lambda ws, msg: self.handle_message(json.loads(msg))
self.on_error = lambda ws, e: logger.warning(f"Error: {e}")
self.on_close = self.on_close
self._running = False
self.thread = None

def set_channel_event_source(self, channel: str, event_source: EventSource):
assert channel not in self.event_sources, "channel already registered"
self.event_sources[channel] = event_source
self.pending_subscriptions.add(channel)

def subscribe_msg(self):
self.pending_subscriptions.update(self.event_sources.keys())
channels = list(self.pending_subscriptions)
self.subscribe_to_channels(channels)

def on_close(self):
self.pending_subscriptions = set()

def main(self):
if not self._running:
self.thread = _thread.start_new_thread(self.run_forever, ())
self._running = True

def subscribe_to_channels(
self, channels: List[str]
):
sub_msg = {
"type": "subscribe",
"product_ids": [
"ETH-USD",
"BTC-USD"
],
"channels": channels
}
self.send(json.dumps(sub_msg))
logger.info(f"Subscribed to channels: {channels}")

def handle_message(self, message: dict) -> None:
channel = message.get("type")
event_source = self.event_sources.get(channel)
if event_source:
event_source.push_to_queue(message)
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,10 @@
'numpy >= 1.17.0',
'pandas >= 0.25.0, != 0.25.0',
'bokeh >= 1.4.0',
],
'configloader >= 1.0.1',
'websocket-client >= 1.6.0',
'urllib3 >= 2.0.3'
],
extras_require={
'doc': [
'pdoc3',