Prophetは、Facebookが開発した時系列予測ライブラリです。この予測手法は、時系列データy(t)をトレンド(t)+季節性(t)+イベント(t)+誤差(t)の合成として分析して予測モデルを構築します。つまり、それらをハイパーパラメータとして予測モデルを調整できます。これがニューラルネットワーク(RNN)と比較して扱い易いところです。論文は次のリンク先のREAD THE PAPERから飛ぶことが出来ます。https://facebook.github.io/prophet/
インストールは、Anaconda環境下で次のようにcondaで出来ます。
conda install -c conda-forge fbprophet
本記事では、日経平均株価225の2017年12月から2020年12月までの3年間のデータを元に、下図のように2021年3月までの株価を予測する雛形コードを載せました。2月くらいまでは上昇あるいは横ばいで、3月には下落する予想です。ハイパーパラメータの調整によって予測結果は異なります。
株価データは、下図のようなcsvファイルを用います。例えば次の記事のようにヤフーファイナンスから取得できます。Python 株価データの欠損値をその前後の値で補完後、単純移動平均を算出する「pandas」 - PythonとVBAで世の中を便利にする
Prophetでは、下図のように一行でトレンドを可視化出来るなど操作が容易です。
■本プログラム
使用環境はAnaconda PromptのようなCUIで実行できます。JupyterLabがおススメではあります。
import pandas as pd
from fbprophet import Prophet
df = pd.read_csv('N225_2017-12_2020_12.csv')
print(df)
new_df = df.rename(columns={'Date': 'ds', 'Adj Close': 'y'})
print(new_df)
predict_num = 15
cnt = len(new_df) - predict_num
train_df = new_df[:cnt]
print(train_df)
model = Prophet(
growth='linear',
yearly_seasonality = True,
weekly_seasonality = False,
daily_seasonality = False,
changepoints = None,
changepoint_range = 0.85,
changepoint_prior_scale = 0.5,
n_changepoints = 5,
)
model.fit(train_df)
future = model.make_future_dataframe(periods = predict_num)
print(future)
forecast = model.predict(future)
print(forecast)
import matplotlib.pyplot as plt
from fbprophet.plot import add_changepoints_to_plot
forecast['cap'] = 23000
fig1 = model.plot(forecast)
a = add_changepoints_to_plot(fig1.gca(), model, forecast)
plt.show()
fig2 = model.plot_components(forecast)
plt.show()
import seaborn as sns
df3 = train_df.loc[model.changepoints.index]
df3['delta'] = model.params['delta'].ravel()
df3['ds'] = df3['ds'].astype(str)
df3['delta'] = df3['delta'].round(3)
df4 = df3[df3['delta'] != 0]
print(df4)
import matplotlib.pyplot as plt
ax = sns.factorplot(x = 'ds', y = 'delta', data = df4, color='magenta')
ax.set_xticklabels(rotation=90)
plt.grid()
plt.show()
ds_list = df4['ds'].tolist()
print(ds_list)
model2 = Prophet(
growth='linear',
yearly_seasonality = True,
weekly_seasonality = False,
daily_seasonality = False,
changepoints = ds_list,
)
model2.fit(train_df)
future2 = model2.make_future_dataframe(periods = predict_num)
print(future2)
forecast2 = model2.predict(future2)
print(forecast2)
forecast2['cap'] = 23000
fig2_1 = model2.plot(forecast2)
b = add_changepoints_to_plot(fig2_1.gca(), model2, forecast)
plt.show()
fig2_2 = model2.plot_components(forecast2)
plt.show()
import json
from fbprophet.serialize import model_to_json, model_from_json
with open('serialized_model.json', 'w') as fout:
json.dump(model_to_json(model2), fout)
import json
from fbprophet.serialize import model_to_json, model_from_json
with open('serialized_model.json', 'r') as fin:
model3 = model_from_json(json.load(fin))
import fbprophet.plot as fp
future3 = model3.make_future_dataframe(periods=120)
forecast3 = model3.predict(future3)
fig3 = model3.plot(forecast3)
fp.add_changepoints_to_plot(fig3.gca(), model3, forecast3);
plt.show()
(参考)再帰型ニューラルネットワーク(RNN)による株価予想は次のリンク参照
hk29.hatenablog.jp
以上
<広告>
リンク