Python 回帰モデルの精度確認のため、その評価指標を出力する「sklearn.metrics 」

 本記事では、作成した回帰モデルの精度検証のための雛形コードを載せました。
下図は、それをするための読み込みデータ例です。N列の「PRICE」が指標の生値で、O列の「Label」が回帰モデルによる指標の予測値です。この2列を比較することによって、回帰モデルの精度検証を行います。

■本プログラム

evalute_regression_func という関数で作成しています。

import pandas as pd
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

def evalute_regression_func(pred_df, target, file):
    test_data = pred_df[target] 
    pred_data = pred_df['Label'] # Labelは回帰モデルによる予測値の列名。PyCaretで出力した場合の標準カラム名である。
    
    # 平均絶対誤差 (MAE, Mean Absolute Error) 
    mae = mean_absolute_error(test_data, pred_data)
    
    # 平均二乗誤差 (MSE, Mean Squared Error) 
    mse = mean_squared_error(test_data, pred_data)
    
    # 二乗平均平方根誤差 (RMSE: Root Mean Squared Error) 
    #rmse = np.sqrt(mean_squared_error(test_data, pred_data)
    
    # 決定係数 (R2)
    r2 = r2_score(test_data, pred_data)
    
    # 対数平均平方二乗誤差 (RMSLE, Root Mean Squared Logarithmic Error)
    #rmsle = mean_squared_log_error(test_data, pred_data)
    
    _df = pd.DataFrame(
          data = {'target': [f'{target}_{file}'],
                  'MAE': [mae],
                  'MSE': [mse],
                  #'RMSE': [rmse],
                  'R2': [r2],
                  #'RMSLE': [rmsle],
                  }
    )
    print(_df)
    return _df
    
def main():
    # csvファイルをpandasデータフレームで読み出し
    df = pd.read_csv(
        file_path, # ファイルパス
        #names = column_names, # 列名を指定
        #na_values ='?', # ?は欠損値として読み込む
        #comment = '\t', # TAB以降右はスキップ 
        #sep = ',', # 空白行を区切りとする
        #skipinitialspace = True, # カンマの後の空白をスキップ  
        #skiprows = [0] # 飛ばしたい行をリストで指定 
        #header = 0, # ヘッダー行を指定
        #nrows = 5, # 読み込む行数
        encoding = 'utf-8', # 'utf-8' 'shift-jis' 'cp932'
    ) 
    print(df)

    # csvファイルの拡張子を除いたファイル名の取得
    file_name = file_path[:-4]
    
    # 関数を実行
    mydf = evalute_regression_func(df, target_name, file_name)

    # csvファイルに保存
    mydf.to_csv(f'evalute_regression_{file_name}.csv', index = False)


if __name__ == '__main__':
    file_path = '201102_predict.csv'
    target_name = 'PRICE'
    
    main()

以上

<広告>