์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
- ๊ฐ๋์ ๋ง๋ก
- NEXT
- ํฌ๋ฃจ์ค์นผ
- c++
- ์๊ณ ๋ฆฌ์ฆ
- dfs
- ๋ค์ต์คํธ๋ผ
- back propagation
- 2023
- object detection
- BFS
- ์ฐ์ ์์ ํ
- ๋ฐฑํธ๋ํน
- ํ๋ก์ด๋ ์์ฌ
- ๋ถํ ์ ๋ณต
- ๋๋น ์ฐ์ ํ์
- ์๋ฐ์คํฌ๋ฆฝํธ
- pytorch
- ์กฐํฉ๋ก
- ๊ฐ๋์_๋ง๋ก
- ์ด๋ถ ํ์
- lazy propagation
- Overfitting
- ๋ฌธ์์ด
- tensorflow
- ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ
- DP
- ํ๊ณ ๋ก
- dropout
- ๋ฏธ๋๋_ํ์ฌ์_๊ณผ๊ฑฐ๋ก
- Today
- Total
Doby's Lab
DataLoader์ collate_fn, ์๋ก ๋ค๋ฅธ ์ํ์ ํฌ๊ธฐ๋ฅผ ํ๋์ ๋ฐฐ์น๋ก ๋ฌถ๋ ๋ฐฉ๋ฒ ๋ณธ๋ฌธ
DataLoader์ collate_fn, ์๋ก ๋ค๋ฅธ ์ํ์ ํฌ๊ธฐ๋ฅผ ํ๋์ ๋ฐฐ์น๋ก ๋ฌถ๋ ๋ฐฉ๋ฒ
๋๋น(Doby) 2024. 7. 26. 00:53๐ค Problem
์์
์ ํ๋ค๊ฐ DataLoader
๋ฅผ ์ ์ธํ๋ ์ฝ๋ ๋ถ๋ถ์์ ์ฒ์ ๋ณด๋ Argument๊ฐ ์์์ต๋๋ค. ์ด Argument์ ๋ํด์ ๊ณต๋ถ๋ฅผ ํ๋ฉด์ '์ด๊ฑด ์ธ์ ๊ฐ ์ ์ฉํ๊ฒ ์ฐ์ผ ๊ธฐ๋ฅ์ด๋ค!'๋ผ๊ณ ํ๋จ์ด ๋ค์ด ๊ธ์ ๊ธฐ๋กํ๊ฒ ๋์์ต๋๋ค.
์ฐ๋ฆฌ๋ ๋๋ถ๋ถ ๋ชจ๋ธ์ ํ์ต์ ์ํฌ ๋, ๊ฐ ์ํ์ shape์ด ๊ฑฐ์ ๋ค ๊ฐ๋๋ก ์ ์ฒ๋ฆฌ๋ฅผ ํด์ ํ์ต์ ์ํค๊ธฐ ๋๋ฌธ์ ์ฌ์ค ์ด ๊ธฐ๋ฅ์ด ํํ๊ฒ ์ฐ์ด์ง๋ ์์ ๊ฒ์ ๋๋ค. ํ์ง๋ง, ํน์ํ ๊ฒฝ์ฐ์๋ ๊ฐ ๋ฐ์ดํฐ ์ํ์ shape์ด ๋๊ฐ์ด ์ฒ๋ฆฌ๋ ์ ์๋ ๊ฒฝ์ฐ๋ค์ด ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, Object Detection์ ๊ดํ ํ๋ก์ ํธ๋ฅผ ํ๋ค๊ณ ๊ฐ์ ํ๋ฉด, ๊ฐ ์ด๋ฏธ์ง์ ๋ํด Bounding Box์ ์๊ฐ ๋ชจ๋ ๊ฐ๋์? ๊ฑฐ์ ๋๋ถ๋ถ ๊ทธ๋ ์ง๋ ์์ต๋๋ค.
์ด๋ฌํ ์ํฉ์ ๋ํด์ ํ๋์ ๋ฐ์ดํฐ ์ํ์ ํฌ๊ธฐ๋ ๋ชจ๋ ์ ๊ฐ๊ฐ์ ๋๋ค. ์ฆ, ๋ ์ด๋ธ์ ํฌ๊ธฐ๊ฐ ๋ค์ํด์ง๋ค๋ ๋ง์ ๋๋ค. (๋ฌผ๋ก , ์ฌ์ฉํ๋ ค๋ ๋ชจ๋ธ์ ๋ฐ๋ผ ๋ฌธ์ ๊ฐ ๋์ง๋ ์์ ์๋ ์์ฃ ) ๊ทธ๋ฌ๋ฉด, ๊ฐ ๋ฐ์ดํฐ ์ํ์ ํฌ๊ธฐ๊ฐ ๋ค๋ฅธ ๊ฒ ์ ๋ฌธ์ ๊ฐ ๋๋์?
๋ฐฐ์น ๋จ์๋ก ํ์ต์ ํ๊ธฐ ์ํด์๋ ๋ชจ๋ธ์ ๋ฐฐ์น๋ฅผ ๋๊ฒจ์ค ๋, ์ฌ๋ฌ ๊ฐ์ ์ํ์ ๋ฌถ์ ๋ฐฐ์น๋ฅผ ํ๋์ ํ ์๋ก ๋ชจ๋ธ์ ๋๊ฒจ์ฃผ๊ฒ ๋ฉ๋๋ค. ๊ทธ๋์, ํ๋์ ํ ์๋ก ๋ฌถ์ด์ฃผ๋ ค๋ฉด ๊ฐ ์ํ์ ํฌ๊ธฐ๊ฐ ์์ ํ ๊ฐ์์ผ ํฉ๋๋ค. (๋ฌผ๋ก , ๋ฐฐ์น์ ์ฌ์ด์ฆ๊ฐ ํ๋๋ผ๋ฉด ์ ๊ฐ๊ฐ์ด์ด๋ ๋ฉ๋๋ค.)
๐ Solution
๋ค์ ๋งํด, ๋ฐฐ์น๋ฅผ ๊ฐ์ ธ์ค๋ ์ํฉ์์ ๊ฐ ์ํ์ ํฌ๊ธฐ๊ฐ ๋ค๋ฅด๋ค๋ฉด, ์ ์ผ ํฌ๊ธฐ๊ฐ ์์ ์ํ์ ๋ํด์ Zero-Padding์ ํ๋ , ํน์ Constant๋ก ๊ฐ๋ค์ ์ฑ์ ๋ฃ์ด ํฌ๊ธฐ๋ฅผ ๋ง์ถ์ด์ฃผ๋ ์ถ๊ฐ์ ์ธ ์์
์ ํ์๋ก ํฉ๋๋ค. ์ด๋ฌํ ์์
์ ํด์ค ์ ์๋ ๊ฒ์ด ๋ฐ๋ก DataLoader
์ collate_fn
์ด๋ผ๋ Argument์
๋๋ค.
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
์์ ๊ฐ์ด ์ ์ธํ ์ ์์ผ๋ฉฐ, Argument๋ก๋ ๊ฐ ํ ์๋ฅผ ์ด๋ป๊ฒ ํฉ์น ๊ฒ์ธ์ง์ ๋ํ ํจ์๋ฅผ ๋ฃ์ด์ฃผ์ด์ผ ํฉ๋๋ค. ๊ทธ๋ฌ๋ฉด, ์ด๋ป๊ฒ ํจ์๋ฅผ ์ ์ธํด์ผ ํ๋์ง ์์๋ด ์๋ค.
collate_fn
์ Argument๋ก๋ batch๊ฐ ๋ค์ด์ค๊ฒ ๋ฉ๋๋ค. ์ด Batch๋ Dataset
์ __getitem__
๋ฉ์๋์ ๋ฆฌํด ๊ฐ๋ค๋ก ๋ฆฌ์คํธ๋ฅผ ๊ตฌ์ฑํ๊ฒ ๋ฉ๋๋ค.
์๋์ ๊ฐ์ ๋ฐ์ดํฐ์ ์ฝ๋๊ฐ ์๋ค๊ณ ํฉ์๋ค.(๋ ํผ๋ฐ์ค ๋งํฌ)
class CustomDataset(Dataset):
def __init__(self, num=10):
self.num = num
def __len__(self):
return self.num
def __getitem__(self, idx):
x = torch.tensor([idx] * (idx+1))
y = torch.tensor([idx])
return x, y
๋ ์ด๋ธ์ ์ ๋ถ ๊ฐ์ง๋ง, input์ผ๋ก ๋ค์ด์ค๋ ํ ์์ ํฌ๊ธฐ๋ ๋ค ๋ค๋ฅด๋ค์.
๋ง์ฝ, ์๋์ ๊ฐ์ด Batch Size๊ฐ 2์ธ DataLoader๋ฅผ ์ ์ธํด ์ค ๊ฒ์ด๋ผ๋ฉด, ์๋์ ๊ฐ์ด ์ ์ธํ ๊ฒ์ ๋๋ค.
dl = DataLoader(ds, batch_size=2, collate_fn=collate_fn)
๊ทธ๋ฆฌ๊ณ , ์ด๋ input์ผ๋ก ๋ค์ด๊ฐ ํ
์์ ํฌ๊ธฐ๊ฐ ์ ๊ฐ๊ฐ์ด๊ธฐ ๋๋ฌธ์ collate_fn
์ ํตํด์ 'ํฌ๊ธฐ๊ฐ ๋ค๋ฅธ ํ
์๋ค์ ์ด๋ป๊ฒ ์ฒ๋ฆฌํ ๊ฒ์ด๋ฉฐ, ์ด๋ป๊ฒ ํ๋์ Batch๋ก ํฉ์ณ์ค ๊ฒ์ธ๊ฐ'์ ๋ํ ๊ณผ์ ์ ์๋์ ๊ฐ์ ํจ์๋ก ๋ง๋ค์ด์ค์ผ ํฉ๋๋ค.
def collate_fn(batch):
"""
Args:
- batch : [(sample_1), (sample_2), ...]
- sample_n : Dataset.__getitem__(n) => Dataset์ getitem ๋ฉ์๋์์ ์ ์ธํ๋๋ก ํํ์ด ๋ค์ด์ด
Description:
- ๊ฒฐ๋ก ์ ์ผ๋ก collate_fn์ argument๋ Batch์ ๋ํ ๊ฐ ์ํ์ด ํ List์ ๋ํ ์์๋ก ๊ตฌ์ฑ๋์ด ์์ผ๋ฉฐ,
- ์ด ๋ชจ๋ ์์(์ํ)๋ค์ ๋ณธ ํจ์ ๋ด์์ ์ฒ๋ฆฌํ์ฌ, ๊ฐ์ ํฌ๊ธฐ์ Tensor๋ฅผ ๊ฐ๋๋ก ํ ๋ค์์
- ๋ง์ง๋ง์ผ๋ก, ์ด Tensor๋ค์ ๋ชจ๋ ํฉ์น๋ ๊ณผ์ ์ ๊ฑฐ์น๋๋ก ํ๋ค.
"""
# Batch ์ค์ ๊ฐ์ฅ ๊ธด size๋ฅผ ๊ฐ๋ ํ
์์ size ์ฐพ๊ธฐ
longest_size = max([sample[0].shape[0] for sample in batch])
# ๊ธด ์ฌ์ด์ฆ์ ํ
์์ ๋ง์ถฐ์ Zero-padding
x_batch = [F.pad(sample[0], (0, longest_size - sample[0].shape[0]), 'constant', 0) for sample in batch]
# ์ฌ์ด์ฆ๋ฅผ ๋ง์ถ ํ
์๋ค์ ํ๋์ ํ
์๋ก ๊ฒฐํฉ
x_batch = torch.stack([x for x in x_batch], dim=0)
# label ๋ํ ๊ฒฐํฉ
y_batch = torch.stack([sample[1] for sample in batch], dim=0)
return x_batch, y_batch
์์ ๊ฐ์ด Argument๋ก ๋ค์ด์ค๋ Batch๊ฐ ์ด๋ค ํํ์ธ์ง๋ฅผ ์ ์ดํดํ๊ณ ์๋ค๋ฉด, ์ํ๋ ํํ์ Batch๋ก ๋ฆฌํดํ์ฌ ๋์ผํ ํฌ๊ธฐ์ ํ ์๋ก ๋ง๋ค์ด์ ํ์ต์ ์งํํ ์ ์์ต๋๋ค.
์ด๊ฒ์ ์ดํดํ๊ธฐ ์ ์ผ ์ฝ๋๋ก ํ๋ ๊ฒ์ 'Batch๋ก ๋ค์ด์ค๋ Argument๋ ๊ธฐ์กด ๋ฐ์ดํฐ์
์ฝ๋์__getitem__
๋ฉ์๋๊ฐ ์ฌ๋ฌ ๋ฒ ํธ์ถ๋ ๋ฆฌํด ๊ฐ๋ค์ด ๋ฆฌ์คํธ๋ก ๋ด๊ฒจ์์ ๋ฟ์ด๊ณ , ์ด๋ฅผ ํจ์ ๋ด์์ ์ด๋ป๊ฒ ์ฒ๋ฆฌํ์ฌ ์ํ๋ ํํ๋ก ๋ํ๋ผ ๊ฒ์ธ๊ฐ'๋ฅผ ์ ์ผ ํต์ฌ์ผ๋ก ์ฌ๊ธด๋ค๋ฉด ์ด๋ ต์ง ์์ต๋๋ค.
๐ ์๋์ DataLoader๋ก Batch๋ฅผ ํ๋์ฉ ์ถ๋ ฅํ์ ๋
tensor([[0]])
tensor([[0]])
---
tensor([[1, 1]])
tensor([[1]])
---
tensor([[2, 2, 2]])
tensor([[2]])
---
tensor([[3, 3, 3, 3]])
tensor([[3]])
---
tensor([[4, 4, 4, 4, 4]])
tensor([[4]])
---
tensor([[5, 5, 5, 5, 5, 5]])
tensor([[5]])
---
tensor([[6, 6, 6, 6, 6, 6, 6]])
tensor([[6]])
---
tensor([[7, 7, 7, 7, 7, 7, 7, 7]])
tensor([[7]])
---
tensor([[8, 8, 8, 8, 8, 8, 8, 8, 8]])
tensor([[8]])
---
tensor([[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]])
tensor([[9]])
---
๐ collate_fn์ ๊ตฌ์ฑํ๊ณ , Batch Size๋ฅผ 2๋ก ์ค์ ํ์ฌ DataLoader๋ฅผ ์ถ๋ ฅํ์ ๋
for x, y in dl:
print(x, y, end='\n---\n', sep='\n')
[Output]
tensor([[0, 0],
[1, 1]])
tensor([[0],
[1]])
---
tensor([[2, 2, 2, 0],
[3, 3, 3, 3]])
tensor([[2],
[3]])
---
tensor([[4, 4, 4, 4, 4, 0],
[5, 5, 5, 5, 5, 5]])
tensor([[4],
[5]])
---
tensor([[6, 6, 6, 6, 6, 6, 6, 0],
[7, 7, 7, 7, 7, 7, 7, 7]])
tensor([[6],
[7]])
---
tensor([[8, 8, 8, 8, 8, 8, 8, 8, 8, 0],
[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]])
tensor([[8],
[9]])
---
๐ Reference
https://seokhee0516.tistory.com/entry/Pytorch-collatefn-%EC%9D%B4%EB%9E%80
Pytorch collate_fn ์ด๋?
DataLoader์๋ ์ฌ๋ฌ ํ๋ผ๋ฏธํฐ๊ฐ ์์ด ํ์์ ์ ์ ํ ํ๋ผ๋ฏธํฐ๋ฅผ ํ์ฉํด ์ฌ๋ฌ ์ค์ ์ ์ค ์ ์๋ค. ๊ทธ ์ค์์๋ collate_fn์ variable length๊ฐ ๋ฌ๋ผ์, ํจ๋ฉํด์ค ๋ ์ฌ์ฉํ๋ค. from torch.utils.data import Dataset,
seokhee0516.tistory.com