Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/data_collector/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Scripts for data collection

- yahoo: get *US/CN* stock data from *Yahoo Finance*
- yahoo: get *US/CN/IN/BR/JP* stock data from *Yahoo Finance*
- fund: get fund data from *http://fund.eastmoney.com*
- cn_index: get *CN index* from *http://www.csindex.com.cn*, *CSI300*/*CSI100*
- us_index: get *US index* from *https://en.wikipedia.org/wiki*, *SP500*/*NASDAQ100*/*DJIA*/*SP400*
Expand Down Expand Up @@ -57,4 +57,4 @@ Scripts for data collection
| Component | required data |
|---------------------------------------------------|--------------------------------|
| Data retrieval | Features, Calendar, Instrument |
| Backtest | **Features[Price/Volume]**, Calendar, Instruments |
| Backtest | **Features[Price/Volume]**, Calendar, Instruments |
130 changes: 123 additions & 7 deletions scripts/data_collector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import pickle
import requests
import functools
from io import BytesIO
from pathlib import Path
from typing import Iterable, Tuple, List
from typing import Iterable, Tuple, List, Optional

import numpy as np
import pandas as pd
Expand All @@ -36,28 +37,39 @@
"US_ALL": "^GSPC",
"IN_ALL": "^NSEI",
"BR_ALL": "^BVSP",
"JP_ALL": "^N225",
}

JPX_LISTED_COMPANIES_URL = "https://www.jpx.co.jp/markets/statistics-equities/misc/tvdivq0000001vg2-att/data_j.xls"

_BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
_US_SYMBOLS = None
_IN_SYMBOLS = None
_BR_SYMBOLS = None
_JP_SYMBOLS = None
_EN_FUND_SYMBOLS = None
_CALENDAR_MAP = {}

# NOTE: Until 2020-10-20 20:00:00
MINIMUM_SYMBOLS_NUM = 3900


def _normalize_calendar_timestamp(value) -> pd.Timestamp:
ts = pd.Timestamp(value)
if ts.tzinfo is not None:
ts = ts.tz_localize(None)
return ts.normalize()


def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
"""get SH/SZ history calendar list

Parameters
----------
bench_code: str
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
value from ["CSI300", "CSI500", "ALL", "US_ALL", "IN_ALL", "BR_ALL", "JP_ALL"]

Returns
-------
Expand All @@ -72,11 +84,15 @@ def _get_calendar(url):

calendar = _CALENDAR_MAP.get(bench_code, None)
if calendar is None:
if bench_code.startswith("US_") or bench_code.startswith("IN_") or bench_code.startswith("BR_"):
print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]))
print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max"))
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
if (
bench_code.startswith("US_")
or bench_code.startswith("IN_")
or bench_code.startswith("BR_")
or bench_code.startswith("JP_")
):
_ticker = Ticker(CALENDAR_BENCH_URL_MAP[bench_code])
df = _ticker.history(interval="1d", period="max")
calendar = sorted({_normalize_calendar_timestamp(v) for v in df.index.get_level_values(level="date")})
else:
if bench_code.upper() == "ALL":
import akshare as ak # pylint: disable=C0415
Expand Down Expand Up @@ -448,6 +464,106 @@ def _format(s_):
return _BR_SYMBOLS


def _normalize_jpx_column_name(col_name: str) -> str:
return str(col_name).replace(" ", "").replace("\u3000", "").replace("\n", "").strip().lower()


def _find_jpx_column(columns: list, exact_candidates: list, keyword_candidates: list) -> Optional[str]:
normalized_map = {col: _normalize_jpx_column_name(col) for col in columns}
exact_candidates = {_normalize_jpx_column_name(col) for col in exact_candidates}
keyword_candidates = [_normalize_jpx_column_name(col) for col in keyword_candidates]

for _col, _normalized_col in normalized_map.items():
if _normalized_col in exact_candidates:
return _col

for _col, _normalized_col in normalized_map.items():
if all(_keyword in _normalized_col for _keyword in keyword_candidates):
return _col

return None


def _extract_jp_prime_symbols(df: pd.DataFrame) -> list:
if df is None or df.empty:
raise ValueError("JPX listed companies file is empty")

code_col = _find_jpx_column(
columns=df.columns.tolist(),
exact_candidates=["コード", "銘柄コード", "code", "securitycode"],
keyword_candidates=["コード"],
)
if code_col is None:
raise ValueError("Unable to find stock code column in JPX listed companies file")

market_col = _find_jpx_column(
columns=df.columns.tolist(),
exact_candidates=["市場・商品区分", "市場商品区分", "市場区分", "marketsegment"],
keyword_candidates=["市場", "区分"],
)
if market_col is None:
raise ValueError("Unable to find market classification column in JPX listed companies file")

domestic_col = _find_jpx_column(
columns=df.columns.tolist(),
exact_candidates=["内外株式区分", "内外区分", "domesticforeign"],
keyword_candidates=["内外", "区分"],
)

market_series = df[market_col].astype(str)
prime_mask = market_series.str.contains("プライム", na=False)

if market_series.str.contains("内国株式", na=False).any():
domestic_mask = market_series.str.contains("内国株式", na=False)
elif domestic_col is not None:
domestic_mask = df[domestic_col].astype(str).str.contains("内国株式", na=False)
else:
domestic_mask = market_series.str.contains("内国株式", na=False)

target_df = df.loc[prime_mask & domestic_mask, [code_col]].copy()
if target_df.empty:
raise ValueError("No JPX Prime domestic stocks found in listed companies file")

symbols = (
target_df[code_col]
.astype(str)
.str.extract(r"(\d{4})", expand=False)
.dropna()
.apply(lambda code: f"{code}.T")
.drop_duplicates()
.sort_values()
.tolist()
)
if not symbols:
raise ValueError("No valid JP stock symbols extracted from JPX listed companies file")
return symbols


def get_jp_stock_symbols() -> list:
"""get JP Prime (domestic stock) symbols"""

global _JP_SYMBOLS # pylint: disable=W0603

@deco_retry
def _get_jpx_listed_companies_df():
resp = requests.get(JPX_LISTED_COMPANIES_URL, timeout=None)
if resp.status_code != 200:
raise ValueError(f"request error, status_code={resp.status_code}")
try:
return pd.read_excel(BytesIO(resp.content), dtype=str)
except Exception as excel_error:
try:
return pd.read_html(BytesIO(resp.content))[0].astype(str)
except Exception as html_error:
raise ValueError(
f"failed to parse JPX listed companies file: excel_error={excel_error}, html_error={html_error}"
) from html_error

if _JP_SYMBOLS is None:
_JP_SYMBOLS = _extract_jp_prime_symbols(_get_jpx_listed_companies_df())
return _JP_SYMBOLS


def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
"""get en fund symbols

