๊ด€๋ฆฌ ๋ฉ”๋‰ด

Doby's Lab

x.clone()์€ ์ •๋ง Residual Connection์„ ํ• ๊นŒ? (Memory ๊ณต์œ , Immutability) ๋ณธ๋ฌธ

Code about AI/PyTorch

x.clone()์€ ์ •๋ง Residual Connection์„ ํ• ๊นŒ? (Memory ๊ณต์œ , Immutability)

๋„๋น„(Doby) 2024. 4. 27. 19:04

๐Ÿค” Problem

์ด๋ฒˆ์— ResNet์„ PyTorch๋กœ ์ง์ ‘ ๊ตฌํ˜„ํ•ด ๋ณด๋ฉด์„œ ์•ฝ๊ฐ„์˜ ์˜๊ตฌ์‹ฌ(?)์ด ๋“ค์—ˆ๋˜ ๋ถ€๋ถ„์ด ์žˆ์Šต๋‹ˆ๋‹ค. Residual Connection์„ ๊ตฌํ˜„ํ•  ๋•Œ ํฌ๊ฒŒ 2๊ฐ€์ง€ ๋ฐฉ๋ฒ•์œผ๋กœ ๊ตฌํ˜„์„ ํ•˜๋Š”๋ฐ, '๋‘ ์ฝ”๋“œ ๋ชจ๋‘ Residual Connection์„ ์ˆ˜ํ–‰ํ•˜๋Š”๊ฐ€?'๊ฐ€ ์˜๋ฌธ์ด์ž ์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„œ ์•Œ์•„๋‚ด๊ณ  ์‹ถ์€ ๋ฌธ์ œ์ ์ž…๋‹ˆ๋‹ค.

 

+ ์ฝ”๋“œ์— ๋Œ€ํ•ด์„œ๋งŒ ๋‹ค๋ฃฐ ๊ฒƒ์ด๋‹ˆ Residual Connection์— ๋Œ€ํ•œ ๊ฐœ๋…์˜ ์–ธ๊ธ‰์€ ๋”ฐ๋กœ ์—†์Šต๋‹ˆ๋‹ค.

 

์ฒซ ๋ฒˆ์งธ ์ฝ”๋“œ๋Š” torchvision ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๋‚ด์— ResNet์„ ๊ตฌํ˜„ํ•ด ๋‘” ์†Œ์Šค์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.

ํ•ด๋‹น ์ฝ”๋“œ์—์„œ๋Š” identity = x์™€ ๊ฐ™์€ ๋ฐฉ๋ฒ•์œผ๋กœ ๋ณต์‚ฌ๋ฅผ ํ•ฉ๋‹ˆ๋‹ค.

( https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 143 line)

def forward(self, x):
	... # ์ผ๋ จ์˜ ๊ณผ์ •(1)
	identity = x
	... # ์ผ๋ จ์˜ ๊ณผ์ •(2)
	x += identity
	return x

 

๋‘ ๋ฒˆ์งธ ์ฝ”๋“œ๋Š” ์ต๋ช…์˜ ๋ˆ„๊ตฐ๊ฐ€(?)๊ฐ€ ResNet์„ ๊ตฌํ˜„ํ•œ ์†Œ์Šค์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.

ํ•ด๋‹น ์ฝ”๋“œ์—์„œ๋Š” identity = x.clone()๊ณผ ๊ฐ™์€ ๋ฐฉ๋ฒ•์œผ๋กœ ๋ณต์‚ฌ๋ฅผ ํ•ฉ๋‹ˆ๋‹ค.

( https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py 57 line)

def forward(self, x):
	... # ์ผ๋ จ์˜ ๊ณผ์ •(1)
	identity = x.clone()
	... # ์ผ๋ จ์˜ ๊ณผ์ •(2)
	x += identity
	return x

 

