AI 공부 저장소

로지스틱 회귀 (Logistic Regression) Wide & Deep 본문

Artificial Intelligence/ML&DL

로지스틱 회귀 (Logistic Regression) Wide & Deep

aiclaudev 2022. 2. 9. 17:02

 

본 글은 Sung Kim 교수님의 PyTorchZeroToAll 강의를 토대로 저의 지식을 아주 조금 덧붙여 작성하였습니다.
글 작성에 대한 허락을 받아, 개인 공부용으로 작성합니다.
문제가 발생할 시 비공개로 전환함을 알립니다.
https://www.youtube.com/channel/UCML9R2ol-l0Ab9OXoNnr7Lw

 

 

 

 

① 개요

위 상황이 우리가 앞서 다루었던 로지스틱 회귀입니다. GPA를 통해 Admission이 될 것인지 안 될 것인지 (Binary)를 예측하는 것인데, 여기서 Point는 결과를 예측하기 위한 변수가 GPA 하나 뿐이라는 것입니다.

하지만, 실생활에서 대부분의 경우에, 어떠한 결과에 대한 원인은 여러가지 입니다. 따라서, 위 상황처럼 Admission을 예측하기 위해서 GPA만을 사용하는 것은 조금 부족하다고 할 수 있죠.

 

위 처럼, y를 예측하기 위해 두개의 x 변수(input)를 사용하는 것이 더욱 타당한 모델링이라고 할 수 있습니다.

 

 

 

② Matrix Multiplication

두 개 이상의 input 변수를 사용하는 시점부터, 행렬이 사용됩니다.  

위 수식을 직접 행렬곱해보면, 아래와 같은 결과가 나오겠네요.

 

a와 b는 x변수, (즉, GPA와 Experience)이고, w는 각 변수에 대한 weight (결과 예측시 얼마나 반영할 것인가?) 정도로 이해할 수 있습니다. 따라서, 적절한 w를 선정해주는 것이 모델링의 핵심이 되겠습니다. 

 

 

 

③ Go Wide & Go Deep

 

 

이 강의에선 하나의 단일 input이 아닌, 여러개의 input을 사용하는 것을 Go Wide 라고 표현하고, 하나의 layer가 아닌 여러개의 layer를 사용하는 것을 Go Deep 이라고 표현하네요.

 

 

위 코드를 한번 살펴보겠습니다. 먼저 sigmoid 객체를 생성한 후, 3개의 layer를 정의하네요. 아시다시피, layer정의 시 들어가는 두개의 parameter는 각각 input의 개수, output의 개수입니다. 예를 들어, l2와 같은 경우 4개의 input을 받아 3개의 output을 출력하네요. 이때, 2가지 조건이 존재합니다.

 

① 마지막이 아닌 layer의 Output 수는, 그 다음 layer의 Input 수와 일치한다.

ex. l1의 Output수와 l2의 Input수가 일치해야하고, l2의 Output수와 l3의 Input수가 일치해야합니다. (이는 어떻게 보면 당연하죠)

 

② 가장 처음 layer의 Input 수와 가장 마지막 layer의 Output 수는, 모델링 목적과 일치해야한다.

ex. l1의 Input수 2와, l3의 Output 수 1은 우리의 목적과 일치해야한다. (GPA와 Experience를 입력하여 Admission 여부를 예측한다) 이를 제외한 layer의 Input과 Output 수는 임의대로 결정한다. (단, ①을 만족하며)

 

 

 

④ 코드 구현

8 X 9 Matrix 형태의 csv 파일

 위와 같은 csv 파일에 대해 코드를 구현해보겠습니다.

 

1) 데이터 불러오기

먼저, 변수 xy에 csv 파일을 불러옵니다.

입력변수로 사용될 x_data는 총 9개의 컬럼 중, 8개의 컬럼이고 출력변수인 y_data는 가장 마지막 column입니다. (0 또는 1의 값으로 Binary class 임을 확인할 수 있습니다.) 

 

 

2) 모델링

먼저, Model을 Design해야합니다. 총 3개의 layer를 사용할 것이고, forward 메소드를 통해 y_pred를 추출하는 것을 확인할 수 있습니다. 그리고 이외의 과정은 앞서 다룬 모델링 과정과 동일합니다.

 

 

 

 

본 글은 Sung Kim 교수님의 PyTorchZeroToAll 강의를 토대로 저의 지식을 아주 조금 덧붙여 작성하였습니다.
글 작성에 대한 허락을 받아, 개인 공부용으로 작성합니다.
문제가 발생할 시 비공개로 전환함을 알립니다.
https://www.youtube.com/channel/UCML9R2ol-l0Ab9OXoNnr7Lw