Python 3次元データを2次元散布図で表記する

 本記事では、下図のように3次元データを2次元散布図で表現する雛形コードを載せました。

f:id:HK29:20211017233019p:plain

例題に使用したデータは機械学習でお馴染みのボストンデータセットです。図例では、X軸にRM(部屋数)、Y軸にLSTAT(低所得者の割合)、(Z軸)カラーにPRICE(部屋の価格)をとりました。カラーの赤色は部屋の価格が高く、逆に青色は安いことを意味します。赤色が多くプロットされてる箇所は、部屋数が多くて、低所得者の割合が低い地域であり、その場合に部屋の価格が高いとわかります。

■本プログラム

#!/usr/bin/env python
# coding: utf-8

# In[1]:


import pandas as pd

df = pd.read_csv('boston_dataset.csv')
df


# In[2]:


import matplotlib.pyplot as plt
plt.rcParams['font.size']=16

x_name = 'RM'
y_name = 'LSTAT'
z_name = 'PRICE'

fig, ax = plt.subplots()
 
sc = ax.scatter(df[x_name], df[y_name],
                c = df[z_name],
                cmap = 'bwr')

fig.colorbar(sc)
ax.set_xlabel(x_name)
ax.set_ylabel(y_name)
ax.grid()
plt.title(z_name)
plt.tight_layout()
plt.show()

以上

<広告>