Python 株価を予想する。時系列予測ライブラリ「Prophet」

 Prophetは、Facebookが開発した時系列予測ライブラリです。この予測手法は、時系列データy(t)をトレンド(t)+季節性(t)+イベント(t)+誤差(t)の合成として分析して予測モデルを構築します。つまり、それらをハイパーパラメータとして予測モデルを調整できます。これがニューラルネットワーク(RNN)と比較して扱い易いところです。論文は次のリンク先のREAD THE PAPERから飛ぶことが出来ます。https://facebook.github.io/prophet/

 インストールは、Anaconda環境下で次のようにcondaで出来ます。

conda install -c conda-forge fbprophet

 本記事では、日経平均株価225の2017年12月から2020年12月までの3年間のデータを元に、下図のように2021年3月までの株価を予測する雛形コードを載せました。2月くらいまでは上昇あるいは横ばいで、3月には下落する予想です。ハイパーパラメータの調整によって予測結果は異なります。f:id:HK29:20201206224730p:plain

 株価データは、下図のようなcsvファイルを用います。例えば次の記事のようにヤフーファイナンスから取得できます。Python 株価データの欠損値をその前後の値で補完後、単純移動平均を算出する「pandas」 - PythonとVBAで世の中を便利にする

f:id:HK29:20201206224811p:plain

Prophetでは、下図のように一行でトレンドを可視化出来るなど操作が容易です。

f:id:HK29:20201206225924p:plain

■本プログラム
使用環境はAnaconda PromptのようなCUIで実行できます。JupyterLabがおススメではあります。

#!/usr/bin/env python
# coding: utf-8

# In[1]:


# csvファイルを読み込む import pandas as pd from fbprophet import Prophet df = pd.read_csv('N225_2017-12_2020_12.csv') print(df) # In[2]:
# 列名を日時はds、目的変数をyに変更する。Prophetの仕様のため。 new_df = df.rename(columns={'Date': 'ds', 'Adj Close': 'y'}) print(new_df) # In[3]: # 予測したい数 predict_num = 15 # 訓練データの作成する。全データから予測したい数を引いた数 cnt = len(new_df) - predict_num train_df = new_df[:cnt] print(train_df) # In[4]: # モデルの作成 model = Prophet( growth='linear', # 傾向変動の関数.非線形は'logistic' yearly_seasonality = True, # 年次の季節変動を考慮有無 weekly_seasonality = False, # 週次の季節変動を考慮有無 daily_seasonality = False, # 日次の季節変動を考慮有無 changepoints = None, # 傾向変化点のリスト changepoint_range = 0.85, # 傾向変化点の候補の幅で先頭からの割合。 changepoint_prior_scale = 0.5, # 傾向変化点の事前分布のスケール値。パラメータの柔軟性 n_changepoints = 5, # 傾向変化点の数 ) model.fit(train_df) # In[5]: # 学習データに予測したい期間を追加する future = model.make_future_dataframe(periods = predict_num) #,freq='M') print(future) # In[6]: # 予測する forecast = model.predict(future) print(forecast) # In[7]: # 可視化する import matplotlib.pyplot as plt from fbprophet.plot import add_changepoints_to_plot forecast['cap'] = 23000 #forecast['floor'] = 21000 fig1 = model.plot(forecast) a = add_changepoints_to_plot(fig1.gca(), model, forecast) plt.show() # In[8]: # 規則性を可視化 fig2 = model.plot_components(forecast) plt.show() # In[9]: import seaborn as sns # 変化率を追記 df3 = train_df.loc[model.changepoints.index] df3['delta'] = model.params['delta'].ravel() # 変化点を取得 df3['ds'] = df3['ds'].astype(str) df3['delta'] = df3['delta'].round(3) df4 = df3[df3['delta'] != 0] print(df4) # In[10]: # 変化点をグラフ化 import matplotlib.pyplot as plt ax = sns.factorplot(x = 'ds', y = 'delta', data = df4, color='magenta') ax.set_xticklabels(rotation=90) plt.grid() plt.show() # 変化点をリストで抽出 ds_list = df4['ds'].tolist() print(ds_list) # In[11]: # 変化点を指定してモデルの作成 model2 = Prophet( growth='linear', # 傾向変動の関数.非線形は'logistic' yearly_seasonality = True, # 年次の季節変動を考慮有無 weekly_seasonality = False, # 週次の季節変動を考慮有無 daily_seasonality = False, # 日次の季節変動を考慮有無 changepoints = ds_list, # 傾向変化点のリスト #changepoint_range = 0.85, # 傾向変化点の候補の幅で先頭からの割合。 #changepoint_prior_scale = 0.5, # 傾向変化点の事前分布のスケール値。パラメータの柔軟性 #n_changepoints = 5, # 傾向変化点の数 ) model2.fit(train_df) # In[12]: # 学習データに予測したい期間を追加する future2 = model2.make_future_dataframe(periods = predict_num) #,freq='M') print(future2) # In[13]: # 予測2 forecast2 = model2.predict(future2) print(forecast2) # In[14]: # 可視化 forecast2['cap'] = 23000 #forecast['floor'] = 21000 fig2_1 = model2.plot(forecast2) b = add_changepoints_to_plot(fig2_1.gca(), model2, forecast) plt.show() # In[15]: # 規則性を可視化 fig2_2 = model2.plot_components(forecast2) plt.show() # In[16]: # モデルを保存する import json from fbprophet.serialize import model_to_json, model_from_json with open('serialized_model.json', 'w') as fout: json.dump(model_to_json(model2), fout) # In[17]: # モデルをロードする import json from fbprophet.serialize import model_to_json, model_from_json with open('serialized_model.json', 'r') as fin: model3 = model_from_json(json.load(fin)) # In[18]: # (更に先の2021年3月頃の)未来を予測する import fbprophet.plot as fp future3 = model3.make_future_dataframe(periods=120) forecast3 = model3.predict(future3) fig3 = model3.plot(forecast3) fp.add_changepoints_to_plot(fig3.gca(), model3, forecast3); plt.show() # In[ ]:

(参考)再帰型ニューラルネットワーク(RNN)による株価予想は次のリンク参照

hk29.hatenablog.jp

以上

<広告>