๊ด€๋ฆฌ ๋ฉ”๋‰ด

Doby's Lab

DropPath๋ž€ ๋ฌด์—‡์ด๋ฉฐ, Dropout๊ณผ ๋ฌด์Šจ ์ฐจ์ด๊ฐ€ ์žˆ์„๊นŒ? (timm ํ™œ์šฉ ๋ฐ ์˜คํ”ˆ์†Œ์Šค ๋ถ„์„) ๋ณธ๋ฌธ

Code about AI/PyTorch

DropPath๋ž€ ๋ฌด์—‡์ด๋ฉฐ, Dropout๊ณผ ๋ฌด์Šจ ์ฐจ์ด๊ฐ€ ์žˆ์„๊นŒ? (timm ํ™œ์šฉ ๋ฐ ์˜คํ”ˆ์†Œ์Šค ๋ถ„์„)

๋„๋น„(Doby) 2024. 5. 12. 12:17

๐Ÿค” Problem

๋น„์ „ ๋ชจ๋ธ ์˜คํ”ˆ์†Œ์Šค๋ฅผ ๋ณด๋‹ค ๋ณด๋ฉด, ์ข…์ข… DropPath๋ผ๋Š” ํด๋ž˜์Šค๋กœ๋ถ€ํ„ฐ ์ธ์Šคํ„ด์Šค๋ฅผ ์ƒ์„ฑํ•˜์—ฌ ๋ชจ๋ธ์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ๋ฅผ ์ž์ฃผ ๋ด…๋‹ˆ๋‹ค. ๋˜ํ•œ, ์ด DropPath๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” timm์ด๋ผ๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

 

๊ทธ๋ž˜์„œ, ์˜ค๋Š˜์€ DropPath๊ฐ€ ๋ฌด์—‡์ด๋ฉฐ, timm์ด๋ผ๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” ๋ฌด์—‡์ธ์ง€ ๊ทธ ๋‚ด๋ถ€์— ์–ด๋–ป๊ฒŒ ๊ตฌํ˜„๋˜์–ด ์žˆ๋Š”์ง€๋ฅผ ๊ธฐ๋กํ•ด๋‘๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.


๐Ÿ˜€ DropPath๋ž€?(= Stochastic Depth)

DropPath๋ž€ Dropout์˜ ์ด๋ฆ„๊ณผ ์œ ์‚ฌํ•˜๊ฒŒ ๊ธฐ๋Šฅ๋„ ์œ ์‚ฌํ•œ ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

์ด ๊ฐœ๋…์€ Deep Networks with Stochastic Depth์—์„œ ๋“ฑ์žฅํ•œ ๊ฐœ๋…์œผ๋กœ Residual connection์˜ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง„ ๋ชจ๋ธ์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ธฐ๋Šฅ์ž…๋‹ˆ๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” ์ด๋ฅผ Stochastic Depth๋ผ ์–ธ๊ธ‰ํ•˜๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค.

Residual Connection Structure

์œ„ ๊ทธ๋ฆผ์—์„œ DropPath๊ฐ€ ์ˆ˜ํ–‰ํ•˜๊ณ ์ž ํ•˜๋Š” ๊ฒƒ์€ ๋žœ๋คํ•œ ํ™•๋ฅ ์— ๋”ฐ๋ผ \(f_l(H_{l-1})\)์„ ์ƒ๋žตํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ฆ‰, ๋ชจ๋ธ์˜ Output์ด ๋žœ๋คํ•œ ํ™•๋ฅ ์— ๋”ฐ๋ผ \(f(x)+x\)๊ฐ€ ๋  ์ˆ˜๋„ ์žˆ๊ณ , \(x\)๊ฐ€ ๋  ์ˆ˜๋„ ์žˆ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ , ์ด๊ฒƒ์€ ์šฐ๋ฆฌ๊ฐ€ ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ํ•™์Šตํ•  ๋•Œ, ๋ฐฐ์น˜ ๋‚ด์˜ ์ƒ˜ํ”Œ ๋ณ„๋กœ ๋žœ๋คํ•œ ํ™•๋ฅ ์— ์˜ํ•ด ์ ์šฉ๋ฉ๋‹ˆ๋‹ค.

 

