Open In Colab

New in nnetsauce v0.16.3:

DeepMTS in nnetsauce v0.16.3 for Multivariate time series (MTS)

Contents

  • 1 - Install
  • 2 - DeepMTS

1 - Install

!pip uninstall nnetsauce --yes
Found existing installation: nnetsauce 0.16.3
Uninstalling nnetsauce-0.16.3:
  Successfully uninstalled nnetsauce-0.16.3
!pip install git+https://github.com/Techtonique/nnetsauce.git --upgrade --no-cache-dir
Collecting git+https://github.com/Techtonique/nnetsauce.git
  Cloning https://github.com/Techtonique/nnetsauce.git to /tmp/pip-req-build-2fy08xrz
  Running command git clone --filter=blob:none --quiet https://github.com/Techtonique/nnetsauce.git /tmp/pip-req-build-2fy08xrz
  Resolved https://github.com/Techtonique/nnetsauce.git to commit e99ea1404604dc282576abc610b44c490cd8b598
  Preparing metadata (setup.py) ... [?25l[?25hdone
Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (1.3.2)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (3.7.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (1.23.5)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (1.5.3)
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (1.11.4)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (1.2.2)
Requirement already satisfied: threadpoolctl in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (3.2.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (4.66.1)
Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (0.4.23)
Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from nnetsauce==0.16.3) (0.4.23+cuda12.cudnn89)
Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax->nnetsauce==0.16.3) (0.2.0)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->nnetsauce==0.16.3) (3.3.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nnetsauce==0.16.3) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nnetsauce==0.16.3) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nnetsauce==0.16.3) (4.47.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nnetsauce==0.16.3) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nnetsauce==0.16.3) (23.2)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nnetsauce==0.16.3) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nnetsauce==0.16.3) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->nnetsauce==0.16.3) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->nnetsauce==0.16.3) (2023.3.post1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->nnetsauce==0.16.3) (1.16.0)
Building wheels for collected packages: nnetsauce
  Building wheel for nnetsauce (setup.py) ... [?25l[?25hdone
  Created wheel for nnetsauce: filename=nnetsauce-0.16.3-py2.py3-none-any.whl size=152402 sha256=10d081174d14ad5b6af07273a895e85fa0ff28527ec2a27db90aff43102e47f5
  Stored in directory: /tmp/pip-ephem-wheel-cache-a9o14nt9/wheels/18/d7/31/2518e2b1957d1fbc99b30e79e99976579d956e031b45f61794
Successfully built nnetsauce
Installing collected packages: nnetsauce
Successfully installed nnetsauce-0.16.3
#!pip install nnetsauce==0.16.2 --upgrade --no-cache-dir
!pip install statsmodels
Requirement already satisfied: statsmodels in /usr/local/lib/python3.10/dist-packages (0.14.1)
Requirement already satisfied: numpy<2,>=1.18 in /usr/local/lib/python3.10/dist-packages (from statsmodels) (1.23.5)
Requirement already satisfied: scipy!=1.9.2,>=1.4 in /usr/local/lib/python3.10/dist-packages (from statsmodels) (1.11.4)
Requirement already satisfied: pandas!=2.1.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from statsmodels) (1.5.3)
Requirement already satisfied: patsy>=0.5.4 in /usr/local/lib/python3.10/dist-packages (from statsmodels) (0.5.6)
Requirement already satisfied: packaging>=21.3 in /usr/local/lib/python3.10/dist-packages (from statsmodels) (23.2)
Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas!=2.1.0,>=1.0->statsmodels) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas!=2.1.0,>=1.0->statsmodels) (2023.3.post1)
Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from patsy>=0.5.4->statsmodels) (1.16.0)

import nnetsauce as ns
import numpy as np
import pandas as pd
from sklearn.linear_model import ElasticNetCV, LassoCV
from sklearn.metrics import mean_squared_error
import statsmodels.api as sm
from statsmodels.tsa.base.datetools import dates_from_str

