티스토리 뷰

머신러닝

머신러닝 / 로지스틱 회귀

삼전동해커 2023. 1. 13. 15:32

이번엔 로지스틱 회귀에 대해 공부한다.

로지스틱 회귀는 앞에서 공부한 입력값에 대한 결과 값을 예측하는 단순 선형 회귀와 달리 2가지의 클래스 중 어디에 속하는지 분류하는 이진 분류에 대한 알고리즘이다.

처음에 공부하면서 로지스틱 '회귀'인데 왜 분류하는 알고리즘인지 이해가 안됐다. 로지스틱 회귀는 시그모이드 함수0~1 사이의 예측값을 이용해 결과를 분류하는 회귀가 포함된 분류 알고리즘이다.

 

이번엔 성적에 따라 불/합 여부를 판단하는 문제를 가정하자.

45 : 불

50 : 불

55 : 불

60 : 합

65 : 합

70 : 합

이에 대한 그래프는 다음과 같다.

55점과 60점 사이에서 합격과 불합격이 갈린다. 그 사이의 점수들은 합격일 가능성이 높은 부분과 낮은 부분으로 갈린다.

이 사이의 점수들이 합격할 확률은 0~1사이의  값으로 표현된다. 이를 직선으로 표현하기엔 무리가 있어서 시그모이드라는 함수를 이용한다.

 

시그모이드 함수에 대한 표현식은 다음과 같다.

e는 자연상수 2.71이고, 모델이 구해야할 값은 w와 b이다.

 

기본적인 시그모이드 함수는 다음과 같이 구현한다.

import numpy as np
import matplotlib.pyplot as plt

def sigmoid(x):
    return 1/(1+np.exp(-x))

x = np.arange(-5.0, 5.0, 0.1)
y = sigmoid(x)

plt.plot(x,y,'g')
plt.plot([0,0],[1.0,0.0],':')
plt.title('Sigmoid function')
plt.show()

 

이제 w와 b 값을 조정하면 어떻게 그래프가 변하는지 확인해보자.

#w값이 커질수록 경사가 가파라짐
x = np.arange(-5.0, 5.0, 0.1)
y = sigmoid(x)
y1 = sigmoid(0.5*x)
y2 = sigmoid(2*x)

plt.plot(x, y, 'r', linestyle='--') # w의 값이 1일때
plt.plot(x, y1, 'g') # w의 값이 0.5일때
plt.plot(x, y2, 'b', linestyle='--') # w의 값이 2일때
plt.plot([0,0],[1.0,0.0], ':') # 가운데 점선 추가
plt.title('Sigmoid Function')
plt.show()

#b값이 커질수록 1에 수렴
x = np.arange(-5.0, 5.0, 0.1)
y = sigmoid(x + 0.5)
y1 = sigmoid(x + 1)
y2 = sigmoid(x + 1.5)

plt.plot(x, y, 'r', linestyle='--') # w의 값이 1일때
plt.plot(x, y1, 'g') # w의 값이 0.5일때
plt.plot(x, y2, 'b', linestyle='--') # w의 값이 2일때
plt.plot([0,0],[1.0,0.0], ':') # 가운데 점선 추가
plt.title('Sigmoid Function')
plt.show()

로지스틱회귀의 비용함수

로지스틱 회귀의 시그모이드 함수의 도함수 그래프는 직선 방정식과 다르게 극값이 여러개이다.

그 중 최소값은 글로범 미니멈이라고 하고, 특정 구역에서의 최소값은 로컬 미니멈이라고 한다.

당연히 로컬 미니멈을 비용함수로 사용하는 것은 좋지 못한 선택이다. 이보다 더 최소값이 존재하기 때문에..

MSE를 사용하면 로컬 미니멈을 선택하게될 가능성이 매우 높다.

 

로지스틱회귀에서 사용하는 적절한 비용함수는 가중치를 최소화하는 목적 함수가 존재한다.

위 식은 n개의 데이터가 있고, 실제값 y와 예측값H(x)의 오차를 어떤 목적 함수 f가 나타낸다.

이 f를 어떻게 정의하느냐에 따라 가중치를 최소화할 목적 함수가 완성된다.

 

시그모이드 함수는 0~1사이의 y값을 반환한다. 이는 실제값이 0일 때 y값이 1에 가까워지면 오차가 커지고

실제 값이 1일 때 y값이 0에 가까워지면 오차가 커짐을 의미한다.

이를 로그함수로 표현할 수 있다.

 

실제값 y가 1일 때의 그래프는 파란색, 실제값 y가 0일 때 의 그래프는 빨간색이다.

실제값이 1일 때(파란색), 예측값 H(x)의 값이 1이면 오차가 0이므로 cost는 0이 되고,

실제값이 1일 때(파란색), 예측값이 0이면 cost는 무한히 커진다.

 

결과적으로 두개의 그래프를 합친 다음과 같은 공식이 목적 함수이다.

이렇게 찾아낸 비용함수를 크로스 엔트로피라고한다.

 

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함