๊ทธ๋Ÿฌ๋ฉด, ์ด๋ฅผ ํ†ตํ•ด ์ข‹์•„์ง€๋Š” ๊ฒƒ์ด ๋ฌด์—‡์ด๋ƒ?

1๏ธโƒฃ Training time savings

DropPath๊ฐ€ ์ ์šฉ์ด ๋˜๋ฉด, ๊ธฐ์กด Block์—์„œ Residual connection ๊ตฌ์กฐ๋งŒ ๋‚จ๊ฒŒ ๋˜๋Š” ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์— ๋„คํŠธ์›Œํฌ์˜ ๊ธธ์ด๊ฐ€ ์งง์•„์ง‘๋‹ˆ๋‹ค. ์ด์— ๋”ฐ๋ผ, ํ•™์Šต ์‹œ๊ฐ„์ด ์ค„์–ด๋“ค๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

Survival probability์— ๋”ฐ๋ฅธ Test Error

์œ„ ๋…ผ๋ฌธ์—์„œ ์‹คํ—˜ํ•œ ๊ฒฐ๊ณผ๋กœ, Survival probability๊ฐ€ 0.5์ผ ๋•Œ๋Š” ํ•™์Šต ์‹œ๊ฐ„์ด 25% ์ •๋„๋กœ ์ค„์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ , 0.5์ผ ๋•Œ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ์ œ์ผ Optimal ํ–ˆ์Šต๋‹ˆ๋‹ค.

 

๋˜ํ•œ, Survival probability๊ฐ€ 0.2์ผ ๋•Œ๋Š” ์„ฑ๋Šฅ ์ž์ฒด๋Š” DropPath๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์€ ๊ฒƒ๊ณผ ๋น„์Šทํ•œ๋ฐ, ํ•™์Šต ์‹œ๊ฐ„์€ 40%๋กœ ํ™• ์ค„์—ˆ์Šต๋‹ˆ๋‹ค.

2๏ธโƒฃ Implict model ensemble

๊ทธ๋ฆฌ๊ณ , ๋ชจ๋ธ ๋‚ด์— \(L\)๊ฐœ์˜ Residual Block์ด ์žˆ๋‹ค๊ณ  ํ–ˆ์„ ๋•Œ, DropPath๋ฅผ ์ ์šฉ์‹œํ‚ค๋ฉด ์„œ๋กœ ๋‹ค๋ฅธ \(2^L\)๊ฐœ์˜ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๋Š” ๊ฒƒ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. (์™œ๋ƒํ•˜๋ฉด, ์ ์šฉํ•˜๋Š๋ƒ ์•ˆ ํ•˜๋Š๋ƒ์˜ ๊ฒฝ์šฐ์˜ ์ˆ˜๊ฐ€ ์žˆ์–ด์„œ) ์ฆ‰, ์•™์ƒ๋ธ”์˜ ํšจ๊ณผ๋ฅผ ๊ฐ€์ ธ์˜จ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

์ด๋ฅผ ์ฆ๋ช…ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฐ๊ณผ๋กœ ๋ฒค์น˜๋งˆํฌ ๋ฐ์ดํ„ฐ์…‹์„ ํ•™์Šต์‹œ์ผœ Test error๋ฅผ ์—ฌ๋Ÿฌ ์—ฐ๊ตฌ๋“ค๊ณผ ๋น„๊ตํ•œ ๊ฒฐ๊ณผ, ํ™•์‹คํžˆ ๋” ์„ฑ๋Šฅ์ด ์ข‹์Œ์„ ๋ณด์˜€์Šต๋‹ˆ๋‹ค.

Model๊ณผ Method์— ๋”ฐ๋ฅธ Test error ๋น„๊ต ํ‘œ


๐Ÿ˜€ timm ๋‚ด ์˜คํ”ˆ์†Œ์Šค ๋ถ„์„ & ๋ฐฐ์šด ์ 

