'22/07/02更新:クラスタリング前の散布図を冒頭に追加
本記事では、クラス数を指定してクラスター分析(クラスタリング)する雛形コードを載せました。分析結果は、グラフ化してcsvファイルに出力する仕様です。
例題データには、siciki-learnにあるワインデータセットを使用しました。下図はそのデータの散布図です。
そして、下図は本プログラムを実行した例です。クラス数(分類数)に3を指定して3つに分類した場合です。
次に、下図は分類数に4を指定して4つに分類した場合です。グラフのプロットで工夫しています。分類数が多い場合は、凡例を右外に出したり、マーカーの塗りつぶしをなくすことで見易くなる場合もあります。
ちなみに、下図のようにエルボー法によって、分類数を推定するコードも記載しています。変化率が小さくなった時点で判断します。この場合は3から4で変化率が小さいため3が妥当かと推定します。
■本プログラム
from sklearn import datasets
from sklearn import preprocessing
from sklearn.cluster import KMeans
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
plt.rcParams['font.size'] = 18
wine_data = datasets.load_wine()
df = pd.DataFrame(wine_data.data, columns=wine_data.feature_names)
df
column_list = ["alcohol","color_intensity"]
X = df[column_list]
X.hist()
plt.tight_layout()
scaler = preprocessing.StandardScaler()
scaler.fit(X)
scaled_X = scaler.transform(X)
sse_list = []
for i in range(1,11):
km = KMeans(n_clusters=i,
random_state=0)
km.fit(scaled_X)
sse_list.append(km.inertia_)
plt.plot(range(1,11), sse_list, marker='o')
plt.xlabel('Number of clusters')
plt.ylabel('SSE')
plt.grid(True)
cls = KMeans(n_clusters=4)
result = cls.fit(scaled_X)
result
inversed_X = scaler.inverse_transform(scaled_X)
inversed_X
df = pd.DataFrame(inversed_X, columns=column_list)
df
labels = result.labels_
labels
df2 = pd.DataFrame(labels, columns=['label'])
df2
DF = pd.concat([df, df2], axis=1)
DF.to_csv('k-means.csv')
DF
print(result.cluster_centers_)
inversed_center = scaler.inverse_transform(result.cluster_centers_)
inversed_center
markers = ['o', '^', ',', 'v']
colors = cm.rainbow(np.linspace(0, 1, len(DF['label'].unique())))
for i, p in enumerate(DF['label'].unique()):
plt.scatter(DF.loc[DF.label == p, column_list[0]],
DF.loc[DF.label == p, column_list[1]],
marker = markers[i],
facecolor = 'None',
edgecolors = colors[i],
label = 'label_' + str(p),
)
plt.scatter(inversed_center[:,0],
inversed_center[:,1],
s=250,
marker='*',
edgecolors="black",
c='yellow')
plt.xlabel(column_list[0])
plt.ylabel(column_list[1])
plt.legend(bbox_to_anchor=(1, 0.95))
plt.tick_params()
plt.grid()
以上
<広告>
リンク
リンク