๊ด€๋ฆฌ ๋ฉ”๋‰ด

Doby's Lab

self.register_buffer(), ํ•™์Šตํ•˜์ง€ ์•Š์„ ํŒŒ๋ผ๋ฏธํ„ฐ๋ผ๋ฉด? (tensor์™€ ๋ช…๋ฐฑํ•˜๊ฒŒ ๋‹ค๋ฅธ ์  2) ๋ณธ๋ฌธ

Code about AI/PyTorch

self.register_buffer(), ํ•™์Šตํ•˜์ง€ ์•Š์„ ํŒŒ๋ผ๋ฏธํ„ฐ๋ผ๋ฉด? (tensor์™€ ๋ช…๋ฐฑํ•˜๊ฒŒ ๋‹ค๋ฅธ ์  2)

๋„๋น„(Doby) 2025. 1. 14. 22:56

๐Ÿค” Problem

์˜ค๋žœ๋งŒ์— PyTorch ๊ด€๋ จ ๊ธ€์ž…๋‹ˆ๋‹ค. ์ตœ๊ทผ์—๋Š” Generative Model ์ชฝ์„ ๊ณต๋ถ€ํ•˜๋ฉด์„œ DDPM์„ ๊ตฌํ˜„ํ•˜๋‹ค๊ฐ€ PyTorch์˜ ์ƒˆ๋กœ์šด ๊ธฐ๋Šฅ์„ ๋ฐœ๊ฒฌํ–ˆ๋Š”๋ฐ์š”. ๋ฐ”๋กœ ์˜ค๋Š˜ ๊ธ€์˜ ์ฃผ์ œ๊ฐ€ ๋˜๋Š” self.register_buffer()์ž…๋‹ˆ๋‹ค. ๋ณธ ํฌ์ŠคํŠธ๋Š” ์˜ˆ์ „์— ์ž‘์„ฑํ•œ ํฌ์ŠคํŠธ๋“ค ์ค‘์— 'nn.Parameter(), ์ด๊ฑธ ์จ์•ผ ํ•˜๋Š” ์ด์œ ๊ฐ€ ๋ญ˜๊นŒ? (tensor์™€ ๋ช…๋ฐฑํ•˜๊ฒŒ ๋‹ค๋ฅธ ์ )'๋ผ๋Š” ํฌ์ŠคํŠธ์˜ ํ›„์† ํŽธ์ด ๋˜๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค. 

 

์ด์ „ ํฌ์ŠคํŠธ์˜ ๋‚ด์šฉ์„ ๊ฐ„๋žตํ•˜๊ฒŒ ๋ฆฌ๋ทฐํ•ด ๋ณด๋ฉด '๋ชจ๋ธ ๋‚ด์—์„œ ๋‹จ์ˆœํžˆ torch.tensor()๋ฅผ ํ†ตํ•ด ์„ ์–ธํ•œ ํ…์„œ๋Š” ํ•™์Šต์˜ ๋Œ€์ƒ์ด ๋˜์ง€ ๋ชปํ•˜๊ณ , ์ด๋ฅผ ๋ช…ํ™•ํ•˜๊ฒŒ ๋ชจ๋ธ ๋‚ด ํ•™์Šต์„ ํ•˜๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์ •์˜ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” nn.Parameter()๋กœ ๊ฐ์‹ธ์„œ ์ถ”๊ฐ€์ ์œผ๋กœ ์„ ์–ธํ•ด์•ผ ํ•œ๋‹ค.'๋ผ๋Š” ๋‚ด์šฉ์ด์—ˆ์Šต๋‹ˆ๋‹ค.

 

์œ„ ๋‚ด์šฉ์— ๋ง๋ถ™์—ฌ์„œ 'torch.tensor()๋Š” ํ•™์Šต์„ ํ•˜์ง€๋Š” ์•Š์ง€๋งŒ, nn.Parameter()์ฒ˜๋Ÿผ ๋ชจ๋ธ ๋‚ด์— ์†ํ•ด ์žˆ๋‹ค๊ณ  ๋ณผ ์ˆ˜๋Š” ์—†๋‹ค.'๋ผ๋Š” ๋‚ด์šฉ์„ ๋จผ์ € ์ „๋‹ฌ๋“œ๋ฆฌ๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. '๋ชจ๋ธ ์ฝ”๋“œ ์•ˆ์— ์„ ์–ธ์„ ํ•ด๋‘์—ˆ๋Š”๋ฐ ๋ชจ๋ธ์— ์†ํ•ด์žˆ์ง€ ์•Š๋‹ค๋Š” ๊ฒŒ ๋ฌด์Šจ ๋ง์ด์ง€?'๋ผ๊ณ  ๋ณด์‹ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค๋งŒ, ์ œ๊ฐ€ ์ „๋‹ฌํ•˜๊ณ ์ž ํ•˜๋Š” '๋ชจ๋ธ ๋‚ด์— ์†ํ•ด์žˆ๋‹ค'๋ผ๋Š” ๋ง์˜ ์ •์˜๋Š” model.state_dict()์™€ ํฐ ์—ฐ๊ด€์ด ์žˆ์Šต๋‹ˆ๋‹ค.

 

