[python] PyTorch에서 왜 zero_grad ()를 호출해야합니까?

이 메서드 zero_grad()는 훈련 중에 호출해야합니다. 그러나 문서 는별로 도움이되지 않습니다.

|  zero_grad(self)
|      Sets gradients of all model parameters to zero.

이 메서드를 호출해야하는 이유는 무엇입니까?



답변

에서는 PyTorchPyTorch 가 후속 역방향 패스 에서 그라디언트축적 하기 때문에 역 전파를 시작하기 전에 그라디언트를 0으로 설정해야 합니다. 이것은 RNN을 훈련하는 동안 편리합니다. 따라서 기본 작업은 모든 호출 에서 그라디언트누적 (즉, 합계)하는 것loss.backward() 입니다.

따라서 훈련 루프를 시작할 때 이상적으로 zero out the gradients는 매개 변수 업데이트를 올바르게 수행 해야 합니다. 그렇지 않으면 기울기가 의도 한 방향이 아닌 최소 방향 (또는 최대화 목표의 경우 최대 )을 가리 킵니다 .

다음은 간단한 예입니다.

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

또는 바닐라 경사 하강 법을 수행하는 경우 :

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

참고 : 텐서 에서가 호출 될 때 그래디언트 의 누적 (즉, 합계 )이 발생합니다 ..backward()loss


답변

zero_grad ()는 오류 (또는 손실)를 줄이기 위해 그래디언트 메서드를 사용하는 경우 마지막 단계에서 손실없이 루프를 다시 시작합니다.

zero_grad ()를 사용하지 않으면 필요에 따라 손실이 증가하지 않고 감소합니다.

예를 들어 zero_grad ()를 사용하면 다음 출력을 찾을 수 있습니다.

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

zero_grad ()를 사용하지 않으면 다음 출력을 찾을 수 있습니다.

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5


답변