2 - DeepMTS

Macro data

# some example data
mdata = sm.datasets.macrodata.load_pandas().data

# prepare the dates index
dates = mdata[['year', 'quarter']].astype(int).astype(str)

quarterly = dates["year"] + "Q" + dates["quarter"]

quarterly = dates_from_str(quarterly)

print(mdata.head())

#mdata = mdata[['realgdp','realcons','realinv', 'realgovt',
#               'realdpi', 'cpi', 'm1', 'tbilrate', 'unemp',
#               'pop']]

mdata = mdata[['realgovt', 'tbilrate', 'cpi']]

mdata.index = pd.DatetimeIndex(quarterly)

data = np.log(mdata).diff().dropna()

#data = mdata

display(data)
     year  quarter  realgdp  realcons  realinv  realgovt  realdpi   cpi  \
0 1959.00     1.00  2710.35   1707.40   286.90    470.05  1886.90 28.98   
1 1959.00     2.00  2778.80   1733.70   310.86    481.30  1919.70 29.15   
2 1959.00     3.00  2775.49   1751.80   289.23    491.26  1916.40 29.35   
3 1959.00     4.00  2785.20   1753.70   299.36    484.05  1931.30 29.37   
4 1960.00     1.00  2847.70   1770.50   331.72    462.20  1955.50 29.54   

      m1  tbilrate  unemp    pop  infl  realint  
0 139.70      2.82   5.80 177.15  0.00     0.00  
1 141.70      3.08   5.10 177.83  2.34     0.74  
2 140.50      3.82   5.30 178.66  2.74     1.09  
3 140.00      4.33   5.60 179.39  0.27     4.06  
4 139.60      3.50   5.20 180.01  2.31     1.19  
realgovt tbilrate cpi
1959-06-30 0.02 0.09 0.01
1959-09-30 0.02 0.22 0.01
1959-12-31 -0.01 0.13 0.00
1960-03-31 -0.05 -0.21 0.01
1960-06-30 -0.00 -0.27 0.00
... ... ... ...
2008-09-30 0.03 -0.40 -0.01
2008-12-31 0.02 -2.28 -0.02
2009-03-31 -0.01 0.61 0.00
2009-06-30 0.03 -0.20 0.01
2009-09-30 0.02 -0.41 0.01

202 rows × 3 columns

n = data.shape[0]
max_idx_train = np.floor(n*0.8)
training_index = np.arange(0, max_idx_train)
testing_index = np.arange(max_idx_train, n)
df_train = data.iloc[training_index,:]
df_test = data.iloc[testing_index,:]

# Adjust ElasticNetCV
regr7 = ElasticNetCV()
obj_MTS2 = ns.DeepMTS(regr7,
                     n_layers=3,
                     lags = 4,
                     n_hidden_features=5,
                     replications=10,
                     kernel='gaussian',
                     verbose = 1)
obj_MTS2.fit(df_train)
res4 = obj_MTS2.predict(h=len(testing_index))
 Adjusting DeepRegressor to multivariate time series... 
 


100%|██████████| 3/3 [00:02<00:00,  1.02it/s]



 Simulate residuals using gaussian kernel... 


 Best parameters for gaussian kernel: {'bandwidth': 0.022335377063851233} 



100%|██████████| 10/10 [00:00<00:00, 1582.64it/s]
100%|██████████| 10/10 [00:00<00:00, 3664.43it/s]
obj_MTS2.plot("realgovt", type_plot="pi")
obj_MTS2.plot("tbilrate", type_plot="pi")
obj_MTS2.plot("cpi", type_plot="pi")

xxx

xxx

xxx

obj_MTS2.plot("realgovt", type_plot = "spaghetti")
obj_MTS2.plot("tbilrate", type_plot = "spaghetti")
obj_MTS2.plot("cpi", type_plot = "spaghetti")

xxx

xxx

xxx

Comments powered by Talkyard.