model.state_dict()๋Š” ๋ชจ๋ธ์˜ weight๋ฅผ ์ €์žฅํ•  ๋•Œ ์‚ฌ์šฉํ•˜๋Š” ๋ฉ”์„œ๋“œ์ž…๋‹ˆ๋‹ค. ์•„๋ž˜์™€ ๊ฐ™์ด torch.tensor()๋ฅผ ํ†ตํ•ด ์ •์˜ํ•œ weight๊ฐ€ ํฌํ•จ๋œ ๋ชจ๋ธ์„ ์ €์žฅํ•˜๋ ค๊ณ  ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ๋Š” ํ•ด๋‹น weight๊ฐ€ ๋ณด์ด์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ฆ‰, ๋ชจ๋ธ์˜ ๊ตฌ์„ฑ์š”์†Œ๋กœ ์ธ์ •ํ•˜๊ณ  ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.tensor = torch.randn(3, 3)
        
        self.param = nn.Parameter(torch.randn(3, 3))
    ...

model = Model()
print(model.state_dict())

[Output]
OrderedDict({'param': tensor([[ 0.3108,  0.5101, -0.8290],
        [ 0.3511, -0.4658,  0.2131],
        [ 0.6602, -1.0786, -0.3299]])})

 

ํ•˜์ง€๋งŒ, ํ”„๋กœ์ ํŠธ๋ฅผ ์ง„ํ–‰ํ•˜๊ฑฐ๋‚˜ ๋…ผ๋ฌธ์„ ๋ณด๋‹ค ๋ณด๋ฉด, ํ•™์Šต์„ ํ•˜์ง€ ์•Š์Œ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ๋ชจ๋ธ ๋‚ด ๊ตฌ์„ฑ ์š”์†Œ๋กœ ์ทจ๊ธ‰๋˜๋Š” ํ…์„œ, ํ–‰๋ ฌ, ์Šค์นผ๋ผ๋ฅผ ๋ณด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.(ex: DDPM์˜ alpha, beta)

 

๊ทธ์ „์— ๊ณ„์† ์ฝ๊ธฐ ๋ถˆํŽธํ•˜์ง€ ์•Š๊ฒŒ 'ํ•™์Šต์„ ํ•˜์ง€ ์•Š๋Š” ๋ชจ๋ธ์˜ ๊ตฌ์„ฑ ์š”์†Œ'๋ผ๋Š” ๊ฒƒ์„ ๋‚˜ํƒ€๋‚ผ ๋‹จ์–ด๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ํ•™๊ณ„์—์„œ๋Š” ์ด๋ฅผ ๋‹จ์ˆœํžˆ Non-trainable Parameter, Non-learnable Parameter, Fixed Parameter๋ผ๋Š” ๋‹จ์–ด๋กœ ์ •์˜๋ฅผ ํ•ฉ๋‹ˆ๋‹ค. ๋ณธ ํฌ์ŠคํŠธ์—์„œ๋Š” ํŽธํ•˜๊ฒŒ Fixed Parameter๋ผ๋Š” ๋‹จ์–ด๋ฅผ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

 

๋‹ค์‹œ ๋Œ์•„์™€์„œ ์ด Fixed Parameter๋Š” ๋‹จ์ˆœํžˆ torch.tensor()๋กœ ์ •์˜๋˜์ง€๋Š” ๋ชป ํ•œ๋‹ค๋Š” ์‚ฌ์‹ค์„ ์•Œ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด, PyTorch์—์„œ๋Š” ์ด์— ๋Œ€ํ•ด ์–ด๋– ํ•œ ๋ฐฉ๋ฒ•์„ ์ œ์‹œํ•˜๊ณ  ์žˆ์„๊นŒ์š”?


๐Ÿ˜€ Solution

