์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- ๋ค์ต์คํธ๋ผ
- DP
- ๋ฏธ๋๋_ํ์ฌ์_๊ณผ๊ฑฐ๋ก
- ๋ถํ ์ ๋ณต
- back propagation
- ํ๋ก์ด๋ ์์ฌ
- ๋ฌธ์์ด
- NEXT
- ์ฐ์ ์์ ํ
- ๊ฐ๋์_๋ง๋ก
- ๋ฐฑํธ๋ํน
- ํฌ๋ฃจ์ค์นผ
- ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ
- c++
- BFS
- ๊ฐ๋์ ๋ง๋ก
- 2023
- object detection
- lazy propagation
- ์ด๋ถ ํ์
- Overfitting
- tensorflow
- ํ๊ณ ๋ก
- pytorch
- dfs
- ์๋ฐ์คํฌ๋ฆฝํธ
- ์กฐํฉ๋ก
- dropout
- ์๊ณ ๋ฆฌ์ฆ
- ๋๋น ์ฐ์ ํ์
- Today
- Total
Doby's Lab
x.clone()์ ์ ๋ง Residual Connection์ ํ ๊น? (Memory ๊ณต์ , Immutability) ๋ณธ๋ฌธ
x.clone()์ ์ ๋ง Residual Connection์ ํ ๊น? (Memory ๊ณต์ , Immutability)
๋๋น(Doby) 2024. 4. 27. 19:04๐ค Problem
์ด๋ฒ์ ResNet์ PyTorch๋ก ์ง์ ๊ตฌํํด ๋ณด๋ฉด์ ์ฝ๊ฐ์ ์๊ตฌ์ฌ(?)์ด ๋ค์๋ ๋ถ๋ถ์ด ์์ต๋๋ค. Residual Connection์ ๊ตฌํํ ๋ ํฌ๊ฒ 2๊ฐ์ง ๋ฐฉ๋ฒ์ผ๋ก ๊ตฌํ์ ํ๋๋ฐ, '๋ ์ฝ๋ ๋ชจ๋ Residual Connection์ ์ํํ๋๊ฐ?'๊ฐ ์๋ฌธ์ด์ ์ด๋ฒ ํฌ์คํ ์์ ์์๋ด๊ณ ์ถ์ ๋ฌธ์ ์ ์ ๋๋ค.
+ ์ฝ๋์ ๋ํด์๋ง ๋ค๋ฃฐ ๊ฒ์ด๋ Residual Connection์ ๋ํ ๊ฐ๋ ์ ์ธ๊ธ์ ๋ฐ๋ก ์์ต๋๋ค.
์ฒซ ๋ฒ์งธ ์ฝ๋๋ torchvision ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ด์ ResNet์ ๊ตฌํํด ๋ ์์ค์ฝ๋์ ๋๋ค.
ํด๋น ์ฝ๋์์๋ identity = x
์ ๊ฐ์ ๋ฐฉ๋ฒ์ผ๋ก ๋ณต์ฌ๋ฅผ ํฉ๋๋ค.
( https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 143 line)
def forward(self, x):
... # ์ผ๋ จ์ ๊ณผ์ (1)
identity = x
... # ์ผ๋ จ์ ๊ณผ์ (2)
x += identity
return x
๋ ๋ฒ์งธ ์ฝ๋๋ ์ต๋ช ์ ๋๊ตฐ๊ฐ(?)๊ฐ ResNet์ ๊ตฌํํ ์์ค์ฝ๋์ ๋๋ค.
ํด๋น ์ฝ๋์์๋ identity = x.clone()
๊ณผ ๊ฐ์ ๋ฐฉ๋ฒ์ผ๋ก ๋ณต์ฌ๋ฅผ ํฉ๋๋ค.
( https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py 57 line)
def forward(self, x):
... # ์ผ๋ จ์ ๊ณผ์ (1)
identity = x.clone()
... # ์ผ๋ จ์ ๊ณผ์ (2)
x += identity
return x
์ฒซ ๋ฒ์งธ ์ฝ๋์ ๊ฐ์ ๋ฐฉ์์ผ๋ก Residual Connection์ ๊ตฌํํ๋ ๊ฒ์ด ์ผ๋ฐ์ ์ด์ง๋ง, ๋ ๋ฒ์งธ ์ฝ๋์ ๊ฐ์ ๋ฐฉ์์ผ๋ก ๊ตฌํ๋ ๊ฒ๋ ์ข ์ข ์ฐพ์๋ณผ ์ ์์์ต๋๋ค. ํ์ง๋ง, ์กฐ๊ธ ์๊ตฌ์ฌ์ ํ์๋ ๊ณ๊ธฐ๊ฐ ์์์ต๋๋ค. ์ฒซ ๋ฒ์งธ ์ฝ๋๋ ๋ ๋ฒ์งธ ์ฝ๋์ ๋ค๋ฅธ ์ ์ด ์๋ณธ ๋ณ์(ํ ์)์ ๋ํด์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ณต์ ํ๋ค๋ ๊ฒ์ ๋๋ค. '๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ณต์ ํ๊ณ ์๋ค๋ฉด, x์ ๋ํ ์ฐ์ฐ๋ค์ด ์ผ์ด๋ ๋๋ ๊ณต์ ๋ฅผ ํ๋ ๊ฒ์ ๋ฌธ์ ๊ฐ ๋์ง ์๋๊ฐ? ๊ทธ๋ฌ๋ฉด Residual Conneciton์ด ์๋๋ผ ๋๊ฐ์ ์ฐ์ฐ์ ํ ๊ฒฐ๊ณผ 2๊ฐ๋ฅผ ํฉ์น๋ ๊ผด์ด ์๋๊ฐ?'๋ผ๋ ๋ฌธ์ ์ ์ ์ ์ํ๋ฉฐ ์ด๋ฒ ํฌ์คํ ์ ์์ฑํฉ๋๋ค.
๐ What is x.clone()
?
๊ทธ๋ฌ๋ฉด ๋ ์ฝ๋์ ๋ํ ์ฐจ์ด์ ์ ๋ณด๊ธฐ ์ํด์ x.clone()
์ด ๋ฌด์์ ์ํํ๋ ์ฝ๋์ธ์ง ์์์ผ ํฉ๋๋ค.
https://pytorch.org/docs/stable/generated/torch.clone.html
x.clone()
์ด๋ x
๋ผ๋ ํ
์์ ๋ํด์ ๋ณต์ฌ๊ฐ ์ด๋ฃจ์ด์ง๋๋ฐ ์ด๋, ๋ฉ๋ชจ๋ฆฌ๋ ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํ๋ฉฐ, ๊ธฐ์กด์ .grad
์์ฑ์ ๊ฐ์ ธ๊ฐ์ง ์๊ณ , .grad_fn
์ CloneBackward
๋ผ๋ ์ฐ์ฐ ํจ์๋ฅผ ๋ฑ๋กํด ์ค๋๋ค.
์ฆ, x.clone()
์ ๊ธฐ์กด ํ
์์ ๊ฐ๋ค(๋ฐ์ดํฐ)์ ๋ํด์ ๊ณต์ ๊ฐ ์๋ ๋ณต์ฌ(์ ๋ฉ๋ชจ๋ฆฌ ํ ๋น)๋ฅผ ๋ฐ๋ก ํด๋ฒ๋ฆฌ๋ ๊ฒ์
๋๋ค. (.grad
๋ ์ ์ธ๋ฅผ ํ๊ณ , ๋ฐ์ดํฐ์ ๋ํด์๋ง ๋ง์
๋๋ค.)
์ด ์ฐจ์ด์ ์ ์ง์ ๋ณด๊ธฐ ์ํด์ ์๋์ ๊ฐ์ ์ฝ๋๋ฅผ ๋ง๋ค์ด์คฌ์ต๋๋ค.
def clone_test():
warnings.filterwarnings(action='ignore')
a = torch.tensor([2, 3, 4], dtype=torch.float32, requires_grad=True)
a.grad = torch.tensor([1, 2, 3], dtype=torch.float32)
b = a.clone()
c = a
li = ['a', 'b', 'c']
for tensor_ in li:
print('=' * 60)
print(f'Tensor: {tensor_}')
print(f'Tensor Memory Address: {str(hex(eval(tensor_).data_ptr())).upper()}')
print(f'Tensor Value: {eval(tensor_).data}')
print(f'Tensor Requires Grad: {eval(tensor_).requires_grad}')
print(f'Tensor Gradient: {eval(tensor_).grad}')
print(f'Tensor Backward Function: {eval(tensor_).grad_fn}')
warnings.filterwarnings(action='default')
๊ทธ์ ๋ํ ๊ฒฐ๊ณผ๋ ์๋์ ๊ฐ์ต๋๋ค.
============================================================
Tensor: a
Tensor Memory Address: 0X143E7370D40
Tensor Value: tensor([2., 3., 4.])
Tensor Requires Grad: True
Tensor Gradient: tensor([1., 2., 3.])
Tensor Backward Function: None
============================================================
Tensor: b
Tensor Memory Address: 0X143E7370E00
Tensor Value: tensor([2., 3., 4.])
Tensor Requires Grad: True
Tensor Gradient: None
Tensor Backward Function: <CloneBackward0 object at 0x00000143E9DFDC30>
============================================================
Tensor: c
Tensor Memory Address: 0X143E7370D40
Tensor Value: tensor([2., 3., 4.])
Tensor Requires Grad: True
Tensor Gradient: tensor([1., 2., 3.])
Tensor Backward Function: None
x.clone()
์ด ์ด๋ค ๊ฒ์ ์ํํ๋์ง ์ด๋ฅผ ํตํด์ ์ดํด๋ณผ ์ ์์๊ณ , identity = x
๋ผ๋ ๋ฐฉ์๊ณผ ํต์ฌ์ ์ธ ์ฐจ์ด์ ์ ๋ฐ์ดํฐ๋ฅผ ๋ณต์ฌํจ์ ์์ด์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ณต์ ํ๋๋ ์ ํ๋๋(์๋ก์ด ๋ฉ๋ชจ๋ฆฌ ํ ๋น)๊ฐ ๊ฐ์ฅ ํฐ ์ฐจ์ด์ ์
๋๋ค.
๐ค identity = x
์ ๋ํ ์์ฌ
๊ทธ๋ฌ๋ฉด ์ ์ ์๊ตฌ์ฌ์ ๋ฐฉํฅ์ ์ฒซ ๋ฒ์งธ ์ฝ๋๋ก ํฅํฉ๋๋ค. x.clone()
์ ๋ฉ๋ชจ๋ฆฌ ๊ณต์ ๋ฅผ ํ๊ณ ์์ง ์๊ธฐ ๋๋ฌธ์ ์๋ณธ x
์ ๋ํด ์ด๋ ํ ์ฐ์ฐ์ด ์ด๋ฃจ์ด์ง๋๋ผ๋ ๊ด๊ณ๊ฐ ์๋ค๋ ๊ฒ์ ํ์คํด์ก์ผ๋ ๋ง์
๋๋ค.
ํ์ง๋ง, ์ฌ๊ธฐ์ ์ด์ง ์ด๋ฏธ ์ ๋ต์ ๋์์์ต๋๋ค. ์๋ํ๋ฉด, identity = x
๋ฅผ ์ฐ๋ ๋ฐฉ์์ ์ฝ๋๋ torchvision
๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ฝ๋์
๋๋ค. ๋๋ฌด ๊ณต์์ ์ธ ์ฝ๋๋ผ ์ ๊ฐ ์๊ฐํ ๋ฌธ์ ๋ฅผ ์ผ๊ธฐํ์ง๋ ์๊ฒ ์ง๋ง, ๊ทธ๋๋ ๊ถ๊ธํ๋ ์์๋ด
์๋ค.
๋ ์ฝ๋์ ์ฐจ์ด์ ์ ์ง์ ๋ณด๊ธฐ ์ํด์ ๋ชจ๋ธ์ ํ๋ ๋ง๋ค์์ต๋๋ค. (๋ชจ๋ธ์ด๋ผ ํ๊ธฐ์ ๊ทธ๋ ์ง๋ง)
# (1) MAKE RESIDUAL MODEL
input_data = torch.tensor([1.], dtype=torch.float32)
weight1 = torch.tensor([2.], dtype=torch.float32, requires_grad=True)
weight2 = torch.tensor([3.], dtype=torch.float32, requires_grad=True)
target_data = torch.tensor([8.5], dtype=torch.float32)
feature = weight1 * input_data # 2.0
# make identity for residual connection
identity = None # 2.0
print('HOW TO RESIDUAL CONNECTION: ', end='')
if use_clone:
print('identity = feature.clone()')
identity = feature.clone()
else:
print('identity = feature')
identity = feature
feature2 = weight2 * feature # 6.0
output_data = identity + feature2 # 8.0 = 2.0 + 6.0
loss = (output_data - target_data) ** 2 # 0.25 = (8.0 - 8.5) ** 2
์ฝ๋๊ฐ ๋ค์ ๋ณต์กํด ๋ณด์ด์ง๋ง, ๊ทธ๋ฆผ์ผ๋ก ๋ํ๋ด๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
Residual Connection์ ์ํ ๋๋ฌด๋ ์๋์ ์ธ ๋ชจ๋ธ์ ๋๋ค. ์ ์ฝ๋์ ๋ฐ๋ผ ์ ๋ ฅ ๊ฐ์ด 1์ด๋ฉด, ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ณ์ ๊ณต์ ํ์ง ์๋๋ค๊ณ ํ์ ๋, ์ถ๋ ฅ ๊ฐ์ด 8์ด ๋์์ผ ํ๋ ๋ชจ๋ธ์ ๋๋ค. ์ด๊ฑธ ๋ ๊ฐ์ง ๋ฐฉ์์ ๋ฐ๋ผ ์ํํ์ ๋ ์ฐจ์ด๊ฐ ์์์๊น์? ์ถ๋ ฅ์ ํด๋ณด์์ต๋๋ค. (์ถ๋ ฅ์ ๊ดํ ์ฝ๋๋ ๋งจ ์๋์ ํ๊บผ๋ฒ์ ์ฒจ๋ถํ์์ต๋๋ค.)
(1) identity = x
๋ฅผ ์ฌ์ฉํ์ ๋
HOW TO RESIDUAL CONNECTION: identity = feature
=========================[Memory Address]=========================
Identity's Address(data_ptr): 0X24E65892840, Address(id): 0X24E683404A0
Feature's Address(data_ptr): 0X24E65892840, Address(id): 0X24E683404A0
Feature2's Address(data_ptr): 0X24E65891F40, Address(id): 0X24E68358C20
=========================[Gradient Function]=========================
input_data's grad_fn: None
weight1's grad_fn: None
feature's grad_fn: <MulBackward0 object at 0x0000024E6831DCF0>
identity's grad_fn: <MulBackward0 object at 0x0000024E6831DF00>
weight2's grad_fn: None
feature2's grad_fn: <MulBackward0 object at 0x0000024E6831DCF0>
output_data's grad_fn: <AddBackward0 object at 0x0000024E6831DF00>
loss's grad_fn: <PowBackward0 object at 0x0000024E6831F5B0>
=========================[Value Information]=========================
Output: 8.00
Target: 8.50
Loss: 0.25
Weight1's Gradient to update -4.00
(2) identity = x.clone()
๋ฅผ ์ฌ์ฉํ์ ๋
HOW TO RESIDUAL CONNECTION: identity = feature.clone()
=========================[Memory Address]=========================
Identity's Address(data_ptr): 0X23F5C9B5880, Address(id): 0X23F5F420D60
Feature's Address(data_ptr): 0X23F5C9B6000, Address(id): 0X23F5F4085E0
Feature2's Address(data_ptr): 0X23F5C9B5580, Address(id): 0X23F5F420DB0
=========================[Gradient Function]=========================
input_data's grad_fn: None
weight1's grad_fn: None
feature's grad_fn: <MulBackward0 object at 0x0000023F5F3EDCF0>
identity's grad_fn: <CloneBackward0 object at 0x0000023F5F3EDF00>
weight2's grad_fn: None
feature2's grad_fn: <MulBackward0 object at 0x0000023F5F3EDCF0>
output_data's grad_fn: <AddBackward0 object at 0x0000023F5F3EDF00>
loss's grad_fn: <PowBackward0 object at 0x0000023F5F3EF5B0>
=========================[Value Information]=========================
Output: 8.00
Target: 8.50
Loss: 0.25
Weight1's Gradient to update -4.00
๊ฒฐ๊ณผ๊ฐ ๊ฐ์ต๋๋ค! ๋๊ฐ์ Output์ ๊ฐ์ง๊ณ ์จ๋ค๋ ๊ฒ๋๋ค. ํ์ง๋ง, ์ด๊ฑด Forward Propagation์ ๋ง์กฑํ๋ค๋ ์๋ฆฌ์ผ ๋ฟ, Back Propagation์ ํ์คํ์ง ์์ต๋๋ค.
๊ทธ๋์ ์ ๋ weight1
์ด Back Propagation์ ํ์ ๋, ์
๋ฐ์ดํธ๋ฅผ ํด์ผ ํ๋ ๊ฐ์ ์ถ๋ ฅํ๋๋ก ํด์ฃผ์์ต๋๋ค. ๊ทธ ๊ฒฐ๊ณผ๋ ๊ฐ์ต๋๋ค! ์ฆ, Forward Propagation, Backward Propagation์ ๊ฒฐ๊ณผ๊ฐ ๊ฐ๊ธฐ ๋๋ฌธ์ ์ฒซ ๋ฒ์งธ ์ฝ๋์ ๋ฐฉ์์ด๋ , ๋ ๋ฒ์งธ ์ฝ๋์ ๋ฐฉ์์ด๋ ๋ฌธ์ ๊ฐ ์๋ค๋ ๊ฒ์
๋๋ค.
๊ทธ๋ฌ๋ฉด ๋๋์ฒด identity = x
์ ๋ฐฉ์์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ณต์ ํ๋ค๊ณ ํ๋๋ฐ ์ด๋ป๊ฒ Residual Conneciton์ ๊ฐ๋ฅํ๊ฒ ํ์๊น์? ๋ค์ ๋งํด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ณต์ ํ๊ณ ์๋๋ฐ ๊ฐ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ ๋ณ์์ ๋ํด์ ์ด๋ป๊ฒ ๋
๋ฆฝ์ ์ธ ๊ฐ์ ๊ฐ๋๋ก ํ์๊น์.
๐ Tensor์ Mutable, or Immutable
์ด๊ฒ์ด ๊ฐ๋ฅํ ์ด์ ๋ฅผ ์ฐพ๊ธฐ ์ํด์๋ ๋๊ณ ๋์์ผ ํ์ต๋๋ค. Garbage Collection์ด๋ผ๋ ๊ฐ๋ ๊น์ง ๋ณผ ๋ปํ๋ค๊ฐ ๊ฒจ์ฐ ๊ทธ๊น์ง๋ ๊ฐ์ง ์์์ต๋๋ค. ์ด์จ๋ ์ฐพ์ ๊ทผ๊ฑฐ๋ ๊ฒฐ๊ตญ์ ๋ ๊ธฐ์ด์ ๋ชจ๋ ๊ฒ์ด ์์์ต๋๋ค.
์์ ํ์ด์ฌ ๊ด๋ จ, ํ์ดํ ์น ๊ด๋ จ ํฌ์คํ ์์ ํ์ด์ฌ์ ๋ชจ๋ ๊ฒ์ ๊ฐ์ฒด์ด๋ฉฐ, ์ด๋ Mutableํ๊ฐ, Immutableํ๊ฐ๋ฅผ ๋ค๋ฃฌ ์ ์ด ์์์ต๋๋ค.
๊ฐ๋จํ๊ฒ ๋ค์ ์ธ๊ธํ๋ฉด Mutable์ ์์ ์ฌํญ์ ๋ํด์ ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํ ํ์ ์์ด ๊ทธ๋๋ก ์์ ํ ์ ์๋ ๊ฐ์ฒด์ด๊ณ , Immutable์ ์์ ์ ํ๋๋ฐ ๊ธฐ์กด ๋ฉ๋ชจ๋ฆฌ์์ ์์ ํ ์ ์์ผ๋ ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํด์ผ ํ๋ ๊ฐ์ฒด์ ๋๋ค. (ํญ์ ์ด ๊ฐ๋ ์ ๋ํด์ ๋ธ๋ก๊ทธ์์๋ ๋ถ์์ ์ผ๋ก๋ง ๋ค๋ฃจ์ด ์์ด์ ๋์ค์ ํ ๋ฒ ์ ๋๋ก ๋ค๋ฃจ์ด ๋ณด๊ฒ ์ต๋๋ค.)
์ด์จ๋ , ์ฌ๊ธฐ๊น์ง ์ค๋ช ํ๋ค๋ฉด ๊ฐ์ ์ค์ จ์ ๊ฒ๋๋ค. '์, Tensor๊ฐ Immutable ํ๊ธฐ ๋๋ฌธ์ ๊ตณ์ด x.clone()์ ํด์ฃผ์ง ์์๋ ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํ๋ ๊ฒ์ด๋ผ์ Residual Connection์ ํ ๋, ๋ฌธ์ ๊ฐ ๋์ง ์๋ ๊ฑฐ๊ตฌ๋!'
๋ค, ์ ๋ง ํ๋ฅญํ ์๊ฐ์ ๋๋ค๋ง ์ฌ์ค ์ด๊ฑด ํ๋ ธ์ต๋๋ค. ์๋ํ๋ฉด, Tensor๋ Mutable ํ ๊ฐ์ฒด์ด๊ฑฐ๋ ์.
๋จ, ํน์ ์กฐ๊ฑด์ ๋ํด์๋ง ๋ง์ด์์. ๋ฆฌ์คํธ์์ ์์ ๊ฐ ์์ ๊ฐ์ In-place Operation์ธ ๊ฒฝ์ฐ์๋ ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํ ํ์๊ฐ ์์ต๋๋ค. Mutableํ๋ค๋ ์๋ฏธ์ ๋๋ค.
ํ์ง๋ง, ๋ค๋ฅธ ํ ์์์ ๋ง์ , ๋บ์ , ๊ณฑ์ ๋ฑ๊ณผ ๊ฐ์ Arithmetic Operation์ ๋ํด์๋ Immutability๊ฐ ์ฌ์ฉ๋์ด ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํด์ผ๋ง ํฉ๋๋ค.
๊ทธ๋์, Residual Connection์์๋ Artihmetic Operation์ด 99.999999%๋ก ์ฐ์ด๊ธฐ ๋๋ฌธ์ ์ฐ๋ฆฌ๊ฐ identity = x
์ ๊ฐ์ ์ฝ๋๋ฅผ ์ฌ์ฉํ๋๋ผ๋ x๊ฐ ์ฐ์ฐ์ ์ํํจ์ ๋ฐ๋ผ ๋ฉ๋ชจ๋ฆฌ ์ฃผ์๊ฐ ๋ฌ๋ผ์ง๊ธฐ ๋๋ฌธ์ ๋ฌธ์ ๊ฐ ๋์ง ์๋ ๊ฒ์ด์์ต๋๋ค.
โ Result
์, ์ฌ๊ธฐ๊น์ง๊ฐ ์ ๊ฐ ์๊ฐํ ๋ฌธ์ ์ (๊ถ๊ธ์ )์ ๋ํ ํด๊ฒฐ์ฑ (ํด์์ฑ )์ด์์ต๋๋ค.
๋ค์ ํ๋ฒ ์์ฝ์ ํด๋ณด์๋ฉด, Residual Connection์ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ 2๊ฐ์ง๊ฐ ์์์ต๋๋ค.
1. identity = x.clone()
์ ๋ณ๋์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ฐ์ง๊ธฐ ๋๋ฌธ์ ๋ฌธ์ ๊ฐ ๋์ง ์๋๋ค๊ณ ํ๋จํ์ฌ, identity = x
์ ๋ํด์๋ง ์์๋ณด๊ธฐ ์์
2. Tensor๊ฐ ์ฐ์ฐ์ ์ํํจ์ ๋ฐ๋ผ ๋ณ๋์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ฐ์ง๋ค. ํ์ง๋ง, Tensor๋ Mutable ํ ๊ฐ์ฒด์ด๋ค.
3. ์ด ์ ์์ Tensor๊ฐ In-place Operation์ ๋ํด์๋ Mutability๊ฐ ๋ณด์ฅ๋์ง๋ง, Arithmetic Operation์ ๋ํด์๋ Immutability๊ฐ ๋ณด์ฅ๋๋ค. ์ฆ, ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ ๋น๋๋ค.
4. ๊ทธ๋์ identity = x
or identity = x.clone()
์ ๋ํ ์ ํ์ ์์ด์ ๊ฒฐ๊ณผ์ ์ฐจ์ด๋ ์๋ค.
ํ์ง๋ง, x.clone()
์ .grad
์์ฑ์ ๋ํด์๋ ๋ณต์ฌ๋ฅผ ํ์ง ์๋๋ค๊ณ ํ์ผ๋ฏ๋ก x.clone()
์ด ๋ ํจ์จ์ ์ผ ์๋ ์๊ฒ ๋ค๋ ์ถ์ธก์ด ๋ฉ๋๋ค.
import torch
from memory_profiler import profile
@profile
def copy_tensor():
a = torch.tensor([1], dtype=torch.float32)
a.grad = torch.tensor([2], dtype=torch.float32)
b = a
# b = a.clone()
if __name__ == '__main__':
copy_tensor()
๊ทธ๋์ memory_profiler
๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ํตํด์ ํ๋กํ์ผ๋ง์ ํ ๊ฒฐ๊ณผ, ํ
์์ ๋ฐ์ดํฐ์ ๋ํด์ ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ ๋น๋๊ธฐ ๋๋ฌธ์ x.clone()
์ ์ฌ์ฉํ์ง ์๋ ๊ฒ์ด ๋ ํจ์จ์ ์
๋๋ค. ์ด์ฐจํผ, x
๋ฅผ ์ฐ์ฐํจ์ ๋ฐ๋ผ ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํด์ผ ํ๊ธฐ ๋๋ฌธ์
๋๋ค.
๋ํ, identity
๋ x
๋ ๊ฒฐ๊ตญ input tensor๋ก weight๊ฐ ์๋๊ธฐ ๋๋ฌธ์ gradient ๊ฐ์ ํ์๊ฐ ์์ด์ x.clone()
์ ํ์์ฑ์ ๋ ๋ชป ๋๋ผ๊ฒ ๋๋ ๊ฒ์ด ์๋๊ฐ ์ถ์ต๋๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก, 'Residual Connection์ ํ ๋๋ x.clone()
์ ์ธ ํ์๊ฐ ์๋ค. identity = x
๋ฅผ ์ฐ๋ฉด ๋๋ค.'๋ก ๊ธ์ ๋ง๋ฌด๋ฆฌํ ์ ์์ ๊ฑฐ ๊ฐ์ต๋๋ค.
๐ Code for this post
import torch
import numpy as np
import argparse
import warnings
def residual_test(use_clone: bool = True):
# (1) MAKE RESIDUAL MODEL
input_data = torch.tensor([1.], dtype=torch.float32)
weight1 = torch.tensor([2.], dtype=torch.float32, requires_grad=True)
weight2 = torch.tensor([3.], dtype=torch.float32, requires_grad=True)
target_data = torch.tensor([8.5], dtype=torch.float32)
feature = weight1 * input_data # 2.0
# make identity for residual connection
identity = None # 2.0
print('HOW TO RESIDUAL CONNECTION: ', end='')
if use_clone:
print('identity = feature.clone()')
identity = feature.clone()
else:
print('identity = feature')
identity = feature
feature2 = weight2 * feature # 6.0
output_data = identity + feature2 # 8.0 = 2.0 + 6.0
loss = (output_data - target_data) ** 2 # 0.25 = (8.0 - 8.5) ** 2
loss.backward()
# (2) PRINT INFORMATION
_ANNOTATION_MARK = 25
print('\n' + '=' * _ANNOTATION_MARK +
'[Memory Address]' + '=' * _ANNOTATION_MARK)
print(
f'Identity\'s Address(data_ptr): {str(hex(identity.data_ptr())).upper()}, Address(id): {str(hex(id(identity))).upper()}')
print(
f'Feature\'s Address(data_ptr): {str(hex(feature.data_ptr())).upper()}, Address(id): {str(hex(id(feature))).upper()}')
print(
f'Feature2\'s Address(data_ptr): {str(hex(feature2.data_ptr())).upper()}, Address(id): {str(hex(id(feature2))).upper()}')
print('\n' + '=' * _ANNOTATION_MARK +
'[Gradient Function]' + '=' * _ANNOTATION_MARK)
components_li = ['input_data',
'weight1',
'feature',
'identity',
'weight2',
'feature2',
'output_data',
'loss']
for component in components_li:
temp_str = f'{component}\'s grad_fn:'
print(f'{temp_str:25s} {eval(component).grad_fn}')
# print(hex(weight1.data_ptr()))
print('\n' + '=' * _ANNOTATION_MARK +
'[Value Information]' + '=' * _ANNOTATION_MARK)
print(f'Output: {output_data.item():.2f}')
print(f'Target: {target_data.item():.2f}')
print(f'Loss: {loss.item():.2f}')
print(f'Weight1\'s Gradient to update {weight1.grad.item():.2f}')
def memory_test():
a = np.array([1, 2, 3])
print(f'a\'s address: {hex(id(a))}, a: {a}')
a[2] = 2
print(f'a\'s address: {hex(id(a))}, a: {a}')
b = np.array([2, 4, 6])
a = a + b
print(f'a\'s address: {hex(id(a))}, a: {a}')
def clone_test():
warnings.filterwarnings(action='ignore')
a = torch.tensor([2, 3, 4], dtype=torch.float32, requires_grad=True)
a.grad = torch.tensor([1, 2, 3], dtype=torch.float32)
b = a.clone()
c = a
li = ['a', 'b', 'c']
for tensor_ in li:
print('=' * 60)
print(f'Tensor: {tensor_}')
print(
f'Tensor Memory Address: {str(hex(eval(tensor_).data_ptr())).upper()}')
print(f'Tensor Value: {eval(tensor_).data}')
print(f'Tensor Requires Grad: {eval(tensor_).requires_grad}')
print(f'Tensor Gradient: {eval(tensor_).grad}')
print(f'Tensor Backward Function: {eval(tensor_).grad_fn}')
warnings.filterwarnings(action='default')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--use-clone', action='store_true')
args = parser.parse_args()
residual_test(use_clone=args.use_clone)
# memory_test()
# clone_test()