Expand Down
14 changes: 11 additions & 3 deletions scripts/data_collector/yahoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pip install -r requirements.txt
### Collector *YahooFinance* data to qlib
> collector *YahooFinance* data and *dump* into `qlib` format.
> If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data.
> For `region=JP`, the symbol universe is **TSE Prime (domestic stocks)** from JPX listed companies file.
1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data`

This will download the raw data such as high, low, open, close, adjclose price from yahoo to a local directory. One file per symbol.
Expand All @@ -63,7 +64,8 @@ pip install -r requirements.txt
- `source_dir`: save the directory
- `interval`: `1d` or `1min`, by default `1d`
> **due to the limitation of the *YahooFinance API*, only the last month's data is available in `1min`**
- `region`: `CN` or `US` or `IN` or `BR`, by default `CN`
- `region`: `CN` or `US` or `IN` or `BR` or `JP`, by default `CN`
> `JP` supports `1d` only
- `delay`: `time.sleep(delay)`, by default *0.5*
- `start`: start datetime, by default *"2000-01-01"*; *closed interval(including start)*
- `end`: end datetime, by default `pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`; *open interval(excluding end)*
Expand Down Expand Up @@ -92,6 +94,9 @@ pip install -r requirements.txt
python collector.py download_data --source_dir ~/.qlib/stock_data/source/br_data --start 2003-01-03 --end 2022-03-01 --delay 1 --interval 1d --region BR
# br 1min data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/br_data_1min --delay 1 --interval 1min --region BR

# jp 1d data (TSE Prime domestic stocks)
python collector.py download_data --source_dir ~/.qlib/stock_data/source/jp_data --start 2020-01-01 --end 2020-12-31 --delay 1 --interval 1d --region JP
```
2. normalize data: `python scripts/data_collector/yahoo/collector.py normalize_data`

Expand All @@ -105,7 +110,8 @@ pip install -r requirements.txt
- `max_workers`: number of concurrent, by default *1*
- `interval`: `1d` or `1min`, by default `1d`
> if **`interval == 1min`**, `qlib_data_1d_dir` cannot be `None`
- `region`: `CN` or `US` or `IN`, by default `CN`
- `region`: `CN` or `US` or `IN` or `BR` or `JP`, by default `CN`
> `JP` supports `1d` only
- `date_field_name`: column *name* identifying time in csv files, by default `date`
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
Expand Down Expand Up @@ -133,6 +139,9 @@ pip install -r requirements.txt