์ฒซ ๋ฒˆ์งธ ์ฝ”๋“œ์™€ ๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ Residual Connection์„ ๊ตฌํ˜„ํ•˜๋Š” ๊ฒƒ์ด ์ผ๋ฐ˜์ ์ด์ง€๋งŒ, ๋‘ ๋ฒˆ์งธ ์ฝ”๋“œ์™€ ๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ ๊ตฌํ˜„๋œ ๊ฒƒ๋„ ์ข…์ข… ์ฐพ์•„๋ณผ ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ, ์กฐ๊ธˆ ์˜๊ตฌ์‹ฌ์„ ํ’ˆ์—ˆ๋˜ ๊ณ„๊ธฐ๊ฐ€ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์ฒซ ๋ฒˆ์งธ ์ฝ”๋“œ๋Š” ๋‘ ๋ฒˆ์งธ ์ฝ”๋“œ์™€ ๋‹ค๋ฅธ ์ ์ด ์›๋ณธ ๋ณ€์ˆ˜(ํ…์„œ)์— ๋Œ€ํ•ด์„œ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ณต์œ ํ•œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. '๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ณต์œ ํ•˜๊ณ  ์žˆ๋‹ค๋ฉด, x์— ๋Œ€ํ•œ ์—ฐ์‚ฐ๋“ค์ด ์ผ์–ด๋‚  ๋•Œ๋„ ๊ณต์œ ๋ฅผ ํ•˜๋Š” ๊ฒƒ์€ ๋ฌธ์ œ๊ฐ€ ๋˜์ง€ ์•Š๋Š”๊ฐ€? ๊ทธ๋Ÿฌ๋ฉด Residual Conneciton์ด ์•„๋‹ˆ๋ผ ๋˜‘๊ฐ™์€ ์—ฐ์‚ฐ์„ ํ•œ ๊ฒฐ๊ณผ 2๊ฐœ๋ฅผ ํ•ฉ์น˜๋Š” ๊ผด์ด ์•„๋‹Œ๊ฐ€?'๋ผ๋Š” ๋ฌธ์ œ์ ์„ ์ œ์•ˆํ•˜๋ฉฐ ์ด๋ฒˆ ํฌ์ŠคํŒ…์„ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.


๐Ÿ˜€ What is x.clone()?

๊ทธ๋Ÿฌ๋ฉด ๋‘ ์ฝ”๋“œ์— ๋Œ€ํ•œ ์ฐจ์ด์ ์„ ๋ณด๊ธฐ ์œ„ํ•ด์„œ x.clone()์ด ๋ฌด์—‡์„ ์ˆ˜ํ–‰ํ•˜๋Š” ์ฝ”๋“œ์ธ์ง€ ์•Œ์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.

https://pytorch.org/docs/stable/generated/torch.clone.html

 

torch.clone — PyTorch 2.3 documentation

Shortcuts

pytorch.org

 

x.clone()์ด๋ž€ x๋ผ๋Š” ํ…์„œ์— ๋Œ€ํ•ด์„œ ๋ณต์‚ฌ๊ฐ€ ์ด๋ฃจ์–ด์ง€๋Š”๋ฐ ์ด๋•Œ, ๋ฉ”๋ชจ๋ฆฌ๋Š” ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•˜๋ฉฐ, ๊ธฐ์กด์˜ .grad ์†์„ฑ์€ ๊ฐ€์ ธ๊ฐ€์ง€ ์•Š๊ณ , .grad_fn์€ CloneBackward๋ผ๋Š” ์—ฐ์‚ฐ ํ•จ์ˆ˜๋ฅผ ๋“ฑ๋กํ•ด ์ค๋‹ˆ๋‹ค.

์ฆ‰, x.clone()์€ ๊ธฐ์กด ํ…์„œ์˜ ๊ฐ’๋“ค(๋ฐ์ดํ„ฐ)์— ๋Œ€ํ•ด์„œ ๊ณต์œ ๊ฐ€ ์•„๋‹Œ ๋ณต์‚ฌ(์ƒˆ ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น)๋ฅผ ๋ฐ”๋กœ ํ•ด๋ฒ„๋ฆฌ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. (.grad๋Š” ์ œ์™ธ๋ฅผ ํ•˜๊ณ , ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด์„œ๋งŒ ๋ง์ž…๋‹ˆ๋‹ค.)

 

์ด ์ฐจ์ด์ ์„ ์ง์ ‘ ๋ณด๊ธฐ ์œ„ํ•ด์„œ ์•„๋ž˜์™€ ๊ฐ™์€ ์ฝ”๋“œ๋ฅผ ๋งŒ๋“ค์–ด์คฌ์Šต๋‹ˆ๋‹ค.

def clone_test():
    warnings.filterwarnings(action='ignore')
    a = torch.tensor([2, 3, 4], dtype=torch.float32, requires_grad=True)
    a.grad = torch.tensor([1, 2, 3], dtype=torch.float32)
    b = a.clone()
    c = a

    li = ['a', 'b', 'c']
    for tensor_ in li:
        print('=' * 60)
        print(f'Tensor: {tensor_}')
        print(f'Tensor Memory Address: {str(hex(eval(tensor_).data_ptr())).upper()}')
        print(f'Tensor Value: {eval(tensor_).data}')
        print(f'Tensor Requires Grad: {eval(tensor_).requires_grad}')
        print(f'Tensor Gradient: {eval(tensor_).grad}')
        print(f'Tensor Backward Function: {eval(tensor_).grad_fn}')

    warnings.filterwarnings(action='default')

