티스토리 뷰

MeanShift 클러스터링 알고리즘은 scikit-learn에서 제공되는 비지도 학습 알고리즘 중 하나로, 클러스터의 중심을 식별하고 데이터 포인트를 중심 주변의 밀도가 높은 영역에 그룹화하려고 시도합니다. 이 알고리즘은 데이터의 확률 밀도 함수를 추정하고, 클러스터의 중심을 고밀도 영역으로 반복적으로 이동시킴으로써 작동합니다. 이 예제에서는 MeanShift 알고리즘을 사용하여 합성 데이터셋을 클러스터링하는 방법을 살펴보겠습니다.

우선 필요한 라이브러리를 가져옵니다.

 

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs

다음으로, make_blobs 함수를 사용하여 합성 데이터셋을 생성합니다. 이 함수는 주어진 중심과 표준 편차를 가진 포인트 집합을 만듭니다. 

# 3개의 중심을 가진 합성 데이터셋 생성
centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=1000, centers=centers, cluster_std=0.6)

이제, MeanShift 알고리즘의 대역폭을 추정해야합니다. 대역폭은 확률 밀도 함수를 추정하는 데 사용되는 창의 크기를 제어하는 하이퍼파라미터입니다. estimate_bandwidth 함수를 사용하여 대역폭을 추정할 수 있습니다.

# MeanShift 알고리즘의 대역폭 추정
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

이제, MeanShift 알고리즘의 인스턴스를 만들고 데이터에 맞출 수 있습니다.

# MeanShift 알고리즘 인스턴스 생성
ms = MeanShift(bandwidth=bandwidth)

# 데이터에 알고리즘 적용
ms.fit(X)

마지막으로, 클러스터와 알고리즘에서 찾은 중심을 그래프로 그릴 수 있습니다.

 

# 클러스터 및 알고리즘이 찾은 중심 그래프에 표시
labels = ms.labels_
cluster_centers = ms.cluster_centers_

n_clusters_ = len(np.unique(labels))

plt.figure(1)
plt.clf()

colors = list('bgrcmyk')
for k, col in zip(range(n_clusters_), colors):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(X[my_members, 0], X[my_members, 1], col + '.')
    plt.plot(cluster_center[0], cluster_center[1], 'o', marker

728x90
반응형
댓글
250x250
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/04   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30
글 보관함
공지사항