timm ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” Hugging face์—์„œ ์ด๋ฏธ์ง€ ๋ชจ๋ธ(pre-train๋œ ๋ชจ๋ธ ํฌํ•จ)์„ ์ œ๊ณตํ•˜๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค. ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์—ฌ๋Ÿฌ ๊ธฐ๋Šฅ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ ˆ์ด์–ด๋„ ์˜คํ”ˆ ์†Œ์Šค๋กœ ์ œ๊ณต์„ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. (timm.models.layers)

 

๋‚ด๋ถ€์— ๊ตฌํ˜„๋˜์–ด์žˆ๋Š” DropPath ํด๋ž˜์Šค์—์„œ ๊ตฌํ˜„์ ์œผ๋กœ ๋ฐฐ์šด ๊ฒŒ ๋งŽ์•„์„œ ์ด์— ๋Œ€ํ•ด์„œ ์ •๋ฆฌ๋ฅผ ํ•ด๋ณด๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ์ „๋ถ€๋ฅผ ์ ๊ธฐ์—๋Š” ๋‹ค์†Œ ๊ธ€์ด ๊ธธ์–ด์ง€๊ธฐ ๋•Œ๋ฌธ์— ๋‚ด์šฉ๋“ค์€ ๋ชจ๋‘ ์ฃผ์„์œผ๋กœ ์ ์–ด๋‘์—ˆ์Šต๋‹ˆ๋‹ค.

 

https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py#L150

 

pytorch-image-models/timm/layers/drop.py at main · huggingface/pytorch-image-models

PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNet-V3/V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeX...

github.com

์œ„ ์˜คํ”ˆ์†Œ์Šค์˜ 150๋ฒˆ ์ค„๋ถ€ํ„ฐ ์ฃผ์„์„ ์ ์–ด๋‘์—ˆ์Šต๋‹ˆ๋‹ค.

