๊ด€๋ฆฌ ๋ฉ”๋‰ด

Doby's Lab

optimizer.step()์€ ์ •๋ง ๊ฐ€์ค‘์น˜๋ฅผ ๊ฑด๋“ค๊นŒ? (Call-by-Assignment) ๋ณธ๋ฌธ

Code about AI/PyTorch

optimizer.step()์€ ์ •๋ง ๊ฐ€์ค‘์น˜๋ฅผ ๊ฑด๋“ค๊นŒ? (Call-by-Assignment)

๋„๋น„(Doby) 2023. 11. 18. 02:42

๐Ÿค” Question

PyTorch์—์„œ๋Š” ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ฌ ๋•Œ, (ํŠนํžˆ Back Propagation์ด ์ผ์–ด๋‚  ๋•Œ) ์•„๋ž˜์™€ ๊ฐ™์€ ๊ณผ์ •์„ ๊ฑฐ์นฉ๋‹ˆ๋‹ค.

 

1) optimizer์˜ ์ข…๋ฅ˜์™€ learning rate์— ๋”ฐ๋ฅธ optimizer๋ฅผ ์„ ์–ธํ•˜๋ฉด์„œ model์˜ parameter ๊ฐ’์„ ๋„˜๊น๋‹ˆ๋‹ค.

2) .grad ์†์„ฑ ๊ฐ’์„ optim.zero_grad()๋ฅผ ํ†ตํ•ด ์ดˆ๊ธฐํ™”์‹œํ‚ต๋‹ˆ๋‹ค. (.grad accumulate ๋ฐฉ์ง€)

3) ๊ทธ ๋‹ค์Œ, .backward()๋ฅผ ํ˜ธ์ถœ์‹œ์ผœ์„œ ๊ฐ Tensor์˜ .grad ์†์„ฑ์„ ๊ฐฑ์‹ ํ•ฉ๋‹ˆ๋‹ค. 

4) ๋งˆ์ง€๋ง‰์œผ๋กœ, optimizer.step()์„ ํ†ตํ•ด์„œ ๊ฐ Tensor์˜ ์—…๋ฐ์ดํŠธ๋ฅผ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

 

์ด ๊ณผ์ •์—์„œ ๊ถ๊ธˆํ–ˆ๋˜ ๋ถ€๋ถ„์€ '์ •๋ง ์‹ค์ œ model์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์—…๋ฐ์ดํŠธํ•˜๋Š”๊ฐ€?'์˜€์Šต๋‹ˆ๋‹ค.

 

๋„ˆ๋ฌด๋‚˜ ๋‹น์—ฐํ•˜๊ฒŒ๋„ ์—…๋ฐ์ดํŠธ๋ฅผ ํ•˜๊ฒ ์ง€๋งŒ ์—…๋ฐ์ดํŠธ๊ฐ€ ๋˜๊ณ  ์žˆ๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์—†์—ˆ์„๋ฟ๋”๋Ÿฌ ์•„๋ž˜์˜ optimizer ์„ ์–ธ๊ณผ Tensor(๊ฐ€์ค‘์น˜) ๊ฐฑ์‹  ์ฝ”๋“œ์—์„œ ์œ„์™€ ๊ฐ™์€ ์˜๋ฌธ์„ ์ผ์œผ์ผฐ์Šต๋‹ˆ๋‹ค.

# optimizer ์„ ์–ธ
optmizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# ๋ชจ๋ธ์˜ Tensor(๊ฐ€์ค‘์น˜) ๊ฐฑ์‹ 
optimizer.step()

์ด๋Ÿฌํ•œ ์งˆ๋ฌธ์˜ ์ด์œ ๋กœ๋Š” model.parameters()๋ฅผ optimizer์— ๋„˜๊ธฐ๊ธฐ๋งŒ ํ–ˆ์„ ๋ฟ ์ง์ ‘์ ์œผ๋กœ model.parameters()๋ฅผ ๊ฑด๋“œ๋Š” ๋ถ€๋ถ„์€ ์—†์—ˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.


๐Ÿ˜€ Answer

๐Ÿ“„ Experiment

๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๊ฐ€ ์‹ค์ œ๋กœ ๋ณ€ํ•˜๋Š”์ง€ ์ง์ ‘ ํ™•์ธํ•ด ๋ณด๊ธฐ ์œ„ํ•ด์„œ ๊ฐ„๋‹จํžˆ training์„ ํ•˜๋Š” ์ฝ”๋“œ๋กœ ์‹คํ—˜์„ ํ•ด๋ดค์Šต๋‹ˆ๋‹ค.

๋ˆˆ์— ๋„๋Š” ๊ฒฐ๊ณผ๋ฅผ ์œ„ํ•ด์„œ Learning Rate๋ฅผ 0.1๋กœ ํฌ๊ฒŒ ์žก๊ณ  ํ•™์Šต์‹œ์ผฐ์Šต๋‹ˆ๋‹ค.

(๋ชจ๋ธ์€ ์ด๋ฏธ์ง€๋ฅผ flatten ์‹œํ‚จ ๊ฐ„๋‹จํ•œ Linear Regression ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.)

๋˜ํ•œ, ์ถœ๋ ฅํ•˜๋Š” ๊ฐ€์ค‘์น˜๋Š” Linear Layer์˜ Tensor ์ค‘ ๋™์ผํ•œ ์›์†Œ์— ๋Œ€ํ•ด ์ถœ๋ ฅํ•˜๊ฒŒ ํ–ˆ์œผ๋ฉฐ, Batch๋ฅผ 10๋ฒˆ ํ•™์Šต์‹œํ‚จ ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•˜๋„๋ก ํ–ˆ์Šต๋‹ˆ๋‹ค.

def train_loop(model, dataloader, loss_fn, optim):
    size = len(dataloader.dataset)
    print('[Weight Change Observation]')
    for batch, (X, y) in enumerate(dataloader):
        # Forward Propagation
        pred = model(X)
        loss = loss_fn(pred, y)
        
        print('=========================================')
        print(f'ORDER: {batch+1}')
        print(f'[Before Backward]\n{model.linear2.weight[0,0]}')
        
        # Backward Propagation
        optim.zero_grad() # init grad
        loss.backward()
        optim.step()
        
        print(f'[After Backward]\n{model.linear2.weight[0,0]}')
        
        if batch == (10 - 1):
            print('=========================================')
            print('print 10 times')
            break

์ด์— ๋Œ€ํ•œ ๊ฒฐ๊ณผ, ์ถœ๋ ฅ์€ ๋‹น์—ฐํžˆ ๊ฐ€์ค‘์น˜์˜ ๋ณ€ํ™”๋ฅผ ๋ณด์—ฌ์ฃผ์—ˆ์Šต๋‹ˆ๋‹ค.

[Weight Change Observation]
=========================================
ORDER: 1
[Before Backward]
0.005343694239854813
[After Backward]
0.005184032954275608
=========================================
ORDER: 2
[Before Backward]
0.005184032954275608
[After Backward]
0.00509195402264595
=========================================
ORDER: 3
[Before Backward]
0.00509195402264595
[After Backward]
0.004864421673119068
=========================================
ORDER: 4
[Before Backward]
0.004864421673119068
[After Backward]
0.00480083841830492
=========================================
ORDER: 5
[Before Backward]
0.00480083841830492
[After Backward]
0.00480083841830492
=========================================
ORDER: 6
[Before Backward]
0.00480083841830492
[After Backward]
0.004487521015107632
=========================================
ORDER: 7
[Before Backward]
0.004487521015107632
[After Backward]
0.004345327150076628
=========================================
ORDER: 8
[Before Backward]
0.004345327150076628
[After Backward]
0.004016462713479996
=========================================
ORDER: 9
[Before Backward]
0.004016462713479996
[After Backward]
0.0035967917647212744
=========================================
ORDER: 10
[Before Backward]
0.0035967917647212744
[After Backward]
0.0033088999334722757
=========================================
print 10 times

๐Ÿ“„ Call-by-Assignment