๊ทธ์— ๋Œ€ํ•œ ๊ฒฐ๊ณผ๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

============================================================
Tensor: a
Tensor Memory Address: 0X143E7370D40
Tensor Value: tensor([2., 3., 4.])
Tensor Requires Grad: True
Tensor Gradient: tensor([1., 2., 3.])
Tensor Backward Function: None
============================================================
Tensor: b
Tensor Memory Address: 0X143E7370E00
Tensor Value: tensor([2., 3., 4.])
Tensor Requires Grad: True
Tensor Gradient: None
Tensor Backward Function: <CloneBackward0 object at 0x00000143E9DFDC30>
============================================================
Tensor: c
Tensor Memory Address: 0X143E7370D40
Tensor Value: tensor([2., 3., 4.])
Tensor Requires Grad: True
Tensor Gradient: tensor([1., 2., 3.])
Tensor Backward Function: None

x.clone()์ด ์–ด๋–ค ๊ฒƒ์„ ์ˆ˜ํ–‰ํ•˜๋Š”์ง€ ์ด๋ฅผ ํ†ตํ•ด์„œ ์‚ดํŽด๋ณผ ์ˆ˜ ์žˆ์—ˆ๊ณ , identity = x๋ผ๋Š” ๋ฐฉ์‹๊ณผ ํ•ต์‹ฌ์ ์ธ ์ฐจ์ด์ ์€ ๋ฐ์ดํ„ฐ๋ฅผ ๋ณต์‚ฌํ•จ์— ์žˆ์–ด์„œ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ณต์œ ํ•˜๋Š๋ƒ ์•ˆ ํ•˜๋Š๋ƒ(์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น)๊ฐ€ ๊ฐ€์žฅ ํฐ ์ฐจ์ด์ ์ž…๋‹ˆ๋‹ค. 


๐Ÿค” identity = x์— ๋Œ€ํ•œ ์˜์‹ฌ

๊ทธ๋Ÿฌ๋ฉด ์ €์˜ ์˜๊ตฌ์‹ฌ์˜ ๋ฐฉํ–ฅ์€ ์ฒซ ๋ฒˆ์งธ ์ฝ”๋“œ๋กœ ํ–ฅํ•ฉ๋‹ˆ๋‹ค. x.clone()์€ ๋ฉ”๋ชจ๋ฆฌ ๊ณต์œ ๋ฅผ ํ•˜๊ณ  ์žˆ์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์›๋ณธ x์— ๋Œ€ํ•ด ์–ด๋– ํ•œ ์—ฐ์‚ฐ์ด ์ด๋ฃจ์–ด์ง€๋”๋ผ๋„ ๊ด€๊ณ„๊ฐ€ ์—†๋‹ค๋Š” ๊ฒƒ์€ ํ™•์‹คํ•ด์กŒ์œผ๋‹ˆ ๋ง์ž…๋‹ˆ๋‹ค.

 

ํ•˜์ง€๋งŒ, ์—ฌ๊ธฐ์„œ ์‚ด์ง ์ด๋ฏธ ์ •๋‹ต์€ ๋‚˜์™€์žˆ์Šต๋‹ˆ๋‹ค. ์™œ๋ƒํ•˜๋ฉด, identity = x๋ฅผ ์“ฐ๋Š” ๋ฐฉ์‹์˜ ์ฝ”๋“œ๋Š” torchvision ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค. ๋„ˆ๋ฌด ๊ณต์‹์ ์ธ ์ฝ”๋“œ๋ผ ์ œ๊ฐ€ ์ƒ๊ฐํ•œ ๋ฌธ์ œ๋ฅผ ์•ผ๊ธฐํ•˜์ง€๋Š” ์•Š๊ฒ ์ง€๋งŒ, ๊ทธ๋ž˜๋„ ๊ถ๊ธˆํ•˜๋‹ˆ ์•Œ์•„๋ด…์‹œ๋‹ค. 

 

๋‘ ์ฝ”๋“œ์˜ ์ฐจ์ด์ ์„ ์ง์ ‘ ๋ณด๊ธฐ ์œ„ํ•ด์„œ ๋ชจ๋ธ์„ ํ•˜๋‚˜ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค. (๋ชจ๋ธ์ด๋ผ ํ•˜๊ธฐ์—” ๊ทธ๋ ‡์ง€๋งŒ)

