์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- pytorch
- ๋ถํ ์ ๋ณต
- Overfitting
- dfs
- ํ๋ก์ด๋ ์์ฌ
- dropout
- NEXT
- ๊ฐ๋์ ๋ง๋ก
- ์๊ณ ๋ฆฌ์ฆ
- ๊ฐ๋์_๋ง๋ก
- ๋ฐฑํธ๋ํน
- object detection
- 2023
- ์ฐ์ ์์ ํ
- ๋ฌธ์์ด
- ์กฐํฉ๋ก
- ๋๋น ์ฐ์ ํ์
- ํ๊ณ ๋ก
- c++
- ๋ค์ต์คํธ๋ผ
- back propagation
- ๋ฏธ๋๋_ํ์ฌ์_๊ณผ๊ฑฐ๋ก
- ์๋ฐ์คํฌ๋ฆฝํธ
- ํฌ๋ฃจ์ค์นผ
- BFS
- ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ
- ์ด๋ถ ํ์
- lazy propagation
- tensorflow
- DP
- Today
- Total
Doby's Lab
self.register_buffer(), ํ์ตํ์ง ์์ ํ๋ผ๋ฏธํฐ๋ผ๋ฉด? (tensor์ ๋ช ๋ฐฑํ๊ฒ ๋ค๋ฅธ ์ 2) ๋ณธ๋ฌธ
self.register_buffer(), ํ์ตํ์ง ์์ ํ๋ผ๋ฏธํฐ๋ผ๋ฉด? (tensor์ ๋ช ๋ฐฑํ๊ฒ ๋ค๋ฅธ ์ 2)
๋๋น(Doby) 2025. 1. 14. 22:56๐ค Problem
์ค๋๋ง์ PyTorch ๊ด๋ จ ๊ธ์
๋๋ค. ์ต๊ทผ์๋ Generative Model ์ชฝ์ ๊ณต๋ถํ๋ฉด์ DDPM์ ๊ตฌํํ๋ค๊ฐ PyTorch์ ์๋ก์ด ๊ธฐ๋ฅ์ ๋ฐ๊ฒฌํ๋๋ฐ์. ๋ฐ๋ก ์ค๋ ๊ธ์ ์ฃผ์ ๊ฐ ๋๋ self.register_buffer()
์
๋๋ค. ๋ณธ ํฌ์คํธ๋ ์์ ์ ์์ฑํ ํฌ์คํธ๋ค ์ค์ 'nn.Parameter(), ์ด๊ฑธ ์จ์ผ ํ๋ ์ด์ ๊ฐ ๋ญ๊น? (tensor์ ๋ช
๋ฐฑํ๊ฒ ๋ค๋ฅธ ์ )'๋ผ๋ ํฌ์คํธ์ ํ์ ํธ์ด ๋๊ธฐ๋ ํฉ๋๋ค.
์ด์ ํฌ์คํธ์ ๋ด์ฉ์ ๊ฐ๋ตํ๊ฒ ๋ฆฌ๋ทฐํด ๋ณด๋ฉด '๋ชจ๋ธ ๋ด์์ ๋จ์ํ torch.tensor()
๋ฅผ ํตํด ์ ์ธํ ํ
์๋ ํ์ต์ ๋์์ด ๋์ง ๋ชปํ๊ณ , ์ด๋ฅผ ๋ช
ํํ๊ฒ ๋ชจ๋ธ ๋ด ํ์ต์ ํ๋ ํ๋ผ๋ฏธํฐ๋ก ์ ์ํ๊ธฐ ์ํด์๋ nn.Parameter()
๋ก ๊ฐ์ธ์ ์ถ๊ฐ์ ์ผ๋ก ์ ์ธํด์ผ ํ๋ค.'๋ผ๋ ๋ด์ฉ์ด์์ต๋๋ค.
์ ๋ด์ฉ์ ๋ง๋ถ์ฌ์ 'torch.tensor()
๋ ํ์ต์ ํ์ง๋ ์์ง๋ง, nn.Parameter()
์ฒ๋ผ ๋ชจ๋ธ ๋ด์ ์ํด ์๋ค๊ณ ๋ณผ ์๋ ์๋ค.'๋ผ๋ ๋ด์ฉ์ ๋จผ์ ์ ๋ฌ๋๋ฆฌ๊ณ ์ถ์ต๋๋ค. '๋ชจ๋ธ ์ฝ๋ ์์ ์ ์ธ์ ํด๋์๋๋ฐ ๋ชจ๋ธ์ ์ํด์์ง ์๋ค๋ ๊ฒ ๋ฌด์จ ๋ง์ด์ง?'๋ผ๊ณ ๋ณด์ค ์ ์์ต๋๋ค๋ง, ์ ๊ฐ ์ ๋ฌํ๊ณ ์ ํ๋ '๋ชจ๋ธ ๋ด์ ์ํด์๋ค'๋ผ๋ ๋ง์ ์ ์๋ model.state_dict()
์ ํฐ ์ฐ๊ด์ด ์์ต๋๋ค.
model.state_dict()
๋ ๋ชจ๋ธ์ weight๋ฅผ ์ ์ฅํ ๋ ์ฌ์ฉํ๋ ๋ฉ์๋์
๋๋ค. ์๋์ ๊ฐ์ด torch.tensor()
๋ฅผ ํตํด ์ ์ํ weight๊ฐ ํฌํจ๋ ๋ชจ๋ธ์ ์ ์ฅํ๋ ค๊ณ ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ ๋๋ ํด๋น weight๊ฐ ๋ณด์ด์ง ์์ต๋๋ค. ์ฆ, ๋ชจ๋ธ์ ๊ตฌ์ฑ์์๋ก ์ธ์ ํ๊ณ ์์ง ์์ต๋๋ค.
class Model(nn.Module):
def __init__(self):
super().__init__()
self.tensor = torch.randn(3, 3)
self.param = nn.Parameter(torch.randn(3, 3))
...
model = Model()
print(model.state_dict())
[Output]
OrderedDict({'param': tensor([[ 0.3108, 0.5101, -0.8290],
[ 0.3511, -0.4658, 0.2131],
[ 0.6602, -1.0786, -0.3299]])})
ํ์ง๋ง, ํ๋ก์ ํธ๋ฅผ ์งํํ๊ฑฐ๋ ๋ ผ๋ฌธ์ ๋ณด๋ค ๋ณด๋ฉด, ํ์ต์ ํ์ง ์์์๋ ๋ถ๊ตฌํ๊ณ ๋ชจ๋ธ ๋ด ๊ตฌ์ฑ ์์๋ก ์ทจ๊ธ๋๋ ํ ์, ํ๋ ฌ, ์ค์นผ๋ผ๋ฅผ ๋ณด๊ฒ ๋ฉ๋๋ค.(ex: DDPM์ alpha, beta)
๊ทธ์ ์ ๊ณ์ ์ฝ๊ธฐ ๋ถํธํ์ง ์๊ฒ 'ํ์ต์ ํ์ง ์๋ ๋ชจ๋ธ์ ๊ตฌ์ฑ ์์'๋ผ๋ ๊ฒ์ ๋ํ๋ผ ๋จ์ด๊ฐ ํ์ํฉ๋๋ค. ํ๊ณ์์๋ ์ด๋ฅผ ๋จ์ํ Non-trainable Parameter, Non-learnable Parameter, Fixed Parameter๋ผ๋ ๋จ์ด๋ก ์ ์๋ฅผ ํฉ๋๋ค. ๋ณธ ํฌ์คํธ์์๋ ํธํ๊ฒ Fixed Parameter๋ผ๋ ๋จ์ด๋ฅผ ์ฌ์ฉํ๊ฒ ์ต๋๋ค.
๋ค์ ๋์์์ ์ด Fixed Parameter๋ ๋จ์ํ torch.tensor()
๋ก ์ ์๋์ง๋ ๋ชป ํ๋ค๋ ์ฌ์ค์ ์๊ฒ ๋์์ต๋๋ค. ๊ทธ๋ ๋ค๋ฉด, PyTorch์์๋ ์ด์ ๋ํด ์ด๋ ํ ๋ฐฉ๋ฒ์ ์ ์ํ๊ณ ์์๊น์?
๐ Solution
PyTorch์์๋ Fixed Parameter๋ฅผ Buffer๋ผ๊ณ ์ ์ํฉ๋๋ค. ์ต๊ทผ ๋
ผ๋ฌธ๋ค์์๋ Buffer๋ผ๋ ๋จ์ด๋ฅผ ์ผ์ปซ๋ ๊ฒฝ์ฐ๋ค์ด ์ข
์ข
์๊ธฐ๋ ํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ , ์ด Buffer๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด์ PyTorch๋ self.register_buffer()
๋ผ๋ ๊ธฐ๋ฅ์ ์ ๊ณตํ๊ณ ์์ต๋๋ค. ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ ๋ง ๊ฐ๋จํฉ๋๋ค.
์๋์ ๊ฐ์ด torch.tensor()
๋ก ์ ์ธํ ๋ค์์ self.register_buffer()
์ ํ ๋นํ์ฌ ์ ์ธํ ์ ์์ต๋๋ค. ์ฒซ ๋ฒ์งธ ์ธ์์๋ Buffer๊ฐ ์ฌ์ฉํ ์ด๋ฆ์ด ๋๊ณ , ๋ ๋ฒ์งธ ์ธ์์๋ Buffer๋ก ์ฌ์ฉํ ํ
์๊ฐ ๋ค์ด๊ฐ๊ฒ ๋ฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ , forward()
ํจ์์์ ๋ค๋ฅธ Parameter์ฒ๋ผ ๋๊ฐ์ ์ ๊ทผ์ด ๊ฐ๋ฅํ ๊ฑธ ๋ณผ ์ ์์ต๋๋ค.
class Model(nn.Module):
def __init__(self):
super().__init__()
buff = torch.randn(3, 3)
self.register_buffer('buff', buff)
def forward(self, x):
print(self.buff)
return x
self.register_buffer()
์ ์ฅ์ ์ ์ด๊ฒ๋ฟ๋ง์ด ์๋๋๋ค. ๋๋ฐ์ด์ค ๊ฐ ์ด๋์ ์์จ์ฑ์ด ๋๋ค๋ ์ฅ์ ์ด ํ๋ ๋ ์์ต๋๋ค. torch.tensor()
๋ก ์ ์ธํ ๊ฒฝ์ฐ์๋ model = Model.to(device)
๋ฅผ ํ๋๋ผ๋ ๋ค๋ฅธ ๋๋ฐ์ด์ค๋ก ๋์ด๊ฐ์ง ์๋ ๋ถํธํจ์ด ์์์ต๋๋ค๋ง, self.register_buffer()
์์๋ ์ด๋ฅผ ํด๊ฒฐํด์ฃผ๊ณ ์์ต๋๋ค.
์ฆ, Buffer๋ก ์ ์ํ๋ ๋ฐฉ๋ฒ์ 2๊ฐ์ง ์ฅ์ ์ด ์๋ค๊ณ ๋ณผ ์ ์์ต๋๋ค.
- ํ์ต์ ํ์ง ์์ง๋ง, ๋ชจ๋ธ์ ๊ตฌ์ฑ์์๋ก ์ธ์ ํ ์ ์๋ค. (๊ฐ์ค์น๋ก ์ ์ฅํ ์ ์๋ค.)
- ๋ชจ๋ธ์ ๋๋ฐ์ด์ค ์ด๋ ์์ ์ด ์๊ธธ ๋ ๋ฒ๊ฑฐ๋ก์ด ์์ ์ด ๋ถํ์ํด์ก๋ค.
๊ทธ๋ฆฌ๊ณ , model.parameters()
๋ model.name_parameters()
๋ก ํ๋ผ๋ฏธํฐ๋ฅผ ํ์ธํ ์ ์๋ ๊ฒ์ฒ๋ผ model.buffers()
๋ model.named_buffers()
์ ๊ฐ์ ๋ฉ์๋๋ก Buffer๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
๋ํ, Buffer๋ผ๋ ํน์ฑ์ ํตํด Distribution(torch.distributions
)๋ ๊ด๋ฆฌ๊ฐ ๊ฐ๋ฅํ์ง ์คํํด ๋ณด์์ผ๋, Buffer๋ Tensor type๋ง ๋ด์ ์ ์๋ ๊ฒ์ผ๋ก ํ์ธํ์์ต๋๋ค.
๋ง์ง๋ง์ผ๋ก, 2๊ฐ์ง ์ฅ์ ์ ํ์ธํ ์ ์๋ ์ฝ๋์ Output์ ๊ฐ์ด ์ฌ๋ฆฌ๋ฉด์ ๊ธ์ ๋ง๋ฌด๋ฆฌํ๊ฒ ์ต๋๋ค.
# Expr 1
import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.tensor = torch.randn(3, 3)
self.param = nn.Parameter(torch.randn(3, 3))
buff = torch.randn(3, 3)
self.register_buffer('buff', buff)
def forward(self, x):
print(self.buff)
return x
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model().to(device)
print("===========================================")
print(f"tensor: \n{model.tensor}")
print(f"{model.tensor.device}")
print("===========================================")
for name, parameter in model.named_parameters():
print(f"{name}: \n{parameter}")
print(f"{parameter.device}")
print("===========================================")
for name, buff in model.named_buffers():
print(f"{name}: \n{buff}")
print(f"{buff.device}")
print("===========================================")
print(f"model state_dict(): param O, buff O, tensor X")
print(model.state_dict())
# Expr 2
# Distribution์ buffer์ register ์ ๋๋ค.
from torch.distributions.normal import Normal
class Model2(nn.Module):
def __init__(self):
super().__init__()
self.gaussian = Normal(loc=torch.zeros(3),
scale=torch.ones(3))
gaussian2 = Normal(loc=torch.zeros(3),
scale=torch.ones(3))
self.register_buffer('gaussian2', gaussian2)
def forward(self, x):
sample = self.gaussian.sample()
sample2 = self.gaussian2.sample()
print(f"sample 1 device: {sample.device}")
print(f"sample 2 device: {sample2.device}")
return x
model2 = Model2()
model2(torch.randn(3, 3))
'''
(.venv) E:\DDPM>python register_buffer_test.py
===========================================
tensor:
tensor([[ 0.5325, 1.3698, -1.2790],
[-0.5546, 0.3236, 0.6196],
[ 2.1521, 1.8287, 1.0600]])
cpu
===========================================
param:
Parameter containing:
tensor([[-1.5823, -1.0639, -0.9007],
[-0.8665, -0.0151, 0.8802],
[ 0.3128, 2.1903, -0.2867]], requires_grad=True)
cpu
===========================================
buff:
tensor([[-0.0590, 0.4823, 2.1716],
[-0.6110, 0.1420, 1.5730],
[ 1.2040, -0.0654, -0.4525]])
cpu
===========================================
model state_dict(): param O, buff O, tensor X
OrderedDict({'param': tensor([[-1.5823, -1.0639, -0.9007],
[-0.8665, -0.0151, 0.8802],
[ 0.3128, 2.1903, -0.2867]]), 'buff': tensor([[-0.0590, 0.4823, 2.1716],
[-0.6110, 0.1420, 1.5730],
[ 1.2040, -0.0654, -0.4525]])})
Traceback (most recent call last):
File "E:\DDPM\register_buffer_test.py", line 64, in <module>
model2 = Model2()
^^^^^^^^
File "E:\DDPM\register_buffer_test.py", line 52, in __init__
self.register_buffer('gaussian2', gaussian2)
File "E:\DDPM\.venv\Lib\site-packages\torch\nn\modules\module.py", line 566, in register_buffer
raise TypeError(
TypeError: cannot assign 'torch.distributions.normal.Normal' object to buffer 'gaussian2' (torch Tensor or None required)
'''
๐ Reference
https://pytorch.org/docs/stable/generated/torch.nn.Module.html
Module — PyTorch 2.5 documentation
Shortcuts
pytorch.org
https://www.ai-bio.info/pytorch-register-buffer
PyTorch์์ register_buffer๋ฅผ ์จ์ผํ๋ ์ด์
PyTorch model์์ register_buffer๋ฅผ ์ฌ์ฉํ๋ ์ด์ ์ ๋ํด ์์๋ด ๋๋ค.
www.ai-bio.info