[머신러닝] K-최근접 이웃(K-NN) 알고리즘 및 예제

[머신러닝] K-최근접 이웃(K-NN) 알고리즘 및 예제


K-최근접 이웃 알고리즘 개념

k-최근접 이웃 알고리즘(K-Nearest Neighbors)은 가장 간단한 머신러닝 알고리즘으로, 분류(Classification)알고리즘이다.
K-최근접 이웃 알고리즘은 어떤 데이터가 무엇 인가를 판별 하고 싶을 때 인접한 데이터 중 다수를 차지하는 것을 정답으로 한다.
인접한 데이터를 알기 위해서는 거리를 측정 해야 하는데 이때 유클리드 거리(Euclidean distance) 계산법을 사용한다.


예를 들어 아래 그림처럼 주황색 원과 파란색 원들이 5개씩 뭉쳐있다.
이때 색상에 따라 주황색 그룹과 파란색 그룹으로 부를 수 있다.



여기서 알 수 없는 데이터인 미지의 색(초록색) 원이 추가 되었다.
여기서 k-최근접 이웃 알고리즘을 적용하면 초록색 원은 주황색 원들과 인접해 있으므로 주황색 그룹으로 분류된다.




k-최근접 알고리즘은 분류 하고 싶은 데이터와 인접한 k개의 데이터를 찾는다.
아래 그림처럼 k = 3일 때 주황색 2개, 초록색 1개 이므로 초록색 원은 주황색 그룹으로 분류된다.
이 처럼 주변에 가장 가까운 k개의 데이터를 보고 데이터가 속하는 그룹을 판단하는 것을 k-최근접 이웃 알고리즘이다.



여기서 주의 할 점이 있다. k의 개수는 홀수로 하는 것이 좋다.
이유는 k의 개수를 짝수로 했을 때 아래 그림과 같은 상황이 발생 할 수 있기 때문이다.
그림을 보면 주황색 원과 파란색 원이 각각 2:2 동점인 상황이 발생하여 초록색 원을 분류 할 수 없게 된다.



k-최근접 이웃 알고리즘 장점

1. 단순하고 효율적이다.
2. 훈련 단계가 매우 빠르다.


k-최근접 이웃 알고리즘 단점

1. 모델을 생성하지 않아 특징과 클래스간 관계를 이해하는 데 제한적이다.
2. 적절한 k 선택이 필요하다.
3. 데이터가 늘어나면 분류 단계가 느려진다.




K-최근접 이웃 알고리즘 파이썬 예제

지금부터 아래의 내용들은 K-최근접 이웃 알고리즘의 예제다. 
차례대로 따라 해보면 이해가 될 것이다.


matplotlib 패키지 import 하기

K-최근접 이웃 알고리즘을 실습 하기 위해서는 matplotlib 패키지sklearn 패키지 2개를 사용해야 한다. 
먼저 matplotlib 패키지 사용법을 알아보자.


우선 matplotlib 패키지를 import 해주자.
matplotlib 패키지는 파이썬에서 과학 계산용 그래프를 그려주는 대표적인 패키지다.
이 패키지로 K-최근접 이웃 알고리즘의 결과를 확인 할 거다.
프로그램을 실행 시키면 아래와 같은 그래프를 알아서 그려줘서 편리하다.
코드를 읽어보면 대충 사용법을 알 수 있다.

혹시 파이참을 사용한다면 아래 링크 참조.
import matplotlib.pyplot as plt

x = [45,47,48,40,50]
y = [300,350,300,400,450]

plt.scatter(x, y)
plt.xlabel('x')
plt.ylabel('y')
plt.show()
결과 : 






데이터 준비하기

K-최근접 이웃 알고리즘을 사용하기 전에 어린이 10명과 어른 5명 키, 몸무게 데이터를 준비 하자.
프로그램을 실행시키면 파란색 원이 어린이, 주황색 원이 어른의 데이터 인 것을 확인 할 수 있다.
import matplotlib.pyplot as plt

