'21/03/03更新:浮動小数点以外の型の場合の制約条件「カテゴリ変数、整数、対数、離散値」の設定例を追記しました。
本記事では、Optunaを使った多目的最適化の雛形コードを載せました。 Optunaは、オープンソースの機械学習モデルのハイパーパラメータを自動最適化するフレームワークです。しかし、現在の最新ver2.3.0では多目的最適化にも対応しています。本雛形コードでは、パレート解の可視化のためにグラフ化、パレート解の目的変数と説明変数をcsvファイルに保存する仕様にしました。
インストール方法は2通りあります。pipの場合は次のようにします。
pip install optuna
Anaconda環境ではcondaで出来ます。
conda install -c conda-forge optuna
本コードを実行すると、下図のようにパレート解を図示するグラフを作成して画像ファイルで保存します。
更に、下図のようにパレート解の結果をcsvファイルに保存します。
ここで、変数の制約条件は変数の型によって次のようにして設定します。上図は浮動小数型の場合です。
lstat = trial.suggest_float('LSTAT', 1.7, 38)
rm = trial.suggest_int('RM', 3, 8)
chas = trial.suggest_categorical('CHAS', ['0', '1'])
density = trial.suggest_loguniform('Density', 1e-15, 1e-7)
rad = trial.suggest_discrete_uniform('RAD', 1,2,3,4,5,6,7,8,24)
■本プログラム
import optuna
import matplotlib.pyplot as plt
def objective(trial):
x = trial.suggest_float("x", 0, 5)
y = trial.suggest_float("y", 0, 3)
v0 = 4 * x ** 2 + 4 * y ** 2
v1 = (x - 5) ** 2 + (y - 5) ** 2
return v0, v1
study = optuna.multi_objective.create_study(
directions=["minimize", "minimize"],
sampler=optuna.multi_objective.samplers.NSGAIIMultiObjectiveSampler(seed = 1)
)
study.optimize(objective, n_trials=200)
trials = {str(trial.values): trial for trial in study.get_trials()}
trials = list(trials.values())
y1_all_list = []
y2_all_list = []
for i, trial in enumerate(trials, start=1):
y1_all_list.append(trial.values[0])
y2_all_list.append(trial.values[1])
trials = {str(trial.values): trial for trial in study.get_pareto_front_trials()}
trials = list(trials.values())
trials.sort(key=lambda t: t.values)
y1_list = []
y2_list = []
with open('pareto_data.csv', 'w') as f:
for i, trial in enumerate(trials, start=1):
if i == 1:
columns_name_str = 'trial_no,y1,y2'
data_list = []
data_list.append(trial.number)
y1_value = trial.values[0]
y2_value = trial.values[1]
y1_list.append(y1_value)
y2_list.append(y2_value)
data_list.append(y1_value)
data_list.append(y2_value)
for key, value in trial.params.items():
data_list.append(value)
if i == 1:
columns_name_str += ',' + key
if i == 1:
f.write(columns_name_str + '\n')
data_list = list(map(str, data_list))
data_list_str = ','.join(data_list)
f.write(data_list_str + '\n')
plt.rcParams["font.size"] = 16
plt.figure(dpi=120)
plt.title("multiobjective optimization")
plt.xlabel("Y1")
plt.ylabel("Y2")
plt.grid()
plt.scatter(y1_all_list, y2_all_list, c='blue', label='all trials')
plt.scatter(y1_list, y2_list, c='red', label='pareto front')
plt.legend()
plt.tight_layout()
plt.savefig("pareto_graph.png")
plt.close()
(参考資料)マニュアルは次のリンク先pdfです。https://optuna.readthedocs.io/_/downloads/en/latest/pdf/。
以上
<広告>
リンク