# (1) MAKE RESIDUAL MODEL
input_data = torch.tensor([1.], dtype=torch.float32)
weight1 = torch.tensor([2.], dtype=torch.float32, requires_grad=True)
weight2 = torch.tensor([3.], dtype=torch.float32, requires_grad=True)
target_data = torch.tensor([8.5], dtype=torch.float32)

feature = weight1 * input_data  # 2.0

# make identity for residual connection
identity = None  # 2.0
print('HOW TO RESIDUAL CONNECTION: ', end='')
if use_clone:
    print('identity = feature.clone()')
    identity = feature.clone()
else:
    print('identity = feature')
    identity = feature

feature2 = weight2 * feature  # 6.0

output_data = identity + feature2  # 8.0 = 2.0 + 6.0

loss = (output_data - target_data) ** 2  # 0.25 = (8.0 - 8.5) ** 2

์ฝ”๋“œ๊ฐ€ ๋‹ค์†Œ ๋ณต์žกํ•ด ๋ณด์ด์ง€๋งŒ, ๊ทธ๋ฆผ์œผ๋กœ ๋‚˜ํƒ€๋‚ด๋ฉด ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

Residual Connection์„ ์œ„ํ•œ ๋„ˆ๋ฌด๋‚˜ ์˜๋„์ ์ธ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์œ„ ์ฝ”๋“œ์— ๋”ฐ๋ผ ์ž…๋ ฅ ๊ฐ’์ด 1์ด๋ฉด, ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ณ„์† ๊ณต์œ ํ•˜์ง€ ์•Š๋Š”๋‹ค๊ณ  ํ–ˆ์„ ๋•Œ, ์ถœ๋ ฅ ๊ฐ’์ด 8์ด ๋‚˜์™€์•ผ ํ•˜๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ด๊ฑธ ๋‘ ๊ฐ€์ง€ ๋ฐฉ์‹์— ๋”ฐ๋ผ ์ˆ˜ํ–‰ํ–ˆ์„ ๋•Œ ์ฐจ์ด๊ฐ€ ์žˆ์—ˆ์„๊นŒ์š”? ์ถœ๋ ฅ์„ ํ•ด๋ณด์•˜์Šต๋‹ˆ๋‹ค. (์ถœ๋ ฅ์— ๊ด€ํ•œ ์ฝ”๋“œ๋Š” ๋งจ ์•„๋ž˜์— ํ•œ๊บผ๋ฒˆ์— ์ฒจ๋ถ€ํ•˜์˜€์Šต๋‹ˆ๋‹ค.)

 

(1) identity = x๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ

HOW TO RESIDUAL CONNECTION: identity = feature

=========================[Memory Address]=========================
Identity's Address(data_ptr): 0X24E65892840, Address(id): 0X24E683404A0
Feature's  Address(data_ptr): 0X24E65892840, Address(id): 0X24E683404A0
Feature2's Address(data_ptr): 0X24E65891F40, Address(id): 0X24E68358C20

=========================[Gradient Function]=========================
input_data's grad_fn:     None
weight1's grad_fn:        None
feature's grad_fn:        <MulBackward0 object at 0x0000024E6831DCF0>
identity's grad_fn:       <MulBackward0 object at 0x0000024E6831DF00>
weight2's grad_fn:        None
feature2's grad_fn:       <MulBackward0 object at 0x0000024E6831DCF0>
output_data's grad_fn:    <AddBackward0 object at 0x0000024E6831DF00>
loss's grad_fn:           <PowBackward0 object at 0x0000024E6831F5B0>

=========================[Value Information]=========================
Output: 8.00
Target: 8.50
Loss: 0.25
Weight1's Gradient to update -4.00

(2) identity = x.clone()๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ

HOW TO RESIDUAL CONNECTION: identity = feature.clone()

=========================[Memory Address]=========================
Identity's Address(data_ptr): 0X23F5C9B5880, Address(id): 0X23F5F420D60
Feature's  Address(data_ptr): 0X23F5C9B6000, Address(id): 0X23F5F4085E0
Feature2's Address(data_ptr): 0X23F5C9B5580, Address(id): 0X23F5F420DB0

