공부/A.I

복잡도와 일반화 사이의 관계

래울 2021. 5. 10. 19:12

books.google.co.kr/books?id=tPaTDwAAQBAJ&lpg=PP1&dq=%ED%8C%8C%EC%9D%B4%EC%8D%AC%EB%9D%BC%EC%9D%B4%EB%B8%8C%EB%9F%AC%EB%A6%AC%EB%A5%BC%20%ED%99%9C%EC%9A%A9%ED%95%9C%20%EB%A8%B8%EC%8B%A0%EB%9F%AC%EB%8B%9D&hl=ko&pg=PA1#v=onepage&q=%ED%8C%8C%EC%9D%B4%EC%8D%AC%EB%9D%BC%EC%9D%B4%EB%B8%8C%EB%9F%AC%EB%A6%AC%EB%A5%BC%20%ED%99%9C%EC%9A%A9%ED%95%9C%20%EB%A8%B8%EC%8B%A0%EB%9F%AC%EB%8B%9D&f=false

 

파이썬 라이브러리를 활용한 머신러닝(번역개정판)

사이킷런 핵심 개발자에게 배우는 머신러닝 이론과 구현   현업에서 머신러닝을 연구하고 인공지능 서비스를 개발하기 위해 꼭 학위를 받을 필요는 없습니다. 사이킷런(scikit-learn)과 같은 훌륭

books.google.co.jp

 

- scikit-learn의 유방암 데이터 셋

실제 데이터 셋으로, 유방암 종양이 음성인지 양성인지를 예측할 수 있도록 학습한다.


위의 유방암 데이터 셋을 활용해 모델 복잡도와 일반화 사이의 관계에 대해 알아보자

 

이웃개수(1~20), random_state=0

 

이웃개수(1~10), random_state=5
이웃개수(1~10), random_state=10
이웃개수(1~15), random_state=20

 

이웃이 하나일 때는 훈련데이터에 대한 예측이 완벽하고, 이웃이 늘어날 수록 훈련데이터의 정확도는 감소함을 볼 수 있다.

테스트데이터에 대한 정확도는 특정 이웃까지는 증가하다 특정 이웃 수가 넘어가면 감소함을 볼 수 있다.

 


 

from IPython.display import display
import numpy as np
import pandas as pd
import mglearn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.neighbors import KNeighborsClassifier
cancer = load_breast_cancer()
print("종양 데이터 key : ", cancer.keys())
print("데이터셋개수와 특성개수 : ", cancer.data.shape)
####################################################
X_train, X_test, y_train, y_test = train_test_split(
    cancer.data, cancer.target, stratify=cancer.target, random_state=0)
training_accuracy = []
test_accuracy = []
neighbors_settings = range(1,21)
for n_neighbors in neighbors_settings:
    #모델 생성
    clf = KNeighborsClassifier(n_neighbors = n_neighbors)
    clf.fit(X_train, y_train)
    training_accuracy.append(clf.score(X_train, y_train))
    test_accuracy.append(clf.score(X_test, y_test))

plt.plot(neighbors_settings, training_accuracy, label="training accuracy")
plt.plot(neighbors_settings, test_accuracy, label="test accuracy")
plt.ylabel("accuracy")
plt.xlabel("n_neighbors")
plt.legend()