-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New working instrument factory for stock instrument.
- Loading branch information
1 parent
93a7cd0
commit 07291ba
Showing
6 changed files
with
382 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# -*- coding: utf-8 -*- | ||
import torch | ||
import pandas as pd | ||
import numpy as np | ||
|
||
from datetime import datetime | ||
from matplotlib import pyplot as plt | ||
|
||
from utils.stock_factory import StockFactory | ||
from utils.instruments import Stock | ||
from utils.config import STOCK_ADMISSIBLE_MODELS | ||
|
||
from models.ar import AR | ||
from models.garch import GARCH | ||
|
||
|
||
|
||
start = datetime.strptime("2005-05-01", r"%Y-%m-%d") | ||
end = datetime.strptime("2023-06-01", r"%Y-%m-%d") | ||
instrument_specification = ["XOM", "GS", "T"] | ||
instrument_factory = StockFactory(tickers=instrument_specification, | ||
start=start, | ||
end=end) | ||
stocks = instrument_factory.build_stocks() | ||
|
||
|
||
|
||
class Simulator: | ||
|
||
def __init__(self, instruments: list): | ||
self._instruments = instruments | ||
self._calibrated = False | ||
|
||
def calibrate(self): | ||
self._calibrate_instruments() | ||
self._calibrate_copula() | ||
self._calibrated = True | ||
|
||
def _calibrate_instruments(self): | ||
for instrument in self._instruments: | ||
if isinstance(instrument, Stock): | ||
self._calibrate_stock(instrument) | ||
|
||
def _calibrate_copula(self): | ||
... | ||
|
||
def _calibrate_stock(self, stock): | ||
current_aic = np.inf | ||
risk_factor = stock.risk_factors[0] | ||
data = risk_factor.price_history.log_returns | ||
|
||
for model_name in STOCK_ADMISSIBLE_MODELS: | ||
if model_name == "AR": | ||
model = AR(data) | ||
model.calibrate() | ||
if model.aic < current_aic: | ||
risk_factor.model = model | ||
|
||
elif model_name == "GARCH": | ||
model = GARCH(data) | ||
model.calibrate() | ||
if model.aic < current_aic: | ||
risk_factor.model = model | ||
|
||
|
||
def run_simulation(time_steps: int) -> torch.Tensor: | ||
# Check for succesful calibration, Throw an error otherwise. | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from utils.instruments import Asset, Stock, RiskFactor | ||
from utils.data_handler import YahooDataHandler | ||
|
||
class StockFactory: | ||
|
||
def __init__(self, tickers, start, end, interval=None): | ||
self._tickers = tickers | ||
self._start = start | ||
self._end = end | ||
self._ydr = YahooDataHandler() | ||
|
||
|
||
def build_stocks(self) -> list[Stock]: | ||
stocks = [] | ||
for ticker in self._tickers: | ||
price_history = self._ydr.get_price_history(ticker, self._start, self._end) | ||
identifier = f"{ticker}_RF" | ||
risk_factor = RiskFactor(identifier, price_history) | ||
stock = Stock(identifier=ticker, risk_factors=[risk_factor]) | ||
stocks.append(stock) | ||
return stocks |