-
[딥러닝] Pytorch로 구현한 Linear Regression[Study] Data Science/머신러닝&딥러닝 2022. 6. 25. 13:20
Pytorch를 통해 Linear Regression 문제를 풀어보고자 한다. 본 글은 '모두를 위한 딥러닝 시즌2-Pytorch' 편을 참고하였다.
1. 가설
너무나도 익숙한 선형회귀의 방정식은 학습데이터를 가장 잘 맞는 직선을 찾는 일이다. "가장 잘 맞는 직선을 찾는다"는 것은 곧 직선을 잘 표현할 수 있게끔하 하는 Weight(직선의 기울기)와 bias(절편)을 찾는 것과 같다고 할 수 있다.
<모두를 위한 딥러닝2 Pytorch - lab03 강의 중> 2. 비용함수
예측 값이 실제 값과 차이가 크지 않을 때, W와 b가 적합하다고 할 수 있다. 따라서, 이를 측정하기 위해 비용함수(cost function)을 정의하게 된다. 주로 Linear Regression에서 쓰이는 cost function은 MSE(Mean Squared Error)이다. 예측값과 실제값의 차이를 제곱한 값의 평균을 구한다고 볼 수 있다.
<모두를 위한 딥러닝2 Pytorch - lab03 강의 중> 결국 학습을 시킨다는 것은 비용함수를 최소화하는 일이며, 최소화하기 위해서는 이 비용함수의 미분값을 0에 가깝게 업데이트시키는 것으로 이해할 수 있다. 이러한 과정을 'Gradient Descent'라고 한다.
3. Pytorch로 구현하기
먼저 학습데이터를 x_train, y_train이라는 학습데이터를 임의로 설정해보자. 예측해야할 W와 b는 zero로 초기화한다. 이 때, 'requires_grad = True'라는 파라미터를 붙여주는데, 이는 W와 b가 학습할 것임을 암시해주는 역할을 한다.
import torch x_train = torch.FloatTensor([[1],[2],[3]]) #입력 y_train = torch.FloatTensor([[2],[4],[6]]) #출력 W= torch.zeros(1,requires_grad=True) b= torch.zeros(1,requires_grad=True)
다음으로 optimizer를 설정한다. 앞서 설명한 경사하강법의 일종인 'SGD'를 사용하고, 얼마만큼 W, b값을 업데이트할 것인지를 결정해주는 learning rate을 설정한다.
import torch.optim as optim optimizer = optim.SGD([W,b],lr=0.01)
학습데이터가 주어졌고 W,b를 초기화한 후, optimizer까지 결정했다면 학습을 진행한다.
- cost : 예측값과 실제값 차이의 제곱에 대해 torch.mean을 통해 평균한다.
- optimizer.zero_grad() : gradient를 0으로 초기화하는 역할이다. pytorch의 경우, 미분을 통해 얻은 기울기를 이전에 계산된 값에 누적시키는 특성이 있기 때문에 초기화를 해주어야 한다.
- cost.backward() : cost function을 미분하여 gradient를 계산해주는 역할이다.
- optimizer.step() : 전 단계에서 계산한 값을 통해 W와 b를 업데이트한다.
이 과정을 cost가 아주 작아질 때까지 n번 반복한 후 학습을 마무리한다.
nb_epochs =1000 for epoch in range(1,nb_epochs+1): hypothesis = x_train*W + b cost = torch.mean((hypothesis-y_train)**2) optimizer.zero_grad() cost.backward() optimizer.step() if epoch%100 == 0: print('Epoch {:4d}/{} W: {:.3f}, Cost: {:.6f}'.format(epoch, nb_epochs, W.item(),cost.item()))
'[Study] Data Science > 머신러닝&딥러닝' 카테고리의 다른 글
[딥러닝] pytorch로 CNN 구현하기 (0) 2022.07.30 [딥러닝] minibatch를 통한 다변량 선형회귀 (0) 2022.06.26 [강화학습] Policy Gradient와 Actor-Critic (0) 2022.01.25 [강화학습] Model Free Prediction: Monte-Carlo, Temporal Difference, SARSA (0) 2022.01.05 [강화학습] Model-based Planning by Dynamic Programming (0) 2021.12.31