AI 공부 저장소

파이토치를 사용한 선형 회귀 (Linear Regression) 모델링 본문

Artificial Intelligence/ML&DL

파이토치를 사용한 선형 회귀 (Linear Regression) 모델링

aiclaudev 2022. 2. 8. 15:44

 

본 글은 Sung Kim 교수님의 PyTorchZeroToAll 강의를 토대로 저의 지식을 아주 조금 덧붙여 작성하였습니다.
글 작성에 대한 허락을 받아, 개인 공부용으로 작성합니다.
문제가 발생할 시 비공개로 전환함을 알립니다.
https://www.youtube.com/channel/UCML9R2ol-l0Ab9OXoNnr7Lw

 

 

 

 

먼저, 파이토치를 이용해 모델링하는 전체적인 절차는 위와 같고, 정리하자면 아래와 같습니다.

 

Model을 디자인한다.

loss와 optimizer를 구성한다. 이 때, loss와 optimizer도 파이토치 내에 여러 종류가 있기에 자신의 상황에 맞는, 적절한 API를 채택하는 과정이 필요합니다.

Forward, Backward, Update 과정을 통해 모델을 Training합니다. 이 때, Update란 Forward와 Backward 과정에서 나온 결과를 토대로 변수를 갱신하는 것을 의미합니다.

 

이제, 각 단계에 따라 어떻게 코드를 구성할 수 있는지 알아보도록 하겠습니다.

 

0) Data definition

학습에 사용할 데이터를 정의하는 부분입니다. 

 

 

 

1) Design your model using class with Variables

이제, 클래스를 정의할 것입니다. 위 사진에서 클래스의 이름인 Model은 임의로 설정한 것이므로, 어떤 이름으로 하여도 상관 없습니다. 또한. 이 클래스는 torch.nn.Module를 부모 클래스로 갖는, 파생 클래스입니다.

 

① def __init__(self)

클래스의 생성자를 정의하는 부분입니다. 먼저, super를 통해 부모 클래스의 생성자를 사용합니다. 

그리고, 멤버인 linear 객체를 위와 같이 생성해주면 됩니다. (1, 1)은 이 Linear 모델은 한 개의 Input을 입력받아, 한 개의 Output을 추출할 것이라 명시하는 부분이라 생각하시면 될 것 같습니다.

 

② def forward(self, x)

Back propagation을 사용하여 모델을 최적화하기 위해선, forward라는 단계가 필요하였습니다. 이를 정의하기 위한 메소드입니다. 보시다시피, y_pred를 return하는 것을 확인할 수 있습니다.

 

 

 

2) Construct loss and optimizer (select from Pytorch API)

Loss를 계산하기 위한 criterion과, optimizer를 생성합니다. optimizer 생성 시 parameter로 입력되어있는 

model.parameters( )는 어떤 파라미터를 변경시킬 것인지에 대해 입력한 것이라고 생각하면 됩니다.

 

 

 

3) Training Cycle (forward, backward, update)

모델링의 마지막 단계인 Training입니다. 

forward :  y_pred를 계산한 후, criterion을 이용하여 loss를 계산합니다.

backward : optimizer를 통해 변수를 계속 갱신해나가며 최적의 값을 찾아나갑니다. 이 때, step이란 update라고 생각하면 될 것 같습니다. 

 

 

 

4) Test

x에 4를 넣었을 때, y값이 7.996 정도 나오는 것을 확인할 수 있습니다.

8에 굉장히 근접한 값이므로 성공적으로 모델링하였다고 생각할 수 있습니다.

 

 

 

5) 전체 코드 

 

 

 

 

6) 관련 Exercise

 

 

 

 

 

 

 

 

 

 

본 글은 Sung Kim 교수님의 PyTorchZeroToAll 강의를 토대로 저의 지식을 아주 조금 덧붙여 작성하였습니다.
글 작성에 대한 허락을 받아, 개인 공부용으로 작성합니다.
문제가 발생할 시 비공개로 전환함을 알립니다.
https://www.youtube.com/channel/UCML9R2ol-l0Ab9OXoNnr7Lw