# normalize 1min br
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/br_data --source_dir ~/.qlib/stock_data/source/br_data_1min --normalize_dir ~/.qlib/stock_data/source/br_1min_nor --region BR --interval 1min

# normalize 1d jp
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/jp_data --normalize_dir ~/.qlib/stock_data/source/jp_1d_nor --region JP --interval 1d
```
3. dump data: `python scripts/dump_bin.py dump_all`

Expand Down Expand Up @@ -222,4 +231,3 @@ pip install -r requirements.txt
# get all symbol data
# df = D.features(D.instruments("all"), ["$close"], freq="1min")
```

68 changes: 66 additions & 2 deletions scripts/data_collector/yahoo/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
get_us_stock_symbols,
get_in_stock_symbols,
get_br_stock_symbols,
get_jp_stock_symbols,
generate_minutes_calendar_from_daily,
calc_adjusted_price,
)
Expand Down Expand Up @@ -364,6 +365,33 @@ class YahooCollectorBR1min(YahooCollectorBR):
retry = 2


class YahooCollectorJP(YahooCollector, ABC):
def get_instrument_list(self):
logger.info("get JP Prime (domestic stock) symbols......")
symbols = get_jp_stock_symbols()
logger.info(f"get {len(symbols)} symbols.")
return symbols

def download_index_data(self):
pass

def normalize_symbol(self, symbol):
return code_to_fname(symbol).upper()

@property
def _timezone(self):
return "Asia/Tokyo"


class YahooCollectorJP1d(YahooCollectorJP):
pass


class YahooCollectorJP1min(YahooCollectorJP):
def __init__(self, *args, **kwargs):
raise ValueError("JP region does not support 1min data collection")


class YahooNormalize(BaseNormalize):
COLUMNS = ["open", "close", "high", "low", "volume"]
DAILY_FORMAT = "%Y-%m-%d"
Expand Down Expand Up @@ -720,6 +748,27 @@ def symbol_to_yahoo(self, symbol):
return fname_to_code(symbol)


class YahooNormalizeJP:
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
return get_calendar_list("JP_ALL")


class YahooNormalizeJP1d(YahooNormalizeJP, YahooNormalize1d):
pass


class YahooNormalizeJP1dExtend(YahooNormalizeJP, YahooNormalize1dExtend):
pass


class YahooNormalizeJP1min(YahooNormalizeJP, YahooNormalize1min):
def __init__(self, *args, **kwargs):
raise ValueError("JP region does not support 1min normalization")

def symbol_to_yahoo(self, symbol):
return fname_to_code(symbol)


class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
"""
Expand All @@ -735,11 +784,15 @@ def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval=
interval: str
freq, value from [1min, 1d], default 1d
region: str
region, value from ["CN", "US", "BR"], default "CN"
region, value from ["CN", "US", "IN", "BR", "JP"], default "CN"
"""
super().__init__(source_dir, normalize_dir, max_workers, interval)
self.region = region

def _validate_region_interval(self):
if self.region.upper() == "JP" and self.interval.lower() == "1min":
raise ValueError("JP region does not support 1min data")

@property
def collector_class_name(self):
return f"YahooCollector{self.region.upper()}{self.interval}"
Expand Down Expand Up @@ -792,6 +845,7 @@ def download_data(
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
self._validate_region_interval()
if self.interval == "1d" and pd.Timestamp(end) > pd.Timestamp(datetime.datetime.now().strftime("%Y-%m-%d")):
raise ValueError(f"end_date: {end} is greater than the current date.")

Expand Down Expand Up @@ -828,6 +882,7 @@ def normalize_data(
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
"""
self._validate_region_interval()
if self.interval.lower() == "1min":
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
raise ValueError(
Expand Down Expand Up @@ -937,6 +992,7 @@ def update_data_to_bin(
check_data_length: int = None,
delay: float = 1,
exists_skip: bool = False,
limit_nums: int = None,
):
"""update yahoo data to bin

Expand All @@ -953,6 +1009,8 @@ def update_data_to_bin(
time.sleep(delay), default 1
exists_skip: bool
exists skip, by default False
limit_nums: int
using for debug, by default None
Notes
-----
If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day
Expand Down Expand Up @@ -981,7 +1039,13 @@ def update_data_to_bin(

# download data from yahoo
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length)
self.download_data(
delay=delay,
start=trading_date,
end=end_date,
check_data_length=check_data_length,
limit_nums=limit_nums,
)
# NOTE: a larger max_workers setting here would be faster
self.max_workers = (
max(multiprocessing.cpu_count() - 2, 1)
Expand Down
Loading