PyTorch์—์„œ๋Š” Fixed Parameter๋ฅผ Buffer๋ผ๊ณ  ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. ์ตœ๊ทผ ๋…ผ๋ฌธ๋“ค์—์„œ๋„ Buffer๋ผ๋Š” ๋‹จ์–ด๋ฅผ ์ผ์ปซ๋Š” ๊ฒฝ์šฐ๋“ค์ด ์ข…์ข… ์žˆ๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ , ์ด Buffer๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ PyTorch๋Š” self.register_buffer()๋ผ๋Š” ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ์ •๋ง ๊ฐ„๋‹จํ•ฉ๋‹ˆ๋‹ค. 

 

์•„๋ž˜์™€ ๊ฐ™์ด torch.tensor()๋กœ ์„ ์–ธํ•œ ๋‹ค์Œ์— self.register_buffer()์— ํ• ๋‹นํ•˜์—ฌ ์„ ์–ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฒซ ๋ฒˆ์งธ ์ธ์ž์—๋Š” Buffer๊ฐ€ ์‚ฌ์šฉํ•  ์ด๋ฆ„์ด ๋˜๊ณ , ๋‘ ๋ฒˆ์งธ ์ธ์ž์—๋Š” Buffer๋กœ ์‚ฌ์šฉํ•  ํ…์„œ๊ฐ€ ๋“ค์–ด๊ฐ€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ , forward() ํ•จ์ˆ˜์—์„œ ๋‹ค๋ฅธ Parameter์ฒ˜๋Ÿผ ๋˜‘๊ฐ™์€ ์ ‘๊ทผ์ด ๊ฐ€๋Šฅํ•œ ๊ฑธ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        buff = torch.randn(3, 3)
        self.register_buffer('buff', buff)
    
    def forward(self, x):
        print(self.buff)
        return x

 

self.register_buffer()์˜ ์žฅ์ ์€ ์ด๊ฒƒ๋ฟ๋งŒ์ด ์•„๋‹™๋‹ˆ๋‹ค. ๋””๋ฐ”์ด์Šค ๊ฐ„ ์ด๋™์˜ ์ž์œจ์„ฑ์ด ๋†’๋‹ค๋Š” ์žฅ์ ์ด ํ•˜๋‚˜ ๋” ์žˆ์Šต๋‹ˆ๋‹ค. torch.tensor()๋กœ ์„ ์–ธํ•œ ๊ฒฝ์šฐ์—๋Š” model = Model.to(device)๋ฅผ ํ•˜๋”๋ผ๋„ ๋‹ค๋ฅธ ๋””๋ฐ”์ด์Šค๋กœ ๋„˜์–ด๊ฐ€์ง€ ์•Š๋Š” ๋ถˆํŽธํ•จ์ด ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค๋งŒ, self.register_buffer()์—์„œ๋Š” ์ด๋ฅผ ํ•ด๊ฒฐํ•ด์ฃผ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

 

์ฆ‰, Buffer๋กœ ์ •์˜ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ 2๊ฐ€์ง€ ์žฅ์ ์ด ์žˆ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  1. ํ•™์Šต์„ ํ•˜์ง€ ์•Š์ง€๋งŒ, ๋ชจ๋ธ์˜ ๊ตฌ์„ฑ์š”์†Œ๋กœ ์ธ์ •ํ•  ์ˆ˜ ์žˆ๋‹ค. (๊ฐ€์ค‘์น˜๋กœ ์ €์žฅํ•  ์ˆ˜ ์žˆ๋‹ค.)
  2. ๋ชจ๋ธ์˜ ๋””๋ฐ”์ด์Šค ์ด๋™ ์ž‘์—…์ด ์ƒ๊ธธ ๋•Œ ๋ฒˆ๊ฑฐ๋กœ์šด ์ž‘์—…์ด ๋ถˆํ•„์š”ํ•ด์กŒ๋‹ค.

๊ทธ๋ฆฌ๊ณ , model.parameters()๋‚˜ model.name_parameters()๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ์ฒ˜๋Ÿผ model.buffers()๋‚˜ model.named_buffers()์™€ ๊ฐ™์€ ๋ฉ”์„œ๋“œ๋กœ Buffer๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

๋˜ํ•œ, Buffer๋ผ๋Š” ํŠน์„ฑ์„ ํ†ตํ•ด Distribution(torch.distributions)๋„ ๊ด€๋ฆฌ๊ฐ€ ๊ฐ€๋Šฅํ•œ์ง€ ์‹คํ—˜ํ•ด ๋ณด์•˜์œผ๋‚˜, Buffer๋Š” Tensor type๋งŒ ๋‹ด์„ ์ˆ˜ ์žˆ๋Š” ๊ฒƒ์œผ๋กœ ํ™•์ธํ•˜์˜€์Šต๋‹ˆ๋‹ค.

 