def drop_path(x, drop_prob: float = 0., training: bool= False, scale_by_keep: bool = True):
    r"""
    ์ด ํ•จ์ˆ˜๋Š” f(x) + x์—์„œ x์— ์ ์šฉ๋˜๋Š” ํ•จ์ˆ˜์ด๋‹ค.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    
    # (1) ์•„๋ž˜์˜ shape์ด ๊ฐ€์ง€๋Š” ์˜๋ฏธ
    """
    Batch๋ฅผ ๊ณ ๋ คํ•œ (x.shape[0],)๋ฅผ ์ œ์™ธํ•˜๊ณ  ๋ณธ๋‹ค๋ฉด,
    (1,) * (x.ndim - 1)๋งŒ ๋ณด๋ฉด ๋˜๋Š”๋ฐ
    ์ด๋Š” ์›๋ž˜ sample ์ž์ฒด์˜ ndim๋งŒํผ tensor์˜ shape์„ ๊ฐ–๋„๋ก ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
    
    ์ฆ‰, (Batch Size, 1, 1, ..., 1)์˜ shape์ด ๋œ๋‹ค๋Š” ์˜๋ฏธ์ธ๋ฐ
    ์ด๋Š” ํ…์„œ์˜ ๋ชจ๋“  ์›์†Œ๋“ค์— ๋Œ€ํ•ด ์—ฐ์‚ฐ์„ ํ•  ์ˆ˜ ์žˆ๋„๋ก ์˜๋„์ ์œผ๋กœ ์ด๋Ÿฐ
    shape์„ ๊ฐ–๋„๋ก ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. = Broading Semantics ์‚ฌ์šฉ
    
    -> ์ด๋ ‡๊ฒŒ ํ•ด์•ผ Sample๋งˆ๋‹ค ๋…๋ฆฝ์ ์œผ๋กœ Drop์ด ๊ฐ€๋Šฅํ•˜๋‹ค.
    """
    
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    
    # (2) Tensor.new_empty(size)
    """
    https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html
    
    size๋งŒํผ์˜ uninitialized data๋กœ ํ…์„œ๋ฅผ ์ฑ„์šด๋‹ค.
    ์›๋ณธ์˜ ํ…์„œ์™€ ๊ฐ™์€ torch.dtype, torch.device๋ฅผ ๊ฐ–๊ฒŒ ๋œ๋‹ค.
    ๊ทธ๋ž˜์„œ ์ด ๋ฉ”์„œ๋“œ์˜ ์šฉ๋„๋Š” ์•„๋งˆ ์›๋ณธ ํ…์„œ์™€ ๋™์ผํ•œ ์ƒํƒœ(dtype, device)๋ฅผ 
    ๊ฐ€์ง€๋„๋ก ํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์ถ”์ธกํ•˜๊ณ  ์žˆ๋‹ค.
    
    new_empty๋ฅผ ์“ด ๊ฑด bernouli_()๋ฅผ ํ†ตํ•ด์„œ ์ƒˆ๋กœ์šด ๊ฐ’์œผ๋กœ ์ฑ„์šธ ๊ฒƒ์ด๋ผ
    ์–ด๋–ค ํŠน์ • ๊ฐ’์œผ๋กœ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ๊ฒƒ์€ ๋ฉ”๋ชจ๋ฆฌ ๋น„ํšจ์œจ์ ์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
    """
    
    # (3) Tensor.bernoulli_()
    """
    https://pytorch.org/docs/stable/generated/torch.Tensor.bernoulli_.html
    
    ํ…์„œ ๊ฐ๊ฐ์˜ element๋“ค์ด ๋ฒ ๋ฅด๋ˆ„์ด ์‹œํ–‰์„ ํ†ตํ•ด์„œ ๋‚˜์˜จ ๊ฒฐ๊ณผ๋‹ค.
    
    random_tensor์— ํ• ๋‹นํ•  ๊ฑฐ๋ฉด, ๊ตณ์ด inplace operation์œผ๋กœ ์‚ฌ์šฉํ–ˆ์–ด์•ผ ํ–ˆ๋‚˜?
    ์ด๊ฒƒ์€ inplace operation์œผ๋กœ ํ•ด์•ผ Tensor์˜ mutability๊ฐ€ ๋ณด์žฅ๋˜์–ด
    ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ํšจ์œจ์ ์ธ ๊ด€์ ์—์„œ bernoulli()๊ฐ€ ์•„๋‹Œ
    bernoulli_()๊ฐ€ ์‚ฌ์šฉ๋˜์—ˆ๋‹ค๊ณ  ์ถ”์ธกํ•œ๋‹ค.
    """
    
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    
    if keep_prob > 0.0 and scale_by_keep:
        """
        (๋‚˜์˜ ์ถ”์ธก)
        ๋…ผ๋ฌธ์—์„œ scale_by_keep์˜ ๋ชฉ์ ์€ ์ฐพ์ง€๋Š” ๋ชป ํ•˜์˜€์ง€๋งŒ,
        ์•„๋งˆ ์ˆ˜์‹์ƒ์œผ๋กœ ๋ดค์„ ๋•Œ, keep_prob์— ์˜ํ•ด ์‚ฌ๋ผ์ง€๋Š” sample๋“ค์— ๋Œ€ํ•ด์„œ
        ๋ชจ๋ธ์€ ๋ชจ๋ธ ๋‚ด๋ถ€ ํŠน์ • ๋ ˆ์ด์–ด๋“ค์ด ํ•™์Šตํ•˜๊ฒŒ ๋˜๋Š” ๋ฐ์ดํ„ฐ์˜ ์ˆ˜๊ฐ€ ์ค„์–ด๋“œ๋Š” ๊ฒƒ์ด๋ผ
        ํŒ๋‹จํ•˜๊ฒŒ ๋œ๋‹ค. 
        
        ์ด ๋•Œ๋ฌธ์— ๋ ˆ์ด์–ด์˜ ๊ด€์ ์—์„œ Drop์ด ๋˜๋Š” ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด ๊ฐœ์ˆ˜๊ฐ€
        ๋ณด์žฅ์ด ๋˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด์„œ Drop์ด ๋˜์ง€ ์•Š๋Š” ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ€์—ฌํ•œ๋‹ค.
        
        -> ์ด๋Š” keep_prob์˜ ์—ญ์ˆ˜๋กœ ๋ณด์žฅํ•œ๋‹ค.
        
        ์ด๋ฅผ ํ†ตํ•ด, ์—†์–ด์ง„ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด ๋‹ค๋ฅธ ์ƒ˜ํ”Œ์˜ ๊ฐ€์ค‘์น˜ ๋” ๋ถ€์—ฌํ•จ์œผ๋กœ์จ
        ํ•™์Šตํ•  ๋ฐ์ดํ„ฐ์˜ ์ˆ˜๋ฅผ ๋ณด์กดํ•˜๋Š” ํšจ๊ณผ๋ฅผ ๊ฐ€์ง€๋„๋ก ํ•œ๋‹ค.
        """
        random_tensor.div_(keep_prob)
    return x * random_tensor

class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep
        
    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

๐Ÿ“„ ์•Œ๊ฒŒ๋œ ๊ฒƒ

๋‹ค์–‘ํ•œ ๋ฉ”์„œ๋“œ์˜ ์‚ฌ์šฉ์ž…๋‹ˆ๋‹ค. ํŠนํžˆ, ํ…์„œ ๋‚ด์žฅ ๋ฉ”์„œ๋“œ๋ฅผ ๋งŽ์ด ์‚ฌ์šฉํ–ˆ์œผ๋ฉฐ, ๋ฉ”์„œ๋“œ ์ด๋ฆ„ ๋์— '_'(์–ธ๋” ๋ฐ”) ์œ ๋ฌด์— ๋”ฐ๋ผ inplace operation์ธ์ง€ ์•„๋‹Œ์ง€๊ฐ€ ๋‹ฌ๋ผ์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ์˜ ๊ด€์ (Mutable, Immutable)์—์„œ ์™œ inplace operation์„ ์‚ฌ์šฉํ–ˆ๋Š”์ง€ ์•Œ์•„๋ณด๋Š” ๊ฒƒ๋„ ์žฌ๋ฐŒ์—ˆ์Šต๋‹ˆ๋‹ค. div_(), bernoulli_()

 

๋˜ํ•œ, Tuple์˜ ์—ฐ์‚ฐ ๊ทœ์น™์„ ์‚ฌ์šฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ํ…์„œ์˜ Shape์„ ๋งŒ๋“œ๋Š” ๊ณผ์ •์ด ๋ฐฐ์šธ ์ ์ด๋ผ๊ณ  ๋Š๊ผˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์—์„œ ์ด์ „์— ๋ฐฐ์šด Broadcasting Semantics๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๊ฐ ์ƒ˜ํ”Œ์— DropPath๋ฅผ ์ ์šฉํ•˜๋Š” ๊ฒƒ๋„ ์ธ์ƒ์ ์ด์—ˆ์Šต๋‹ˆ๋‹ค.


๐Ÿ“‚ Reference

https://arxiv.org/abs/1603.09382

 

Deep Networks with Stochastic Depth

Very deep convolutional networks with hundreds of layers have led to significant reductions in error on competitive benchmarks. Although the unmatched expressiveness of the many layers can be highly desirable at test time, training very deep networks comes

arxiv.org

https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py#L150

 

pytorch-image-models/timm/layers/drop.py at main · huggingface/pytorch-image-models

PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNet-V3/V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeX...

github.com

https://pytorch.org/docs/stable/generated/torch.Tensor.new_tensor.html

 

torch.Tensor.new_tensor — PyTorch 2.3 documentation

Shortcuts

pytorch.org

https://pytorch.org/docs/stable/generated/torch.Tensor.bernoulli_.html

 

torch.Tensor.bernoulli_ — PyTorch 2.3 documentation

Shortcuts

pytorch.org

 

 

 

728x90