์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- DP
- ํ๋ก์ด๋ ์์ฌ
- ๋ค์ต์คํธ๋ผ
- pytorch
- ๋๋น ์ฐ์ ํ์
- 2023
- ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ
- tensorflow
- dropout
- BFS
- ์กฐํฉ๋ก
- NEXT
- ์๊ณ ๋ฆฌ์ฆ
- ๋ฐฑํธ๋ํน
- ์๋ฐ์คํฌ๋ฆฝํธ
- ํฌ๋ฃจ์ค์นผ
- ๋ฌธ์์ด
- ํ๊ณ ๋ก
- dfs
- ๋ถํ ์ ๋ณต
- ๊ฐ๋์ ๋ง๋ก
- back propagation
- c++
- ์ด๋ถ ํ์
- ์ฐ์ ์์ ํ
- lazy propagation
- ๊ฐ๋์_๋ง๋ก
- Overfitting
- object detection
- ๋ฏธ๋๋_ํ์ฌ์_๊ณผ๊ฑฐ๋ก
- Today
- Total
Doby's Lab
optimizer.step()์ ์ ๋ง ๊ฐ์ค์น๋ฅผ ๊ฑด๋ค๊น? (Call-by-Assignment) ๋ณธ๋ฌธ
optimizer.step()์ ์ ๋ง ๊ฐ์ค์น๋ฅผ ๊ฑด๋ค๊น? (Call-by-Assignment)
๋๋น(Doby) 2023. 11. 18. 02:42๐ค Question
PyTorch์์๋ ๋ชจ๋ธ์ ํ์ต์ํฌ ๋, (ํนํ Back Propagation์ด ์ผ์ด๋ ๋) ์๋์ ๊ฐ์ ๊ณผ์ ์ ๊ฑฐ์นฉ๋๋ค.
1) optimizer์ ์ข
๋ฅ์ learning rate์ ๋ฐ๋ฅธ optimizer
๋ฅผ ์ ์ธํ๋ฉด์ model์ parameter ๊ฐ์ ๋๊น๋๋ค.
2) .grad
์์ฑ ๊ฐ์ optim.zero_grad()
๋ฅผ ํตํด ์ด๊ธฐํ์ํต๋๋ค. (.grad
accumulate ๋ฐฉ์ง)
3) ๊ทธ ๋ค์, .backward()
๋ฅผ ํธ์ถ์์ผ์ ๊ฐ Tensor์ .grad
์์ฑ์ ๊ฐฑ์ ํฉ๋๋ค.
4) ๋ง์ง๋ง์ผ๋ก, optimizer.step()
์ ํตํด์ ๊ฐ Tensor์ ์
๋ฐ์ดํธ๋ฅผ ์งํํฉ๋๋ค.
์ด ๊ณผ์ ์์ ๊ถ๊ธํ๋ ๋ถ๋ถ์ '์ ๋ง ์ค์ model์ ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธํ๋๊ฐ?'์์ต๋๋ค.
๋๋ฌด๋ ๋น์ฐํ๊ฒ๋ ์ ๋ฐ์ดํธ๋ฅผ ํ๊ฒ ์ง๋ง ์ ๋ฐ์ดํธ๊ฐ ๋๊ณ ์๋์ง ํ์ธํ ์ ์์์๋ฟ๋๋ฌ ์๋์ optimizer ์ ์ธ๊ณผ Tensor(๊ฐ์ค์น) ๊ฐฑ์ ์ฝ๋์์ ์์ ๊ฐ์ ์๋ฌธ์ ์ผ์ผ์ผฐ์ต๋๋ค.
# optimizer ์ ์ธ
optmizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# ๋ชจ๋ธ์ Tensor(๊ฐ์ค์น) ๊ฐฑ์
optimizer.step()
์ด๋ฌํ ์ง๋ฌธ์ ์ด์ ๋ก๋ model.parameters()
๋ฅผ optimizer์ ๋๊ธฐ๊ธฐ๋ง ํ์ ๋ฟ ์ง์ ์ ์ผ๋ก model.parameters()
๋ฅผ ๊ฑด๋๋ ๋ถ๋ถ์ ์์๊ธฐ ๋๋ฌธ์
๋๋ค.
๐ Answer
๐ Experiment
๋ชจ๋ธ์ ๊ฐ์ค์น๊ฐ ์ค์ ๋ก ๋ณํ๋์ง ์ง์ ํ์ธํด ๋ณด๊ธฐ ์ํด์ ๊ฐ๋จํ training์ ํ๋ ์ฝ๋๋ก ์คํ์ ํด๋ดค์ต๋๋ค.
๋์ ๋๋ ๊ฒฐ๊ณผ๋ฅผ ์ํด์ Learning Rate๋ฅผ 0.1๋ก ํฌ๊ฒ ์ก๊ณ ํ์ต์์ผฐ์ต๋๋ค.
(๋ชจ๋ธ์ ์ด๋ฏธ์ง๋ฅผ flatten ์ํจ ๊ฐ๋จํ Linear Regression ๋ชจ๋ธ์ ๋๋ค.)
๋ํ, ์ถ๋ ฅํ๋ ๊ฐ์ค์น๋ Linear Layer์ Tensor ์ค ๋์ผํ ์์์ ๋ํด ์ถ๋ ฅํ๊ฒ ํ์ผ๋ฉฐ, Batch๋ฅผ 10๋ฒ ํ์ต์ํจ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๋๋ก ํ์ต๋๋ค.
def train_loop(model, dataloader, loss_fn, optim):
size = len(dataloader.dataset)
print('[Weight Change Observation]')
for batch, (X, y) in enumerate(dataloader):
# Forward Propagation
pred = model(X)
loss = loss_fn(pred, y)
print('=========================================')
print(f'ORDER: {batch+1}')
print(f'[Before Backward]\n{model.linear2.weight[0,0]}')
# Backward Propagation
optim.zero_grad() # init grad
loss.backward()
optim.step()
print(f'[After Backward]\n{model.linear2.weight[0,0]}')
if batch == (10 - 1):
print('=========================================')
print('print 10 times')
break
์ด์ ๋ํ ๊ฒฐ๊ณผ, ์ถ๋ ฅ์ ๋น์ฐํ ๊ฐ์ค์น์ ๋ณํ๋ฅผ ๋ณด์ฌ์ฃผ์์ต๋๋ค.
[Weight Change Observation]
=========================================
ORDER: 1
[Before Backward]
0.005343694239854813
[After Backward]
0.005184032954275608
=========================================
ORDER: 2
[Before Backward]
0.005184032954275608
[After Backward]
0.00509195402264595
=========================================
ORDER: 3
[Before Backward]
0.00509195402264595
[After Backward]
0.004864421673119068
=========================================
ORDER: 4
[Before Backward]
0.004864421673119068
[After Backward]
0.00480083841830492
=========================================
ORDER: 5
[Before Backward]
0.00480083841830492
[After Backward]
0.00480083841830492
=========================================
ORDER: 6
[Before Backward]
0.00480083841830492
[After Backward]
0.004487521015107632
=========================================
ORDER: 7
[Before Backward]
0.004487521015107632
[After Backward]
0.004345327150076628
=========================================
ORDER: 8
[Before Backward]
0.004345327150076628
[After Backward]
0.004016462713479996
=========================================
ORDER: 9
[Before Backward]
0.004016462713479996
[After Backward]
0.0035967917647212744
=========================================
ORDER: 10
[Before Backward]
0.0035967917647212744
[After Backward]
0.0033088999334722757
=========================================
print 10 times
๐ Call-by-Assignment
๊ฒฐ๊ณผ๋ ํ์ธํ์์ผ๋ ์ ์ด๋ฐ ์๋ฌธ์ ํ์์๊น๋ฅผ ๋๋ฌ์ด ๊ฐ๋ณด๋ Call-by-Value์ Call-by-Reference์ ๋ํ ์๊ฐ์ ๊ธฐ๋ฐ์ ๋๊ณ ๋์จ ์๋ฌธ์ด์์ต๋๋ค.
์ฆ, ์ด ์๋ฌธ์ optimizer.step()
์ ๋ํด์ Call-by-Value๋ผ๊ณ ์๊ฐํ์์ง๋ง, Call-by-Reference์๋ ๊ฒ์ผ๋ก ์ ๋ฆฌ๊ฐ ๋ ์ ์์ต๋๋ค.
ํ์ง๋ง, ์ง์ง ์ค์ํ ๊ฑด ํ์ด์ฌ์ ์ด๋ฐ ์ฒด๊ณ๋ก ๋์๊ฐ์ง ์์ต๋๋ค. ํ์ด์ฌ์ ํจ์ ํธ์ถ ๋ฐฉ์์ Call-by-Assignment๋ง ์กด์ฌํฉ๋๋ค.
Call-by-Assignment๋ ํ์ด์ฌ์ ํจ์ ํธ์ถ ๋ฐฉ์์ผ๋ก, ํ์ด์ฌ์ ๋ชจ๋ ๊ฒ์ ๊ฐ์ฒด์ด๋ฉฐ ์ด ๊ฐ์ฒด์ ์ข ๋ฅ๋ Mutable(๋ณํ ์ ์๋), Immutable(๋ณํ ์ ์๋) ์ด 2๊ฐ์ง๋ก๋ง ์ ์๋ฉ๋๋ค.
์ฆ, ํ์ด์ฌ์ ํจ์ ํธ์ถ์ ๋ํด์ Argument๊ฐ ๊ฐ๋ณ์ฑ ์ฌ๋ถ๋ ํจ์์ ๋ฌ๋ ค์๋ ๊ฒ์ด ์๋ Argument๋ก ๋๊ธฐ๋ ๊ฐ์ฒด์ ํ์ ์ ๋ฌ๋ ค์๋ ๊ฒ์ ๋๋ค.
Mutable์ Call-by-Reference์ ๋ฐ๋ฅด๋ฉฐ list, set, dict ๋ฑ์ ๊ฐ์ฒด๊ฐ ์๊ณ , Immutable์ Call-by-Value๋ฅผ ๋ฐ๋ฅด๋ฉฐ str, int, tuple ๋ฑ์ ๊ฐ์ฒด๊ฐ ์์ต๋๋ค.
โ Summary
์ ๋ฆฌ๋ฅผ ํด๋ณด๋ฉด, ์ด๋ฒ์๋ optimizer์ ๋ํ ์๋ฌธ์ ํตํด ์์ ์ ๊ณต๋ถํ๋ ํ์ด์ฌ์ Call-by-Assignment์ ๋ํด์ ๋ณต์ต์ ํ๊ฒ ๋์๊ณ , Pytorch์์ model์ Parameter๋ Mutableํ ๊ฐ์ฒด๋ผ๋ ๊ฒ์ ์๊ฒ ๋์์ต๋๋ค.