์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
- ๋ฐฑํธ๋ํน
- tensorflow
- ์ฐ์ ์์ ํ
- DP
- ๊ฐ๋์ ๋ง๋ก
- BFS
- ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ
- c++
- ๋ฏธ๋๋_ํ์ฌ์_๊ณผ๊ฑฐ๋ก
- ํฌ๋ฃจ์ค์นผ
- ๋ถํ ์ ๋ณต
- dropout
- ๊ฐ๋์_๋ง๋ก
- ๋ค์ต์คํธ๋ผ
- 2023
- lazy propagation
- dfs
- ํ๋ก์ด๋ ์์ฌ
- ๋๋น ์ฐ์ ํ์
- object detection
- pytorch
- ์๊ณ ๋ฆฌ์ฆ
- ์ด๋ถ ํ์
- ๋ฌธ์์ด
- ์๋ฐ์คํฌ๋ฆฝํธ
- ํ๊ณ ๋ก
- NEXT
- ์กฐํฉ๋ก
- Overfitting
- back propagation
- Today
- Total
Doby's Lab
torch.where(), ๊ทผ๋ฐ ์ด์ loss๋ฅผ ๋ง๋ค ๋ ๋ง์ด ๊ณ๋ค์ธ ๋ณธ๋ฌธ
torch.where(), ๊ทผ๋ฐ ์ด์ loss๋ฅผ ๋ง๋ค ๋ ๋ง์ด ๊ณ๋ค์ธ
๋๋น(Doby) 2024. 7. 23. 09:14๐ค Problem
PyTorch์์ ์์ฃผ ์ฐ๋ ๋ฉ์๋๋ค์ ์ด๋ ์ ๋ ์ฒด๋ํ๋ฉด์ ๋น ๋ฅด๊ฒ ํธ๋ค๋งํ ์ ์๋๋ก ํ๋ ๊ฒ์ด ์ข์ ๊ฑฐ ๊ฐ์ต๋๋ค. ์ด๋ฒ ๊ธ์ ๊ทธ๋ฌํ ๋ฉ์๋๋ค ์ค์ ์์ฃผ ์ฐ์ด๋ torch.where()
์ ๋ํด์ ์ ๋ฆฌํด ๋ณด์์ต๋๋ค.
์ฌ์ค์ ๋งค์ฐ ๊ฐ๋จํฉ๋๋ค๋ง, ์์ ํ ์ฒด๋์ ํด๋๋ ๊ฒ ์์ผ๋ก ์์ ์ ํธํ ๊ฑฐ ๊ฐ์์ ์ผ๋ถ๋ฌ ๊ธ์ ์๋๋ค.
out = torch.where(condition, input, other)
condition์ BoolTensor
์
๋๋ค. True
์ False
๋ง์ ๋ด๊ณ ์์ผ๋ฉฐ, True
์ธ element ์๋ฆฌ์ input์ด ๋ค์ด๊ฐ๊ณ , False
์ธ element ์๋ฆฌ์ other์ด ๋ค์ด๊ฐ๊ฒ ๋ฉ๋๋ค.
1. condition์ Boolean Masking์ผ๋ก ๋ํ๋ด๊ธฐ
๋ณดํต condition์ ์๋์ tensor์ Boolean masking์ ํ์ฌ, ํ๋์ BoolTensor
๋ก ๋ํ๋
๋๋ค.
(Boolean masking์ด๋ Broadcasting์ ํตํด ๋ค์ฐจ์ ๋ฐฐ์ด ํน์ ํ
์์์ ์ค์นผ๋ผ ๊ฐ์ ๋ํ ์กฐ๊ฑด์ ๋ง์กฑํ๋ฉด, ํด๋น ์๋ฆฌ์ True
, ๋ฐ๋์ด๋ฉด False
๋ก ์ฒ๋ฆฌํ๋ ๊ธฐ๋ฒ์ ์๋ฏธํฉ๋๋ค.)
๊ทธ๋์, torch.where()
๋ฅผ ์ธ ๋ ์์ฃผ ์ฐ๋ ๊ฒ์ Boolean Masking์ผ๋ก condition์ ๋ํ๋ด๋ ๊ฒ์
๋๋ค.
a = torch.tensor([1, 2, 3, 4, 5])
condition = a > 2
print(a)
>>> tensor([False, False, True, True, True])
2. input๊ณผ other๋ ๊ผญ ์ค์นผ๋ผ ๊ฐ์ด ์๋์ด๋ ๋ฉ๋๋ค.
๊ตณ์ด ๊ธ๋ก ๋จ๊ธฐ๋ ค๋ ์ด์ ๋ ์ด ์คํฌ ๋๋ฌธ์ ๋๋ค. ์ ๊ธ๋ง ๋ณด๋ฉด, input๊ณผ other๋ ํน์ ์ค์นผ๋ผ ๊ฐ์ด์ด์ผ๋ง ํ ๊ฑฐ ๊ฐ์ต๋๋ค.
๊ณต์ ๋ฌธ์๋ฅผ ๋ณด๋ฉด, ๊ผญ ๊ทธ๋ ์ง๋ ์์ต๋๋ค. input๊ณผ output์ด ์ค์นผ๋ผ ๊ฐ์ด๋ฉด ์ค์นผ๋ผ ๊ฐ์ ๊ทธ๋๋ก ๋ฃ์ด์ฃผ๋ ๊ฒ๋ ๋ง์ง๋ง, tensor๋ผ๋ฉด ํด๋น ์์น(indices)์ ๋ํ tensor์ ๊ฐ์ ๋ฃ์ด์ค ์ ์๋ค๊ณ ๋ช ์๋์ด ์์ต๋๋ค.
๊ทธ๋์, Boolean masking ์ด์ ์๋ tensor์ ๋ํด์ '๋ด ์กฐ๊ฑด์ ๋ง์กฑํ๋ค๋ฉด ๊ทธ๋๋ก ๋๋๊ณ , ์๋๋ผ๋ฉด ๋ชจ๋ 0์ผ๋ก ์ฒ๋ฆฌํด๋ผ' ๊ฐ์ ์์ ์ ์ํํ๊ณ ์ถ๋ค๋ฉด, ์๋์ ๊ฐ์ ์ฝ๋๋ก ๊ฐ๋จํ๊ฒ ์ฒ๋ฆฌ๊ฐ ๊ฐ๋ฅํฉ๋๋ค.
x = torch.tensor([~~~~~])
condition = x > 0 # tensor element์์ 0๋ณด๋ค ๋์ผ๋ฉด True
out = torch.where(condition, x, 0)
์ด๋ฅผ ์กฐ๊ธ ๋ ์์ฉํด์ ์ํ๋ ์์น์ ๋ํด์๋ง ํน์ ์ฐ์ฐ์ ์ํํ๋ ๋ณต์กํ ์ฝ๋๋ ํ ์ค์ด๋ฉด, ์์ฑ์ด ๊ฐ๋ฅํฉ๋๋ค.
์ด๋ฌํ torch.where()
๋ ํน์ ์กฐ๊ฑด์ ์์ํ๋ ์์ค ๊ฐ์ ๋ฐ์์ํค๊ธฐ ์ํด์ Loss function์ ๊ตฌํํ ๋, ๋ง์ด ์ฌ์ฉํฉ๋๋ค.
๐ Reference
https://pytorch.org/docs/stable/generated/torch.where.html