๊ฒฐ๊ณผ๋Š” ํ™•์ธํ•˜์˜€์œผ๋‹ˆ ์™œ ์ด๋Ÿฐ ์˜๋ฌธ์„ ํ’ˆ์—ˆ์„๊นŒ๋ฅผ ๋”๋“ฌ์–ด ๊ฐ€๋ณด๋‹ˆ Call-by-Value์™€ Call-by-Reference์— ๋Œ€ํ•œ ์ƒ๊ฐ์— ๊ธฐ๋ฐ˜์„ ๋‘๊ณ  ๋‚˜์˜จ ์˜๋ฌธ์ด์—ˆ์Šต๋‹ˆ๋‹ค.

 

์ฆ‰, ์ด ์˜๋ฌธ์€ optimizer.step()์— ๋Œ€ํ•ด์„œ Call-by-Value๋ผ๊ณ  ์ƒ๊ฐํ–ˆ์—ˆ์ง€๋งŒ, Call-by-Reference์˜€๋˜ ๊ฒƒ์œผ๋กœ ์ •๋ฆฌ๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

ํ•˜์ง€๋งŒ, ์ง„์งœ ์ค‘์š”ํ•œ ๊ฑด ํŒŒ์ด์ฌ์€ ์ด๋Ÿฐ ์ฒด๊ณ„๋กœ ๋Œ์•„๊ฐ€์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํŒŒ์ด์ฌ์˜ ํ•จ์ˆ˜ ํ˜ธ์ถœ ๋ฐฉ์‹์€ Call-by-Assignment๋งŒ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.

 

Call-by-Assignment๋ž€ ํŒŒ์ด์ฌ์˜ ํ•จ์ˆ˜ ํ˜ธ์ถœ ๋ฐฉ์‹์œผ๋กœ, ํŒŒ์ด์ฌ์˜ ๋ชจ๋“  ๊ฒƒ์€ ๊ฐ์ฒด์ด๋ฉฐ ์ด ๊ฐ์ฒด์˜ ์ข…๋ฅ˜๋Š” Mutable(๋ณ€ํ•  ์ˆ˜ ์žˆ๋Š”), Immutable(๋ณ€ํ•  ์ˆ˜ ์—†๋Š”) ์ด 2๊ฐ€์ง€๋กœ๋งŒ ์ •์˜๋ฉ๋‹ˆ๋‹ค.

 

์ฆ‰, ํŒŒ์ด์ฌ์˜ ํ•จ์ˆ˜ ํ˜ธ์ถœ์— ๋Œ€ํ•ด์„œ Argument๊ฐ€ ๊ฐ€๋ณ€์„ฑ ์—ฌ๋ถ€๋Š” ํ•จ์ˆ˜์— ๋‹ฌ๋ ค์žˆ๋Š” ๊ฒƒ์ด ์•„๋‹Œ Argument๋กœ ๋„˜๊ธฐ๋Š” ๊ฐ์ฒด์˜ ํƒ€์ž…์— ๋‹ฌ๋ ค์žˆ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

 

Mutable์€ Call-by-Reference์— ๋”ฐ๋ฅด๋ฉฐ list, set, dict ๋“ฑ์˜ ๊ฐ์ฒด๊ฐ€ ์žˆ๊ณ , Immutable์€ Call-by-Value๋ฅผ ๋”ฐ๋ฅด๋ฉฐ str, int, tuple ๋“ฑ์˜ ๊ฐ์ฒด๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.


โœ… Summary

์ •๋ฆฌ๋ฅผ ํ•ด๋ณด๋ฉด, ์ด๋ฒˆ์—๋Š” optimizer์— ๋Œ€ํ•œ ์˜๋ฌธ์„ ํ†ตํ•ด ์˜ˆ์ „์— ๊ณต๋ถ€ํ–ˆ๋˜ ํŒŒ์ด์ฌ์˜ Call-by-Assignment์— ๋Œ€ํ•ด์„œ ๋ณต์Šต์„ ํ•˜๊ฒŒ ๋˜์—ˆ๊ณ , Pytorch์—์„œ model์˜ Parameter๋Š” Mutableํ•œ ๊ฐ์ฒด๋ผ๋Š” ๊ฒƒ์„ ์•Œ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

728x90