본문 바로가기
IT 이야기/Python

[Python] sklearn

by Dblog 2020. 12. 8.
728x90

 

Sklearn, 머신러닝 라이브러리 중 하나입니다. 

 

scikit-learn: machine learning in Python — scikit-learn 0.23.2 documentation

Model selection Comparing, validating and choosing parameters and models. Applications: Improved accuracy via parameter tuning Algorithms: grid search, cross validation, metrics, and more...

scikit-learn.org

 

오늘은 머신러닝 라이브러리에 대해서 간단한 리뷰를 작성하겠습니다.


 

머신러닝하면 먼저 떠오르는 사진은 

google capture

이런 종류의 사진입니다.  하지만 이 사진은 머신 러닝이 아닌 AI입니다.
그렇다면 머신러닝은 어떤 형태일까요

google capture

이 사진은 머신러닝의 대표적인 예시 Nurunerual network 입니다. 
AI, 머신러닝 둘의 관계는 인공지능이 머신러닝을 포함하고 있는 관계 입니다. 

둘의 가장 큰 차이는 변화된 상황에 어떻게 대처하냐에 있습니다. 

AI는 상황이 변화되면 일정 시간이 지난후 변화된 상황에 적응하여 가장 올바른 답에 가까운 결과를 보여줍니다.
하지만 머신러닝은 상황이 변화되면 사람이 변화된 상황에 맞게 수정해주어야 올바른 결과를 나타냅니다.

서로 데이터를 활용하는 것은 같지만 상황 대처 능력이 있음과 없음의 차이가 있습니다.


Sklearn 라이브러리에서는 크게 6개의 기능을 지원합니다.

 

https://scikit-learn.org/stable/

링크에 예시 코드와 함께 설명이 함께 있습니다. 아래 예시는 예전에 테스트 해보기 위해 가장 유명한 예시인 Iris 클러스터링 입니다.

 

데이터 예시

iris의 데이터는 sklearn 라이브러리를 다운받을때 예시 데이터로 함께 다운됩니다. 다양한 예시로 사용되고 있습니다.

 

코드 예시

너무 오래전 소스라 이미지 밖에 안남아서 이미지를 소스로 옮겼습니다. 오타의 확률이 높아 그대로 복붙하시면 안될 가능성이 높습니다.

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import KMeans
from sklearn import datasets

np.random.seed(5)
iris = datasets.load_iris()
X = iris.data
y = iris.target
estimators = [('k_means_iris_8', KMeans(n_clusters=8)),
              ('k_means_iris_3', KMeans(n_clusters=3)),
              ('k_means_iris_bad_init', KMeans(n_clusters=3, n_init=1, init="random"))
              ]

fignum = 1
titles = ['8 clusters', '3 clusters', '3 clusters, bad initialization']
for name, est in estimators:
    fig = plt.figure(fignum, figsize=(4,3))
    ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
    est.fit(X)
    labels = est.labels_
    ax.scatter(X[:,3], X[:, 0], X[:, 2],
               c=labels.astype(np.float), edgecolors='k')
    ax.w_xaxis.set_ticklabels([])
    ax.w_yaxis.set_ticklabels([])
    ax.w_zaxis.set_ticklabels([])
    ax.set_xlabel('Petal width')
    ax.set_ylabel('Sepal length')
    ax.set_zlabel('Petal length')
    ax.set_title(titles[fignum - 1])
    ax.dist = 12
    fignum = fignum + 1

fig = plt.figure(fignum, figsize=(4,3))
ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
for name, label in [('Setosa', 0),
                    ('Versicolour', 1),
                    ('Virginica', 2)]:

    ax.text3D(X[y == label, 3].mean(),
              X[y == label, 0].mean(),
              X[y == label, 2].mean() + 2, name,
              horizontalalignment="center",
              bbox=dict(alpha=.2, edgecolor='w', facecolor='w'))

y = np.choose(y, [1, 2, 0]).astype(np.float)
ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=y, edgecolors='k')
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
ax.set_xlabel('Petal width')
ax.set_ylabel('Sepal length')
ax.set_zlabel('Petal length')
ax.dist = 12
fig.show()

 

결과 예시

 

 

 

 

728x90

'IT 이야기 > Python' 카테고리의 다른 글

[Python] Matplotlib  (0) 2020.12.07

댓글