child_height = [120,122,124,128,130,132,134,136,138,140]
child_weight = [45,47,48,40,50,45,47,48,40,50]

adult_height = [170,175,180,185,190]
adult_weight = [300,350,300,400,450]

plt.scatter(child_height, child_weight)
plt.scatter(adult_height, adult_weight)

plt.xlabel('height')
plt.ylabel('weight')
plt.show()
결과 : 



그래프 그리는 코드는 지우고 2개의 리스트 child와 adult를 하나의 데이터로 만들기 위해서 어린이와 어른의 리스트를 합치자.
import matplotlib.pyplot as plt

child_height = [120,122,124,128,130,132,134,136,138,140]
child_weight = [45,47,48,40,50,45,47,48,40,50]

adult_height = [170,175,180,185,190]
adult_weight = [300,350,300,400,450]

height = child_height + adult_height #리스트 합치기
weight = child_weight + adult_weight #리스트 합치기



사이킷런은 파이썬 머신러닝에서 가장 많이 사용되는 라이브러리다.
사이킷 런 패키지를 사용하려면 리스트를 세로 방향으로 늘어뜨린 2차원 리스트로 만들어야 한다.

zip() 함수는 리스트에 원소를 처음부터  하나씩 꺼내주는 함수다.
Zip() 함수를 사용하여 height와 weight의 원소를 각각 h,w에 할당한다.
이러면 height 리스트와 weight 리스트로 구성된 2차원 리스트 person_data가 만들어진다. 총 15개의 키와 몸무게로 이루어진 데이터를 준비했다.
person_data = [[h,w] for  h, w in zip(height, weight)]


마지막으로 준비 할 데이터는 정답 데이터다.
0 = 어린이, 1 = 어른으로 컴퓨터에게 정답을 알려주기 위한 데이터를 만든다.
person_target 리스트에 0 이 10개, 1 이 5개인 총 15개의 원소를 만들었다.
이제  데이터 준비는 모두 끝났다.
데이터의 정보를 가진 person_data와 정답 정보를 가진 person_target 을 사용하여 k-최근접 이웃 알고리즘을 사용 할 수 있다.
person_target = [0] * 10 + [1] * 5




K-최근접 이웃 알고리즘 사용하기

우선 사이킷런 패키지 중 K-최근접 이웃 알고리즘을 구현한 클래스 KNeighborsClassifie를 import 해주자.

*주의 : 사이킷런 패키지 이름은 sklearn이 아닌 scikit-learn 으로 install 해야 한다.
from sklearn.neighbors import KNeighborsClassifier


KNeighborsClassifier 클래스의 객체 kn을 만들자.
kn = KNeighborsClassifier()


kn 객체에 person_data person_target 데이터를 전달하여 학습을 시켜야 한다. 이때 이러한 과정을 머신러닝에서 훈련이라고 부른다.
사이킷런에서는 fit()메서드가 훈련을 해주는 역할을 한다.
kn.fit(person_data, person_target)


이제 얼마나 잘 훈련 되었는지 평가해보자.
사이킷런에서 score()가 모델을 평가하는 메서드다.
score() 메서드는 0.0 ~ 1.0 사이의 값을 반환한다. 1에 가까울 수록 데이터를 정확히 맞췄다는 뜻이다.
프로그램 실행 후 결과 창에 1.0이 나오면 100% 정확히 어린이와 어른을 분류 했다는 뜻이며 이 값을 정확도(accuracy)라고 부른다.
data = kn.score(person_data,person_target)
print(data)
결과 : 1.0




새로운 데이터 분류하기

위에 예제로 어린이와 어른을 분류하는 훈련은 끝났다.
새로운 데이터를 추가해 보자.
(170,350)에 데이터를 추가 했다.
그래프를 보면 새로운 데이터(주황색)는 상위 오른쪽에 있는 5개 어른 데이터 쪽에 있다.
만약 K-최근접 이웃 알고리즘을 사용해서 분류를 한다면 어른 그룹으로 분류 될 것이다.
한번 확인해 보자.
new_height = [170]
new_weight = [350]

