| 일 | 월 | 화 | 수 | 목 | 금 | 토 |
|---|---|---|---|---|---|---|
| 1 | 2 | 3 | 4 | 5 | 6 | |
| 7 | 8 | 9 | 10 | 11 | 12 | 13 |
| 14 | 15 | 16 | 17 | 18 | 19 | 20 |
| 21 | 22 | 23 | 24 | 25 | 26 | 27 |
| 28 | 29 | 30 | 31 |
- dfs
- object detection
- 이분 탐색
- 2023
- 회고록
- dropout
- 우선 순위 큐
- 다익스트라
- Overfitting
- DP
- 문자열
- 조합론
- 미래는_현재와_과거로
- 세그먼트 트리
- BFS
- 분할 정복
- NEXT
- lazy propagation
- pytorch
- 자바스크립트
- back propagation
- 알고리즘
- 플로이드 와샬
- 너비 우선 탐색
- c++
- 가끔은_말로
- 백트래킹
- 가끔은 말로
- 크루스칼
- tensorflow
- Today
- Total
Doby's Lab
Neural ODE에서 말하는 메모리 문제는 무엇인가요? 본문
요즘 글을 너무 안 올려서... 제 연구노트인데요, 요거라도 올려봅니다. Neural ODE 논문 리뷰를 하는 블로그에서 이 문제를 심층적으로 다루는 글들은 많이 못 봐서요. 제 연구 노트에는 요런 내용들이 좀 많습니다. 괜히 제 딴에는 보다 깊게 이해하려고 하는 것이지만, 시간을 많이 소비한답니다.. 호호
아무튼 요즘 근황 겸 요런 글을 어떠신지요. 올 한 해 의도치 않게 Generative Model에 빠져있었는데, GAN, VAE, Normalizing Flow, DDPM 등에 대해서도 요런 고군분투스러운 글들이 많답니다. (조금의 반응이라도 있다면, 앞으로는 조금씩 정제해서 올려볼까 합니다. 제 생각이 맞는지 의심스러울 때가 많거든요..ㅎ 여기를 discussion의 장으로 만들어보고 싶다는 생각도 있네요. Generative Models는 할 얘기가 너무 많답니다!)
당연히 이 글을 저의 주관적인 생각이므로, [이상하다, 이건 이렇게 접근해야 하지 않냐, 이거 좀 더 설명해달라] 싶다면, 언제든 댓글 남겨주시기 바랍니다.
그리고, 조만간 또 <한 해 회고록>의 시기가 다가오는데요. 올해도 정말 일들이 많았습니다. 저의 예상보다 더 많은 일들이 있었더군요.. 아무튼 오랜만에 쓴 거라 중구난방으로 남겨두고 싶은 말들이 많은데요. 2025년이 끝나는 그 즈음에 저의 회고록으로 뵙겠습니다.
- Reverse-mode Differentiation은 Wikipedia를 참조했을 때 Backpropagation과 같다는 걸 알았다. (Chain rule인데, 어디서부터 계산해야 되냐 이런 거)
- 자 그럼 논문에서 말한 “but incurs a high memory and introduces additional numerical error”라고 적은 이유는 무엇인가? (numerical error는 아직 패스)
다음과 같은 상황을 가정해보자.
$$ \mathbf{h}_{t+1}=\mathbf{h}_t+f(\mathbf{h}_t, t, \theta_t)\:\:(\text{where}\:t\in\{1,2,\dots,N\}) $$
위 변환에서 \(\theta_t\)는 Neural network의 파라미터를 의미한다. 그런데, 이 때 \(N\)이 엄청 크면 backpropagation 시에 메모리를 많이 잡아먹는 문제가 incur한다.
왜 그런지 \(N=1\) case를 통해 알아보자. (초기 조건 \(\mathbf{h}_0\))
$$ \mathbf{h}_{2}=\mathbf{h}_1+f(\mathbf{h}_1, t=1, \theta_1) $$
$$ \mathbf{h}_{1}=\mathbf{h}_0+f(\mathbf{h}_0, t=0, \theta_0) $$
2번에 거친 변환 \(\mathbf{h}_2\)를 Loss function에 입력하고, \(\theta_0\)에 대한 update를 위해 chain rule을 이용하여 \(\frac{\partial L}{\partial \theta_0}\)을 전개할 것이다. 이는 아래와 같다.
$$ \frac{\partial L}{\partial\theta_0} = \frac{\partial L}{\partial \mathbf{h}_2}\cdot\frac{\partial \mathbf{h}_2}{\partial \mathbf{h}_1}\cdot\frac{\partial \mathbf{h}_1}{\partial f_0}\cdot\frac{\partial f_0}{\partial \theta_0} $$
여기서 \(\frac{\partial\mathbf{h}_2}{\partial\mathbf{h}_1}\)를 역전파를 위해 실제로 구해보자. (\(\sigma=\text{sigmoid}\))
$$ \begin{align} \frac{\partial\mathbf{h}_2}{\partial\mathbf{h}_1}&=\frac{\partial\mathbf{h}_1}{\partial\mathbf{h}_1}+\frac{\partial f(\mathbf{h}_1, t=1, \theta_1)}{\partial\mathbf{h}_1}\:\:(\text{where}\:\:f=\sigma(W\mathbf{h}_1+b), \theta_1=\{W, b\})\\ &=I+\frac{\partial\sigma(\mathbf{z})}{\partial \mathbf{z}}\vert_{\mathbf{z}=W\mathbf{h}_1+b}\cdot\frac{\partial(W\mathbf{h}_1+b)}{\partial \mathbf{h}_1} \\ &=I+\{\sigma(\mathbf{z})(1-\sigma(\mathbf{z}))\}\vert_{\mathbf{z}=W\mathbf{h}_1+b}\cdot W \\ &=I+\sigma(W\mathbf{h}_1+b)(1-\sigma(W\mathbf{h}_1+b))\cdot W \end{align} $$
위 \(\frac{\partial\mathbf{h}_2}{\partial\mathbf{h}_1}\)를 구하기 위해서는 \(\mathbf{h}_1\)을 알아야 한다!
hidden state가 무엇인지 값을 알고 있어야 한다. → 메모리 상에 올라가 있어야 한다.
hidden state 하나 메모리에 올려둔다고 해서 크게 문제가 되진 않을 것으로 예상된다. 하지만, continuous-depth neural network라 하지 않았는가. 무수히 많은 변환의 case라면 어떻겠는가?
\(N=K-1\) case를 생각해보자. \(\frac{\partial L}{\partial\theta_0}\)을 동일하게 구해보자.
$$ \frac{\partial L}{\partial\theta_0} = \frac{\partial L}{\partial \mathbf{h}_K}\cdot\frac{\partial \mathbf{h}K}{\partial \mathbf{h}_{K-1}}\cdot\dots\cdot\frac{\partial \mathbf{h}_2}{\partial \mathbf{h}_{1}}\cdot\frac{\partial \mathbf{h}_1}{\partial f_0}\cdot\frac{\partial f_0}{\partial \theta_0} $$
\(\frac{\partial \mathbf{h}_K}{\partial \mathbf{h}_{K-1}}\)을 알려면, \(\mathbf{h}_{k-1}\)를 어딘가에 저장해두고 있어야 한다.
즉, \(K-1\)개의 hidden state를 메모리에 계속 올려두고 있어야 한다는 의미가 된다.
즉, 정리하자면 Residual Networks를 Euler’s Method로 보는 관점(ODE의 수치적 해)으로부터 이는 ODE Solver를 통해 해를 구하는 과정과 유사한 것이라 보았고, 변화량만 Neural network를 그대로 써서 ODE의 수치적 해를 구하는 게 Residual Network의 output을 출력하는 것이라 보고 ODE Solver로 forward를 주려고 한다.
다만, 이 과정에서 ODE Solver의 변화량을 담당하는 Neural network의 학습을 해야 하는데, 모든 hidden state를 메모리에 저장하고 있어야 하는 문제가 생긴다.
그렇다면, Neural ODE 논문은 학습 시 hidden state를 저장하지 않아도 되는 방법을 도출해야 한다. 그게 adjoint sensitivity method라 하는데, 그건 좀 더 리뷰를 해봐야겠다.