Skip to content

Commit

Permalink
Added code to notebook and restructured it.
Browse files Browse the repository at this point in the history
  • Loading branch information
IanLKaplan committed Aug 18, 2022
1 parent c51667d commit d95ad6c
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 32 deletions.
147 changes: 119 additions & 28 deletions pairs_trading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,32 +203,27 @@
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"1 Failed download:\n",
"- SJM: No data found for this date range, symbol may be delisted\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"from datetime import datetime\n",
"from datetime import timedelta\n",
"from typing import List, Tuple\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import yfinance as yf\n",
"from matplotlib import pyplot as plt\n",
"from numpy import log\n",
"from tabulate import tabulate\n",
"\n",
"s_and_p_file = 's_and_p_sector_components/sp_stocks.csv'\n",
"s_and_p_data = 's_and_p_data'\n",
"start_date_str = '2007-01-03'\n",
"start_date: datetime = datetime.fromisoformat(start_date_str)\n",
"\n",
"trading_days = 252\n",
"\n",
"\n",
"def convert_date(some_date):\n",
" if type(some_date) == str:\n",
Expand Down Expand Up @@ -337,12 +332,13 @@
" last_date = convert_date(last_row.index[0])\n",
" if last_date.date() < self.end_date.date():\n",
" sym_start_date = last_date + timedelta(days=1)\n",
" new_data_df = self.get_market_data(symbol, sym_start_date, self.end_date)\n",
" symbol_df = pd.concat([symbol_df, new_data_df], axis=0)\n",
" ix = symbol_df.index\n",
" ix = pd.to_datetime(ix)\n",
" symbol_df.index = ix\n",
" symbol_df.to_csv(file_path)\n",
" new_data_df = self.get_market_data(symbol, sym_start_date, datetime.today())\n",
" if new_data_df.shape[0] > 0:\n",
" symbol_df = pd.concat([symbol_df, new_data_df], axis=0)\n",
" ix = symbol_df.index\n",
" ix = pd.to_datetime(ix)\n",
" symbol_df.index = ix\n",
" symbol_df.to_csv(file_path)\n",
" else:\n",
" symbol_df = self.get_market_data(symbol, self.start_date, self.end_date)\n",
" if symbol_df.shape[0] > 0:\n",
Expand Down Expand Up @@ -371,19 +367,12 @@
"market_data = MarketData(start_date, s_and_p_data)\n",
"stock_l: list = list(set(stock_info_df['Symbol']))\n",
"stock_l.sort()\n",
"# The stocks in close_prices_df may not include the entire set of stocks in stock_l since there\n",
"# may be stocks that went public after the start date in the backtest.\n",
"# The columns of close_prices_df are the stock symbols, the rows are the close prices\n",
"# A AAL AAP AAPL ABC ... XRAY YUM ZBH ZBRA ZION\n",
"# Date ...\n",
"# 2007-01-03 24.54 56.3 35.58 2.99 23.06 ... 29.99 21.16 74.85 34.88 82.91\n",
"close_prices_df = market_data.get_close_data(stock_l)\n",
"final_stock_list = list(close_prices_df.columns)\n",
"mask = stock_info_df['Symbol'].isin(final_stock_list)\n",
"final_stock_info_df = stock_info_df[mask]\n",
"# sectors is a dictionary where the keys are the sector names (e.g., 'energies') The values are the stock symbols in that\n",
"# sector.\n",
"sectors: dict = extract_sectors(final_stock_info_df)\n",
"\n",
"sectors = extract_sectors(final_stock_info_df)\n",
"pairs_info_df = calc_pair_counts(sectors)\n"
]
},
Expand Down Expand Up @@ -458,15 +447,117 @@
"source": [
"<h3>\n",
"Correlation\n",
"</h3>\n"
"</h3>\n",
"<p>\n",
"\n",
"</p>\n"
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"source": [
"\n",
"\n",
"def get_pairs(sector_info: dict) -> List[Tuple]:\n",
" \"\"\"\n",
" Return all of the stock pairs, where the pairs are selected from the S&P 500 sector.\n",
"\n",
" :param sector_info: A dictionary containing the sector info. For example:\n",
" energies': ['APA', 'BKR', 'COP', ...]\n",
" Here 'energies' is the dictionary key for the list of S&P 500 stocks in that sector.\n",
" :return: A list of Tuples, where each tuple contains the symbols for the stock pair and the sector.\n",
" For example:\n",
" [('AAPL', 'ACN', 'information-technology'),\n",
" ('AAPL', 'ADBE', 'information-technology'),\n",
" ('AAPL', 'ADI', 'information-technology'),\n",
" ('AAPL', 'ADP', 'information-technology'),\n",
" ('AAPL', 'ADSK', 'information-technology')]\n",
" \"\"\"\n",
" pairs_list = list()\n",
" sectors = list(sector_info.keys())\n",
" for sector in sectors:\n",
" stocks = sector_info[sector]\n",
" num_stocks = len(stocks)\n",
" for i in range(num_stocks):\n",
" stock_a = stocks[i]\n",
" for j in range(i + 1, num_stocks):\n",
" stock_b = stocks[j]\n",
" pairs_list.append((stock_a, stock_b, sector))\n",
" return pairs_list\n",
"\n",
"\n",
"def calc_pairs_correlation(stock_close_df: pd.DataFrame, pair: Tuple, window: int, all_cor_v: np.array) -> np.array:\n",
" cor_v = np.zeros(0)\n",
" stock_a = pair[0]\n",
" stock_b = pair[1]\n",
" a_close = stock_close_df[stock_a]\n",
" b_close = stock_close_df[stock_b]\n",
" a_log_close = log(a_close)\n",
" b_log_close = log(b_close)\n",
" assert len(a_log_close) == len(b_log_close)\n",
" for i in range(0, len(a_log_close), window):\n",
" sec_a = a_log_close[i:i + window]\n",
" sec_b = b_log_close[i:i + window]\n",
" c = np.corrcoef(sec_a, sec_b)\n",
" cor_v = np.append(cor_v, c[0, 1])\n",
" return cor_v\n",
"\n",
"\n",
"def calc_yearly_correlation(stock_close_df: pd.DataFrame, pairs_list: List[Tuple]) -> np.array:\n",
" all_cor_v = np.zeros(0)\n",
" for pair in pairs_list:\n",
" cor_v: np.array = calc_pairs_correlation(stock_close_df, pair, trading_days, all_cor_v)\n",
" all_cor_v = np.append(all_cor_v, cor_v)\n",
" return all_cor_v\n",
"\n",
"\n",
"def display_histogram(data_v: np.array, x_label: str, y_label: str) -> None:\n",
" num_bins = int(np.sqrt(data_v.shape[0])) * 4\n",
" fix, ax = plt.subplots(figsize=(10, 8))\n",
" ax.set_xlabel(x_label)\n",
" ax.set_ylabel(y_label)\n",
" ax.grid(True)\n",
" ax.hist(data_v, bins=num_bins, facecolor='b')\n",
" ax.axvline(x=np.mean(data_v), color='black')\n",
" plt.show()\n",
"\n",
"\n",
"pairs_list = get_pairs(sectors)\n",
"yearly_cor_a = calc_yearly_correlation(close_prices_df, pairs_list)\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"<p>\n",
"The histogram below shows the distribution of the yearly correlation between the pairs.\n",
"</p>"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"\n",
"\n",
"display_histogram(yearly_cor_a, 'Correlation between pairs', 'Count')\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
Expand Down
18 changes: 14 additions & 4 deletions pairs_trading.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ def get_close_data(self, stock_list: list) -> pd.DataFrame:


def get_pairs(sector_info: dict) -> List[Tuple]:
"""
Return all of the stock pairs, where the pairs are selected from the S&P 500 sector.
:param sector_info: A dictionary containing the sector info. For example:
energies': ['APA', 'BKR', 'COP', ...]
Here 'energies' is the dictionary key for the list of S&P 500 stocks in that sector.
:return: A list of Tuples, where each tuple contains the symbols for the stock pair and the sector.
For example:
[('AAPL', 'ACN', 'information-technology'),
('AAPL', 'ADBE', 'information-technology'),
('AAPL', 'ADI', 'information-technology'),
('AAPL', 'ADP', 'information-technology'),
('AAPL', 'ADSK', 'information-technology')]
"""
pairs_list = list()
sectors = list(sector_info.keys())
for sector in sectors:
Expand Down Expand Up @@ -210,10 +224,6 @@ def calc_yearly_correlation(stock_close_df: pd.DataFrame, pairs_list: List[Tuple
return all_cor_v


def parallel_yearly_correlation(stock_close_df: pd.DataFrame, pairs_list: List[Tuple]) -> np.array:
pass


def display_histogram(data_v: np.array, x_label: str, y_label: str) -> None:
num_bins = int(np.sqrt(data_v.shape[0])) * 4
fix, ax = plt.subplots(figsize=(10, 8))
Expand Down

0 comments on commit d95ad6c

Please sign in to comment.