plt.scatter(height, weight)
plt.scatter(new_height, new_weight)

plt.xlabel('height')
plt.ylabel('weight')
plt.show()



predict() 함수는 새로운 데이터의 정답을 예측한다.
우리는 어린이 = 0, 어른은 = 1로 정했다.
따라서 새로운 데이터(주황색 원) 결과 값이 [1]이 나온다.
a = kn.predict([[170, 350]])
print(a)
결과 : [1]



k개수 바꾸기

KNeighborsClassifier클래스의 fit() 함수를 사용하여 데이터를 저장을 했다.
그리고 predict() 함수를 사용하여 새로운 데이터의 가장 가까운 데이터들을 참고 하여 어린이인지 어른인지 구분했다.
그렇다면 predict() 함수가 주변에 있는 몇 개의 데이터를 참조할까?
우선 기본값은 5개다.
만약 참고 데이터를 바꾸고 싶다면 아래와 같이 n_neighbors 매개변수의 값을 바꾸면 된다.
kn15 = KNeighborsClassifier(n_neighbors=15)


만약 15개 참고를 한 후 score 함수를 사용해 모델을 평가하면 0.666666... 값이 나온다.
이유는 어린이는 10개, 어른은 5개라서 무조건 어린이만 참고하기 때문이다.
아까 score함수에서 설명 했듯 1에 가까울 수록 정확한 데이터다.
즉, K-최근접 이웃 알고리즘을 사용 할 때는 참고 데이터 개수(k 개수)를 알맞게 정해 줘야 한다.
kn15.fit(person_data, person_target)
a = kn15.score(person_data, person_target)
print(a)
결과 : 0.6666666666666666


전체 소스 코드

K-최근접 이웃 알고리즘 예제다.
import matplotlib.pyplot as plt #그래프 그리는 패키지 import
from sklearn.neighbors import KNeighborsClassifier #사이킷런 Kneighclass import

#데이터 준비하가ㅣ -----------------------------------------
child_height = [120,122,124,128,130,132,134,136,138,140] #어린이 리스트
child_weight = [45,47,48,40,50,45,47,48,40,50]

adult_height = [170,175,180,185,190] #어른 리스트
adult_weight = [300,350,300,400,450]

height = child_height + adult_height #리스트 합치기
weight = child_weight + adult_weight #리스트 합치기

person_data = [[h,w] for h, w in zip(height, weight)] #정보 데이터 만들기(2차원 리스트 만들기)

person_target = [0] * 10 + [1] * 5 #정답 데이터 만들기

kn = KNeighborsClassifier() #k-최근접 이웃 알고리즘 사용 객체 만들기
kn.fit(person_data, person_target) #데이터 보내주기

data = kn.score(person_data,person_target) #데이터 평가하기
print(data)

#k 개수 정해주기 -----------------------------------------
kn15 = KNeighborsClassifier(n_neighbors=15)

kn15.fit(person_data, person_target)
a = kn15.score(person_data, person_target)
print(a)

## 그래프 그리기 -----------------------------------------
# child_height = [120,122,124,128,130,132,134,136,138,140]
# child_weight = [45,47,48,40,50,45,47,48,40,50]
#
# adult_height = [170,175,180,185,190]
# adult_weight = [300,350,300,400,450]
#
# plt.scatter(child_height, child_weight)
# plt.scatter(adult_height, adult_weight)
#
# plt.xlabel('height')
# plt.ylabel('weight')
# plt.show()


참조


혼자 공부하는 머신러닝 + 딥러닝(깃 허브) : https://github.com/rickiepark/hg-mldl


댓글

이 블로그의 인기 게시물

[Arduino] 아두이노 초음파 센서(HC-SR04) 사용하기

[Arduino] 아두이노 조이스틱 사용하기

[자연 환경] 농약의 장단점 농약이 환경과 인간에게 미치는 영향