=========================[Gradient Function]=========================
input_data's grad_fn:     None
weight1's grad_fn:        None
feature's grad_fn:        <MulBackward0 object at 0x0000023F5F3EDCF0>
identity's grad_fn:       <CloneBackward0 object at 0x0000023F5F3EDF00>
weight2's grad_fn:        None
feature2's grad_fn:       <MulBackward0 object at 0x0000023F5F3EDCF0>
output_data's grad_fn:    <AddBackward0 object at 0x0000023F5F3EDF00>
loss's grad_fn:           <PowBackward0 object at 0x0000023F5F3EF5B0>

=========================[Value Information]=========================
Output: 8.00
Target: 8.50
Loss: 0.25
Weight1's Gradient to update -4.00

๊ฒฐ๊ณผ๊ฐ€ ๊ฐ™์Šต๋‹ˆ๋‹ค! ๋˜‘๊ฐ™์€ Output์„ ๊ฐ€์ง€๊ณ  ์˜จ๋‹ค๋Š” ๊ฒ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ, ์ด๊ฑด Forward Propagation์„ ๋งŒ์กฑํ•œ๋‹ค๋Š” ์†Œ๋ฆฌ์ผ ๋ฟ, Back Propagation์€ ํ™•์‹คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

 

๊ทธ๋ž˜์„œ ์ €๋Š” weight1์ด Back Propagation์„ ํ–ˆ์„ ๋•Œ, ์—…๋ฐ์ดํŠธ๋ฅผ ํ•ด์•ผ ํ•˜๋Š” ๊ฐ’์„ ์ถœ๋ ฅํ•˜๋„๋ก ํ•ด์ฃผ์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ ๊ฒฐ๊ณผ๋„ ๊ฐ™์Šต๋‹ˆ๋‹ค! ์ฆ‰, Forward Propagation, Backward Propagation์˜ ๊ฒฐ๊ณผ๊ฐ€ ๊ฐ™๊ธฐ ๋•Œ๋ฌธ์— ์ฒซ ๋ฒˆ์งธ ์ฝ”๋“œ์˜ ๋ฐฉ์‹์ด๋“ , ๋‘ ๋ฒˆ์งธ ์ฝ”๋“œ์˜ ๋ฐฉ์‹์ด๋“  ๋ฌธ์ œ๊ฐ€ ์—†๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

๊ทธ๋Ÿฌ๋ฉด ๋„๋Œ€์ฒด identity = x์˜ ๋ฐฉ์‹์€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ณต์œ ํ•œ๋‹ค๊ณ  ํ–ˆ๋Š”๋ฐ ์–ด๋–ป๊ฒŒ Residual Conneciton์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ–ˆ์„๊นŒ์š”? ๋‹ค์‹œ ๋งํ•ด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ณต์œ ํ•˜๊ณ  ์žˆ๋Š”๋ฐ ๊ฐ™์€ ๋ฉ”๋ชจ๋ฆฌ์˜ ๋‘ ๋ณ€์ˆ˜์— ๋Œ€ํ•ด์„œ ์–ด๋–ป๊ฒŒ ๋…๋ฆฝ์ ์ธ ๊ฐ’์„ ๊ฐ–๋„๋ก ํ–ˆ์„๊นŒ์š”.


๐Ÿ˜€ Tensor์˜ Mutable, or Immutable

์ด๊ฒƒ์ด ๊ฐ€๋Šฅํ•œ ์ด์œ ๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋Œ๊ณ  ๋Œ์•„์•ผ ํ–ˆ์Šต๋‹ˆ๋‹ค. Garbage Collection์ด๋ผ๋Š” ๊ฐœ๋…๊นŒ์ง€ ๋ณผ ๋ป”ํ•˜๋‹ค๊ฐ€ ๊ฒจ์šฐ ๊ทธ๊นŒ์ง€๋Š” ๊ฐ€์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ์–ด์จŒ๋“  ์ฐพ์€ ๊ทผ๊ฑฐ๋Š” ๊ฒฐ๊ตญ์— ๋˜ ๊ธฐ์ดˆ์— ๋ชจ๋“  ๊ฒƒ์ด ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

 

์˜ˆ์ „ ํŒŒ์ด์ฌ ๊ด€๋ จ, ํŒŒ์ดํ† ์น˜ ๊ด€๋ จ ํฌ์ŠคํŒ…์—์„œ ํŒŒ์ด์ฌ์˜ ๋ชจ๋“  ๊ฒƒ์€ ๊ฐ์ฒด์ด๋ฉฐ, ์ด๋Š” Mutableํ•œ๊ฐ€, Immutableํ•œ๊ฐ€๋ฅผ ๋‹ค๋ฃฌ ์ ์ด ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

 