๋งˆ์ง€๋ง‰์œผ๋กœ, 2๊ฐ€์ง€ ์žฅ์ ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋Š” ์ฝ”๋“œ์™€ Output์„ ๊ฐ™์ด ์˜ฌ๋ฆฌ๋ฉด์„œ ๊ธ€์„ ๋งˆ๋ฌด๋ฆฌํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

# Expr 1
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.tensor = torch.randn(3, 3)
        
        self.param = nn.Parameter(torch.randn(3, 3))
        
        buff = torch.randn(3, 3)
        self.register_buffer('buff', buff)
    
    def forward(self, x):
        print(self.buff)
        return x

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model().to(device)

print("===========================================")
print(f"tensor: \n{model.tensor}")
print(f"{model.tensor.device}")

print("===========================================")
for name, parameter in model.named_parameters():
    print(f"{name}: \n{parameter}")
    print(f"{parameter.device}")

print("===========================================")
for name, buff in model.named_buffers():
    print(f"{name}: \n{buff}")
    print(f"{buff.device}")

print("===========================================")
print(f"model state_dict(): param O, buff O, tensor X")
print(model.state_dict())

# Expr 2
# Distribution์€ buffer์— register ์•ˆ ๋œ๋‹ค.
from torch.distributions.normal import Normal

class Model2(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.gaussian = Normal(loc=torch.zeros(3),
                               scale=torch.ones(3))
        
        gaussian2 = Normal(loc=torch.zeros(3),
                           scale=torch.ones(3))
        
        self.register_buffer('gaussian2', gaussian2)
    
    def forward(self, x):
        sample = self.gaussian.sample()
        
        sample2 = self.gaussian2.sample()
        
        print(f"sample 1 device: {sample.device}")
        print(f"sample 2 device: {sample2.device}")
        
        return x

model2 = Model2()
model2(torch.randn(3, 3))

'''
(.venv) E:\DDPM>python register_buffer_test.py
===========================================
tensor: 
tensor([[ 0.5325,  1.3698, -1.2790],
        [-0.5546,  0.3236,  0.6196],
        [ 2.1521,  1.8287,  1.0600]])
cpu
===========================================
param:
Parameter containing:
tensor([[-1.5823, -1.0639, -0.9007],
        [-0.8665, -0.0151,  0.8802],
        [ 0.3128,  2.1903, -0.2867]], requires_grad=True)
cpu
===========================================
buff:
tensor([[-0.0590,  0.4823,  2.1716],
        [-0.6110,  0.1420,  1.5730],
        [ 1.2040, -0.0654, -0.4525]])
cpu
===========================================
model state_dict(): param O, buff O, tensor X
OrderedDict({'param': tensor([[-1.5823, -1.0639, -0.9007],
        [-0.8665, -0.0151,  0.8802],
        [ 0.3128,  2.1903, -0.2867]]), 'buff': tensor([[-0.0590,  0.4823,  2.1716],
        [-0.6110,  0.1420,  1.5730],
        [ 1.2040, -0.0654, -0.4525]])})
Traceback (most recent call last):
  File "E:\DDPM\register_buffer_test.py", line 64, in <module>
    model2 = Model2()
             ^^^^^^^^
  File "E:\DDPM\register_buffer_test.py", line 52, in __init__
    self.register_buffer('gaussian2', gaussian2)
  File "E:\DDPM\.venv\Lib\site-packages\torch\nn\modules\module.py", line 566, in register_buffer
    raise TypeError(
TypeError: cannot assign 'torch.distributions.normal.Normal' object to buffer 'gaussian2' (torch Tensor or None required)
'''

๐Ÿ“‚ Reference

https://pytorch.org/docs/stable/generated/torch.nn.Module.html

 

Module — PyTorch 2.5 documentation

Shortcuts

pytorch.org

https://www.ai-bio.info/pytorch-register-buffer

 

PyTorch์—์„œ register_buffer๋ฅผ ์จ์•ผํ•˜๋Š” ์ด์œ 

PyTorch model์—์„œ register_buffer๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ ์— ๋Œ€ํ•ด ์•Œ์•„๋ด…๋‹ˆ๋‹ค.

www.ai-bio.info