삽질Log

Pytorch| RuntimeError: Trying to backward through the graph a second time

rrojin 2022. 11. 10. 15:01
목차

1. Error

2. Problem Case

3. Solution & MFA

 1.  Error   

RuntimeError: 
Trying to backward through the graph a second time

(or directly access saved tensors after they have already been freed).
Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). 
Specify retain_graph=True if you need to backward through the graph a second time or 
if you need to access saved tensors after calling backward.

 

 2.  Problem Case   

MAML Finetune버전 구현 중에 마주친 에러입니다.
Goal은 MINI-Imagenet데이터셋을 meta-learning으로 학습 시키기입니다.
이미지 학습에는 CNN을 사용하였고, 구현한 모델의 CNN Architecture은 크게 Feature extraction(conv2d, relu, batch norm, maxpool) 과 Classification layer(linear)로 나눌 수 있습니다.   

MAML 모델은 inner loop와 outer loop로 구성되어있습니다. inner loop에서는 fast update만을 수행하고, outer loop에서는 meta update(loss.backward(retain_graph=True) → self.optimizer.step() )가 이루어집니다. 이때 inner loop에서도 Classification layer에 대해서만 meta update하기 위해서 loss.backward(미분)를 사용하고자 합니다.  
따라서 loss.backward()를 2번 사용해야 하는데, 여기서 발생한 문제입니다. 

 

 3.  Solution   

 loss.backward()를 2번 사용했을 때의 문제점
"Saved intermediate values of the graph are freed when you call .backward() or autograd.grad()."  
→".backward() 또는 autograd.grad()를 호출하면 그래프의 저장된 중간 값이 해제됩니다."

해결 방법은,"Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward."

→backward를 2번 수행하거나, backward 이후에 다시 저장된 값에 접근하려면 'retain_graph=True' 설정을 해야합니다. 

일반적으로, gradient를 구한뒤에는 다시 필요가 없기 때문에, loss.backward() 이후에는 다시 접근할 필요가 없기 때문에 메모리 최적화를 위해 backward() 함수가 끝나면 그래프(intermediate 텐서 버퍼 값)가 삭제됩니다. 따라서 그래프를 계속 유지해야 다시 그래프에 접근할 수 있습니다.

loss1.backward(retain_graph=True)  # 그래프 유지
optimizer.step()
loss2.backward() 
optimizer.step()

References

https://stackoverflow.com/questions/48274929/pytorch-runtimeerror-trying-to-backward-through-the-graph-a-second-time-but

'삽질Log' 카테고리의 다른 글

Python | 2차원 list shuffle  (0) 2022.11.14