๊ฐ„๋‹จํ•˜๊ฒŒ ๋‹ค์‹œ ์–ธ๊ธ‰ํ•˜๋ฉด Mutable์€ ์ˆ˜์ •์‚ฌํ•ญ์— ๋Œ€ํ•ด์„œ ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•  ํ•„์š” ์—†์ด ๊ทธ๋Œ€๋กœ ์ˆ˜์ •ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฐ์ฒด์ด๊ณ , Immutable์€ ์ˆ˜์ •์„ ํ•˜๋Š”๋ฐ ๊ธฐ์กด ๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ˆ˜์ •ํ•  ์ˆ˜ ์—†์œผ๋‹ˆ ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•ด์•ผ ํ•˜๋Š” ๊ฐ์ฒด์ž…๋‹ˆ๋‹ค. (ํ•ญ์ƒ ์ด ๊ฐœ๋…์— ๋Œ€ํ•ด์„œ ๋ธ”๋กœ๊ทธ์—์„œ๋Š” ๋ถ€์ˆ˜์ ์œผ๋กœ๋งŒ ๋‹ค๋ฃจ์–ด ์™”์–ด์„œ ๋‚˜์ค‘์— ํ•œ ๋ฒˆ ์ œ๋Œ€๋กœ ๋‹ค๋ฃจ์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.)

 

์–ด์จŒ๋“ , ์—ฌ๊ธฐ๊นŒ์ง€ ์„ค๋ช…ํ–ˆ๋‹ค๋ฉด ๊ฐ์€ ์˜ค์…จ์„ ๊ฒ๋‹ˆ๋‹ค. '์•„, Tensor๊ฐ€ Immutable ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ตณ์ด x.clone()์„ ํ•ด์ฃผ์ง€ ์•Š์•„๋„ ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•˜๋Š” ๊ฒƒ์ด๋ผ์„œ Residual Connection์„ ํ•  ๋•Œ, ๋ฌธ์ œ๊ฐ€ ๋˜์ง€ ์•Š๋Š” ๊ฑฐ๊ตฌ๋‚˜!'

 

๋„ค, ์ •๋ง ํ›Œ๋ฅญํ•œ ์ƒ๊ฐ์ž…๋‹ˆ๋‹ค๋งŒ ์‚ฌ์‹ค ์ด๊ฑด ํ‹€๋ ธ์Šต๋‹ˆ๋‹ค. ์™œ๋ƒํ•˜๋ฉด, Tensor๋Š” Mutable ํ•œ ๊ฐ์ฒด์ด๊ฑฐ๋“ ์š”.

 

๋‹จ, ํŠน์ • ์กฐ๊ฑด์— ๋Œ€ํ•ด์„œ๋งŒ ๋ง์ด์—์š”. ๋ฆฌ์ŠคํŠธ์—์„œ ์›์†Œ ๊ฐ’ ์ˆ˜์ • ๊ฐ™์€ In-place Operation์ธ ๊ฒฝ์šฐ์—๋Š” ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. Mutableํ•˜๋‹ค๋Š” ์˜๋ฏธ์ž…๋‹ˆ๋‹ค.

 

ํ•˜์ง€๋งŒ, ๋‹ค๋ฅธ ํ…์„œ์™€์˜ ๋ง์…ˆ, ๋บ„์…ˆ, ๊ณฑ์…ˆ ๋“ฑ๊ณผ ๊ฐ™์€ Arithmetic Operation์— ๋Œ€ํ•ด์„œ๋Š” Immutability๊ฐ€ ์‚ฌ์šฉ๋˜์–ด ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•ด์•ผ๋งŒ ํ•ฉ๋‹ˆ๋‹ค.

 

๊ทธ๋ž˜์„œ, Residual Connection์—์„œ๋Š” Artihmetic Operation์ด 99.999999%๋กœ ์“ฐ์ด๊ธฐ ๋•Œ๋ฌธ์— ์šฐ๋ฆฌ๊ฐ€ identity = x์™€ ๊ฐ™์€ ์ฝ”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋”๋ผ๋„ x๊ฐ€ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•จ์— ๋”ฐ๋ผ ๋ฉ”๋ชจ๋ฆฌ ์ฃผ์†Œ๊ฐ€ ๋‹ฌ๋ผ์ง€๊ธฐ ๋•Œ๋ฌธ์— ๋ฌธ์ œ๊ฐ€ ๋˜์ง€ ์•Š๋Š” ๊ฒƒ์ด์—ˆ์Šต๋‹ˆ๋‹ค.


โœ… Result

