๊ด€๋ฆฌ ๋ฉ”๋‰ด

Doby's Lab

torch.where(), ๊ทผ๋ฐ ์ด์ œ loss๋ฅผ ๋งŒ๋“ค ๋•Œ ๋งŽ์ด ๊ณ๋“ค์ธ ๋ณธ๋ฌธ

Code about AI/PyTorch

torch.where(), ๊ทผ๋ฐ ์ด์ œ loss๋ฅผ ๋งŒ๋“ค ๋•Œ ๋งŽ์ด ๊ณ๋“ค์ธ

๋„๋น„(Doby) 2024. 7. 23. 09:14

๐Ÿค” Problem

PyTorch์—์„œ ์ž์ฃผ ์“ฐ๋Š” ๋ฉ”์„œ๋“œ๋“ค์€ ์–ด๋Š ์ •๋„ ์ฒด๋“ํ•˜๋ฉด์„œ ๋น ๋ฅด๊ฒŒ ํ•ธ๋“ค๋งํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์€ ๊ฑฐ ๊ฐ™์Šต๋‹ˆ๋‹ค. ์ด๋ฒˆ ๊ธ€์€ ๊ทธ๋Ÿฌํ•œ ๋ฉ”์„œ๋“œ๋“ค ์ค‘์— ์ž์ฃผ ์“ฐ์ด๋Š” torch.where()์— ๋Œ€ํ•ด์„œ ์ •๋ฆฌํ•ด ๋ณด์•˜์Šต๋‹ˆ๋‹ค.

 

์‚ฌ์‹ค์€ ๋งค์šฐ ๊ฐ„๋‹จํ•ฉ๋‹ˆ๋‹ค๋งŒ, ์™„์ „ํžˆ ์ฒด๋“์„ ํ•ด๋‘๋Š” ๊ฒŒ ์•ž์œผ๋กœ ์ž‘์—…์— ํŽธํ•  ๊ฑฐ ๊ฐ™์•„์„œ ์ผ๋ถ€๋Ÿฌ ๊ธ€์„ ์”๋‹ˆ๋‹ค.

out = torch.where(condition, input, other)

condition์€ BoolTensor์ž…๋‹ˆ๋‹ค. True์™€ False๋งŒ์„ ๋‹ด๊ณ  ์žˆ์œผ๋ฉฐ, True์ธ element ์ž๋ฆฌ์— input์ด ๋“ค์–ด๊ฐ€๊ณ , False์ธ element ์ž๋ฆฌ์— other์ด ๋“ค์–ด๊ฐ€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

1. condition์€ Boolean Masking์œผ๋กœ ๋‚˜ํƒ€๋‚ด๊ธฐ

๋ณดํ†ต condition์€ ์›๋ž˜์˜ tensor์— Boolean masking์„ ํ•˜์—ฌ, ํ•˜๋‚˜์˜ BoolTensor๋กœ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

(Boolean masking์ด๋ž€ Broadcasting์„ ํ†ตํ•ด ๋‹ค์ฐจ์› ๋ฐฐ์—ด ํ˜น์€ ํ…์„œ์—์„œ ์Šค์นผ๋ผ ๊ฐ’์— ๋Œ€ํ•œ ์กฐ๊ฑด์„ ๋งŒ์กฑํ•˜๋ฉด, ํ•ด๋‹น ์ž๋ฆฌ์— True, ๋ฐ˜๋Œ€์ด๋ฉด False๋กœ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ธฐ๋ฒ•์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.) 

๊ทธ๋ž˜์„œ, torch.where()๋ฅผ ์“ธ ๋•Œ ์ž์ฃผ ์“ฐ๋Š” ๊ฒƒ์€ Boolean Masking์œผ๋กœ condition์„ ๋‚˜ํƒ€๋‚ด๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

a = torch.tensor([1, 2, 3, 4, 5])
condition = a > 2
print(a)
>>> tensor([False, False, True, True, True])

2. input๊ณผ other๋Š” ๊ผญ ์Šค์นผ๋ผ ๊ฐ’์ด ์•„๋‹ˆ์–ด๋„ ๋ฉ๋‹ˆ๋‹ค.

๊ตณ์ด ๊ธ€๋กœ ๋‚จ๊ธฐ๋ ค๋Š” ์ด์œ ๋Š” ์ด ์Šคํ‚ฌ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ์œ„ ๊ธ€๋งŒ ๋ณด๋ฉด, input๊ณผ other๋Š” ํŠน์ • ์Šค์นผ๋ผ ๊ฐ’์ด์–ด์•ผ๋งŒ ํ•  ๊ฑฐ ๊ฐ™์Šต๋‹ˆ๋‹ค. 

๊ณต์‹ ๋ฌธ์„œ๋ฅผ ๋ณด๋ฉด, ๊ผญ ๊ทธ๋ ‡์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค. input๊ณผ output์ด ์Šค์นผ๋ผ ๊ฐ’์ด๋ฉด ์Šค์นผ๋ผ ๊ฐ’์„ ๊ทธ๋Œ€๋กœ ๋„ฃ์–ด์ฃผ๋Š” ๊ฒƒ๋„ ๋งž์ง€๋งŒ, tensor๋ผ๋ฉด ํ•ด๋‹น ์œ„์น˜(indices)์— ๋Œ€ํ•œ tensor์˜ ๊ฐ’์„ ๋„ฃ์–ด์ค„ ์ˆ˜ ์žˆ๋‹ค๊ณ  ๋ช…์‹œ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๋ž˜์„œ, Boolean masking ์ด์ „ ์›๋ž˜ tensor์— ๋Œ€ํ•ด์„œ '๋‚ด ์กฐ๊ฑด์„ ๋งŒ์กฑํ•œ๋‹ค๋ฉด ๊ทธ๋Œ€๋กœ ๋†”๋‘๊ณ , ์•„๋‹ˆ๋ผ๋ฉด ๋ชจ๋‘ 0์œผ๋กœ ์ฒ˜๋ฆฌํ•ด๋ผ' ๊ฐ™์€ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด, ์•„๋ž˜์™€ ๊ฐ™์€ ์ฝ”๋“œ๋กœ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ฒ˜๋ฆฌ๊ฐ€ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

x = torch.tensor([~~~~~])
condition = x > 0 # tensor element์—์„œ 0๋ณด๋‹ค ๋†’์œผ๋ฉด True
out = torch.where(condition, x, 0)

์ด๋ฅผ ์กฐ๊ธˆ ๋” ์‘์šฉํ•ด์„œ ์›ํ•˜๋Š” ์œ„์น˜์— ๋Œ€ํ•ด์„œ๋งŒ ํŠน์ • ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ณต์žกํ•œ ์ฝ”๋“œ๋„ ํ•œ ์ค„์ด๋ฉด, ์ž‘์„ฑ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

 

์ด๋Ÿฌํ•œ torch.where()๋Š” ํŠน์ • ์กฐ๊ฑด์— ์ƒ์‘ํ•˜๋Š” ์†์‹ค ๊ฐ’์„ ๋ฐ˜์˜์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„œ Loss function์„ ๊ตฌํ˜„ํ•  ๋•Œ, ๋งŽ์ด ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.


๐Ÿ“‚ Reference

https://pytorch.org/docs/stable/generated/torch.where.html

 

torch.where — PyTorch 2.3 documentation

Shortcuts

pytorch.org

 

728x90