์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
31 |
- 2023
- ํฌ๋ฃจ์ค์นผ
- ํ๊ณ ๋ก
- ๋ฐฑํธ๋ํน
- ๋ฏธ๋๋_ํ์ฌ์_๊ณผ๊ฑฐ๋ก
- Overfitting
- ์๋ฐ์คํฌ๋ฆฝํธ
- dropout
- ์ฐ์ ์์ ํ
- DP
- lazy propagation
- ์ด๋ถ ํ์
- back propagation
- ๋ฌธ์์ด
- ๋ถํ ์ ๋ณต
- tensorflow
- ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ
- ์กฐํฉ๋ก
- ์๊ณ ๋ฆฌ์ฆ
- ๋๋น ์ฐ์ ํ์
- pytorch
- ๊ฐ๋์_๋ง๋ก
- BFS
- ๊ฐ๋์ ๋ง๋ก
- dfs
- NEXT
- c++
- ํ๋ก์ด๋ ์์ฌ
- object detection
- ๋ค์ต์คํธ๋ผ
- Today
- Total
Doby's Lab
torch.where(), ๊ทผ๋ฐ ์ด์ loss๋ฅผ ๋ง๋ค ๋ ๋ง์ด ๊ณ๋ค์ธ ๋ณธ๋ฌธ
torch.where(), ๊ทผ๋ฐ ์ด์ loss๋ฅผ ๋ง๋ค ๋ ๋ง์ด ๊ณ๋ค์ธ
๋๋น(Doby) 2024. 7. 23. 09:14๐ค Problem
PyTorch์์ ์์ฃผ ์ฐ๋ ๋ฉ์๋๋ค์ ์ด๋ ์ ๋ ์ฒด๋ํ๋ฉด์ ๋น ๋ฅด๊ฒ ํธ๋ค๋งํ ์ ์๋๋ก ํ๋ ๊ฒ์ด ์ข์ ๊ฑฐ ๊ฐ์ต๋๋ค. ์ด๋ฒ ๊ธ์ ๊ทธ๋ฌํ ๋ฉ์๋๋ค ์ค์ ์์ฃผ ์ฐ์ด๋ torch.where()
์ ๋ํด์ ์ ๋ฆฌํด ๋ณด์์ต๋๋ค.
์ฌ์ค์ ๋งค์ฐ ๊ฐ๋จํฉ๋๋ค๋ง, ์์ ํ ์ฒด๋์ ํด๋๋ ๊ฒ ์์ผ๋ก ์์ ์ ํธํ ๊ฑฐ ๊ฐ์์ ์ผ๋ถ๋ฌ ๊ธ์ ์๋๋ค.
out = torch.where(condition, input, other)
condition์ BoolTensor
์
๋๋ค. True
์ False
๋ง์ ๋ด๊ณ ์์ผ๋ฉฐ, True
์ธ element ์๋ฆฌ์ input์ด ๋ค์ด๊ฐ๊ณ , False
์ธ element ์๋ฆฌ์ other์ด ๋ค์ด๊ฐ๊ฒ ๋ฉ๋๋ค.
1. condition์ Boolean Masking์ผ๋ก ๋ํ๋ด๊ธฐ
๋ณดํต condition์ ์๋์ tensor์ Boolean masking์ ํ์ฌ, ํ๋์ BoolTensor
๋ก ๋ํ๋
๋๋ค.
(Boolean masking์ด๋ Broadcasting์ ํตํด ๋ค์ฐจ์ ๋ฐฐ์ด ํน์ ํ
์์์ ์ค์นผ๋ผ ๊ฐ์ ๋ํ ์กฐ๊ฑด์ ๋ง์กฑํ๋ฉด, ํด๋น ์๋ฆฌ์ True
, ๋ฐ๋์ด๋ฉด False
๋ก ์ฒ๋ฆฌํ๋ ๊ธฐ๋ฒ์ ์๋ฏธํฉ๋๋ค.)
๊ทธ๋์, torch.where()
๋ฅผ ์ธ ๋ ์์ฃผ ์ฐ๋ ๊ฒ์ Boolean Masking์ผ๋ก condition์ ๋ํ๋ด๋ ๊ฒ์
๋๋ค.
a = torch.tensor([1, 2, 3, 4, 5])
condition = a > 2
print(a)
>>> tensor([False, False, True, True, True])
2. input๊ณผ other๋ ๊ผญ ์ค์นผ๋ผ ๊ฐ์ด ์๋์ด๋ ๋ฉ๋๋ค.
๊ตณ์ด ๊ธ๋ก ๋จ๊ธฐ๋ ค๋ ์ด์ ๋ ์ด ์คํฌ ๋๋ฌธ์ ๋๋ค. ์ ๊ธ๋ง ๋ณด๋ฉด, input๊ณผ other๋ ํน์ ์ค์นผ๋ผ ๊ฐ์ด์ด์ผ๋ง ํ ๊ฑฐ ๊ฐ์ต๋๋ค.
๊ณต์ ๋ฌธ์๋ฅผ ๋ณด๋ฉด, ๊ผญ ๊ทธ๋ ์ง๋ ์์ต๋๋ค. input๊ณผ output์ด ์ค์นผ๋ผ ๊ฐ์ด๋ฉด ์ค์นผ๋ผ ๊ฐ์ ๊ทธ๋๋ก ๋ฃ์ด์ฃผ๋ ๊ฒ๋ ๋ง์ง๋ง, tensor๋ผ๋ฉด ํด๋น ์์น(indices)์ ๋ํ tensor์ ๊ฐ์ ๋ฃ์ด์ค ์ ์๋ค๊ณ ๋ช ์๋์ด ์์ต๋๋ค.
๊ทธ๋์, Boolean masking ์ด์ ์๋ tensor์ ๋ํด์ '๋ด ์กฐ๊ฑด์ ๋ง์กฑํ๋ค๋ฉด ๊ทธ๋๋ก ๋๋๊ณ , ์๋๋ผ๋ฉด ๋ชจ๋ 0์ผ๋ก ์ฒ๋ฆฌํด๋ผ' ๊ฐ์ ์์ ์ ์ํํ๊ณ ์ถ๋ค๋ฉด, ์๋์ ๊ฐ์ ์ฝ๋๋ก ๊ฐ๋จํ๊ฒ ์ฒ๋ฆฌ๊ฐ ๊ฐ๋ฅํฉ๋๋ค.
x = torch.tensor([~~~~~])
condition = x > 0 # tensor element์์ 0๋ณด๋ค ๋์ผ๋ฉด True
out = torch.where(condition, x, 0)
์ด๋ฅผ ์กฐ๊ธ ๋ ์์ฉํด์ ์ํ๋ ์์น์ ๋ํด์๋ง ํน์ ์ฐ์ฐ์ ์ํํ๋ ๋ณต์กํ ์ฝ๋๋ ํ ์ค์ด๋ฉด, ์์ฑ์ด ๊ฐ๋ฅํฉ๋๋ค.
์ด๋ฌํ torch.where()
๋ ํน์ ์กฐ๊ฑด์ ์์ํ๋ ์์ค ๊ฐ์ ๋ฐ์์ํค๊ธฐ ์ํด์ Loss function์ ๊ตฌํํ ๋, ๋ง์ด ์ฌ์ฉํฉ๋๋ค.
๐ Reference
https://pytorch.org/docs/stable/generated/torch.where.html
torch.where — PyTorch 2.3 documentation
Shortcuts
pytorch.org