์ž, ์—ฌ๊ธฐ๊นŒ์ง€๊ฐ€ ์ œ๊ฐ€ ์ƒ๊ฐํ•œ ๋ฌธ์ œ์ (๊ถ๊ธˆ์ )์— ๋Œ€ํ•œ ํ•ด๊ฒฐ์ฑ…(ํ•ด์†Œ์ฑ…)์ด์—ˆ์Šต๋‹ˆ๋‹ค.

 

๋‹ค์‹œ ํ•œ๋ฒˆ ์š”์•ฝ์„ ํ•ด๋ณด์ž๋ฉด, Residual Connection์„ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ 2๊ฐ€์ง€๊ฐ€ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

 

1. identity = x.clone()์€ ๋ณ„๋„์˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ฐ€์ง€๊ธฐ ๋•Œ๋ฌธ์— ๋ฌธ์ œ๊ฐ€ ๋˜์ง€ ์•Š๋Š”๋‹ค๊ณ  ํŒ๋‹จํ•˜์—ฌ, identity = x์— ๋Œ€ํ•ด์„œ๋งŒ ์•Œ์•„๋ณด๊ธฐ ์‹œ์ž‘

2. Tensor๊ฐ€ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•จ์— ๋”ฐ๋ผ ๋ณ„๋„์˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ฐ€์ง„๋‹ค. ํ•˜์ง€๋งŒ, Tensor๋Š” Mutable ํ•œ ๊ฐ์ฒด์ด๋‹ค.

3. ์ด ์ ์—์„œ Tensor๊ฐ€ In-place Operation์— ๋Œ€ํ•ด์„œ๋Š” Mutability๊ฐ€ ๋ณด์žฅ๋˜์ง€๋งŒ, Arithmetic Operation์— ๋Œ€ํ•ด์„œ๋Š” Immutability๊ฐ€ ๋ณด์žฅ๋œ๋‹ค. ์ฆ‰, ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ• ๋‹น๋œ๋‹ค.

4. ๊ทธ๋ž˜์„œ identity = x or identity = x.clone()์— ๋Œ€ํ•œ ์„ ํƒ์— ์žˆ์–ด์„œ ๊ฒฐ๊ณผ์˜ ์ฐจ์ด๋Š” ์—†๋‹ค.

 

ํ•˜์ง€๋งŒ, x.clone()์€ .grad ์†์„ฑ์— ๋Œ€ํ•ด์„œ๋Š” ๋ณต์‚ฌ๋ฅผ ํ•˜์ง€ ์•Š๋Š”๋‹ค๊ณ  ํ–ˆ์œผ๋ฏ€๋กœ x.clone()์ด ๋” ํšจ์œจ์ ์ผ ์ˆ˜๋„ ์žˆ๊ฒ ๋‹ค๋Š” ์ถ”์ธก์ด ๋ฉ๋‹ˆ๋‹ค.

import torch
from memory_profiler import profile


@profile
def copy_tensor():
    a = torch.tensor([1], dtype=torch.float32)
    a.grad = torch.tensor([2], dtype=torch.float32)
    b = a
    # b = a.clone()


if __name__ == '__main__':
    copy_tensor()

 

identity = x
identity = x.clone()

๊ทธ๋ž˜์„œ memory_profiler๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ํ†ตํ•ด์„œ ํ”„๋กœํŒŒ์ผ๋ง์„ ํ•œ ๊ฒฐ๊ณผ, ํ…์„œ์˜ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด์„œ ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ• ๋‹น๋˜๊ธฐ ๋•Œ๋ฌธ์— x.clone()์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ๊ฒƒ์ด ๋” ํšจ์œจ์ ์ž…๋‹ˆ๋‹ค. ์–ด์ฐจํ”ผ, x๋ฅผ ์—ฐ์‚ฐํ•จ์— ๋”ฐ๋ผ ์ƒˆ๋กœ์šด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

๋˜ํ•œ, identity๋‚˜ x๋Š” ๊ฒฐ๊ตญ input tensor๋กœ weight๊ฐ€ ์•„๋‹ˆ๊ธฐ ๋•Œ๋ฌธ์— gradient ๊ฐ’์€ ํ•„์š”๊ฐ€ ์—†์–ด์„œ x.clone()์˜ ํ•„์š”์„ฑ์„ ๋” ๋ชป ๋Š๋ผ๊ฒŒ ๋˜๋Š” ๊ฒƒ์ด ์•„๋‹Œ๊ฐ€ ์‹ถ์Šต๋‹ˆ๋‹ค.

 

๊ฒฐ๋ก ์ ์œผ๋กœ, 'Residual Connection์„ ํ•  ๋•Œ๋Š” x.clone()์„ ์“ธ ํ•„์š”๊ฐ€ ์—†๋‹ค. identity = x๋ฅผ ์“ฐ๋ฉด ๋œ๋‹ค.'๋กœ ๊ธ€์„ ๋งˆ๋ฌด๋ฆฌํ•  ์ˆ˜ ์žˆ์„ ๊ฑฐ ๊ฐ™์Šต๋‹ˆ๋‹ค.


๐Ÿ“„ Code for this post

import torch
import numpy as np
import argparse
import warnings


def residual_test(use_clone: bool = True):
    # (1) MAKE RESIDUAL MODEL
    input_data = torch.tensor([1.], dtype=torch.float32)
    weight1 = torch.tensor([2.], dtype=torch.float32, requires_grad=True)
    weight2 = torch.tensor([3.], dtype=torch.float32, requires_grad=True)
    target_data = torch.tensor([8.5], dtype=torch.float32)

    feature = weight1 * input_data  # 2.0

    # make identity for residual connection
    identity = None  # 2.0
    print('HOW TO RESIDUAL CONNECTION: ', end='')
    if use_clone:
        print('identity = feature.clone()')
        identity = feature.clone()
    else:
        print('identity = feature')
        identity = feature

    feature2 = weight2 * feature  # 6.0

    output_data = identity + feature2  # 8.0 = 2.0 + 6.0

    loss = (output_data - target_data) ** 2  # 0.25 = (8.0 - 8.5) ** 2

    loss.backward()

    # (2) PRINT INFORMATION
    _ANNOTATION_MARK = 25

    print('\n' + '=' * _ANNOTATION_MARK +
          '[Memory Address]' + '=' * _ANNOTATION_MARK)

    print(
        f'Identity\'s Address(data_ptr): {str(hex(identity.data_ptr())).upper()}, Address(id): {str(hex(id(identity))).upper()}')
    print(
        f'Feature\'s  Address(data_ptr): {str(hex(feature.data_ptr())).upper()}, Address(id): {str(hex(id(feature))).upper()}')
    print(
        f'Feature2\'s Address(data_ptr): {str(hex(feature2.data_ptr())).upper()}, Address(id): {str(hex(id(feature2))).upper()}')

    print('\n' + '=' * _ANNOTATION_MARK +
          '[Gradient Function]' + '=' * _ANNOTATION_MARK)
    components_li = ['input_data',
                     'weight1',
                     'feature',
                     'identity',
                     'weight2',
                     'feature2',
                     'output_data',
                     'loss']
    for component in components_li:
        temp_str = f'{component}\'s grad_fn:'
        print(f'{temp_str:25s} {eval(component).grad_fn}')
    # print(hex(weight1.data_ptr()))

    print('\n' + '=' * _ANNOTATION_MARK +
          '[Value Information]' + '=' * _ANNOTATION_MARK)
    print(f'Output: {output_data.item():.2f}')
    print(f'Target: {target_data.item():.2f}')
    print(f'Loss: {loss.item():.2f}')
    print(f'Weight1\'s Gradient to update {weight1.grad.item():.2f}')


def memory_test():
    a = np.array([1, 2, 3])
    print(f'a\'s address: {hex(id(a))}, a: {a}')
    a[2] = 2
    print(f'a\'s address: {hex(id(a))}, a: {a}')
    b = np.array([2, 4, 6])
    a = a + b
    print(f'a\'s address: {hex(id(a))}, a: {a}')


def clone_test():
    warnings.filterwarnings(action='ignore')
    a = torch.tensor([2, 3, 4], dtype=torch.float32, requires_grad=True)
    a.grad = torch.tensor([1, 2, 3], dtype=torch.float32)
    b = a.clone()
    c = a

    li = ['a', 'b', 'c']
    for tensor_ in li:
        print('=' * 60)
        print(f'Tensor: {tensor_}')
        print(
            f'Tensor Memory Address: {str(hex(eval(tensor_).data_ptr())).upper()}')
        print(f'Tensor Value: {eval(tensor_).data}')
        print(f'Tensor Requires Grad: {eval(tensor_).requires_grad}')
        print(f'Tensor Gradient: {eval(tensor_).grad}')
        print(f'Tensor Backward Function: {eval(tensor_).grad_fn}')

    warnings.filterwarnings(action='default')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--use-clone', action='store_true')
    args = parser.parse_args()

    residual_test(use_clone=args.use_clone)
    # memory_test()
    # clone_test()