๊ด€๋ฆฌ ๋ฉ”๋‰ด

Doby's Lab

DataLoader์˜ collate_fn, ์„œ๋กœ ๋‹ค๋ฅธ ์ƒ˜ํ”Œ์˜ ํฌ๊ธฐ๋ฅผ ํ•˜๋‚˜์˜ ๋ฐฐ์น˜๋กœ ๋ฌถ๋Š” ๋ฐฉ๋ฒ• ๋ณธ๋ฌธ

Code about AI/PyTorch

DataLoader์˜ collate_fn, ์„œ๋กœ ๋‹ค๋ฅธ ์ƒ˜ํ”Œ์˜ ํฌ๊ธฐ๋ฅผ ํ•˜๋‚˜์˜ ๋ฐฐ์น˜๋กœ ๋ฌถ๋Š” ๋ฐฉ๋ฒ•

๋„๋น„(Doby) 2024. 7. 26. 00:53

๐Ÿค” Problem

์ž‘์—…์„ ํ•˜๋‹ค๊ฐ€ DataLoader๋ฅผ ์„ ์–ธํ•˜๋Š” ์ฝ”๋“œ ๋ถ€๋ถ„์—์„œ ์ฒ˜์Œ ๋ณด๋Š” Argument๊ฐ€ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์ด Argument์— ๋Œ€ํ•ด์„œ ๊ณต๋ถ€๋ฅผ ํ•˜๋ฉด์„œ '์ด๊ฑด ์–ธ์  ๊ฐ€ ์œ ์šฉํ•˜๊ฒŒ ์“ฐ์ผ ๊ธฐ๋Šฅ์ด๋‹ค!'๋ผ๊ณ  ํŒ๋‹จ์ด ๋“ค์–ด ๊ธ€์„ ๊ธฐ๋กํ•˜๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

 

์šฐ๋ฆฌ๋Š” ๋Œ€๋ถ€๋ถ„ ๋ชจ๋ธ์— ํ•™์Šต์„ ์‹œํ‚ฌ ๋•Œ, ๊ฐ ์ƒ˜ํ”Œ์˜ shape์ด ๊ฑฐ์˜ ๋‹ค ๊ฐ™๋„๋ก ์ „์ฒ˜๋ฆฌ๋ฅผ ํ•ด์„œ ํ•™์Šต์„ ์‹œํ‚ค๊ธฐ ๋•Œ๋ฌธ์— ์‚ฌ์‹ค ์ด ๊ธฐ๋Šฅ์ด ํ”ํ•˜๊ฒŒ ์“ฐ์ด์ง€๋Š” ์•Š์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ, ํŠน์ˆ˜ํ•œ ๊ฒฝ์šฐ์—๋Š” ๊ฐ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์˜ shape์ด ๋˜‘๊ฐ™์ด ์ฒ˜๋ฆฌ๋  ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ๋“ค์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, Object Detection์— ๊ด€ํ•œ ํ”„๋กœ์ ํŠธ๋ฅผ ํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๋ฉด, ๊ฐ ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด Bounding Box์˜ ์ˆ˜๊ฐ€ ๋ชจ๋‘ ๊ฐ™๋‚˜์š”? ๊ฑฐ์˜ ๋Œ€๋ถ€๋ถ„ ๊ทธ๋ ‡์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค.

 

์ด๋Ÿฌํ•œ ์ƒํ™ฉ์— ๋Œ€ํ•ด์„œ ํ•˜๋‚˜์˜ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์˜ ํฌ๊ธฐ๋Š” ๋ชจ๋‘ ์ œ๊ฐ๊ฐ์ž…๋‹ˆ๋‹ค. ์ฆ‰, ๋ ˆ์ด๋ธ”์˜ ํฌ๊ธฐ๊ฐ€ ๋‹ค์–‘ํ•ด์ง„๋‹ค๋Š” ๋ง์ž…๋‹ˆ๋‹ค. (๋ฌผ๋ก , ์‚ฌ์šฉํ•˜๋ ค๋Š” ๋ชจ๋ธ์— ๋”ฐ๋ผ ๋ฌธ์ œ๊ฐ€ ๋˜์ง€๋„ ์•Š์„ ์ˆ˜๋„ ์žˆ์ฃ ) ๊ทธ๋Ÿฌ๋ฉด, ๊ฐ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์˜ ํฌ๊ธฐ๊ฐ€ ๋‹ค๋ฅธ ๊ฒŒ ์™œ ๋ฌธ์ œ๊ฐ€ ๋˜๋‚˜์š”?

 

๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ํ•™์Šต์„ ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋ชจ๋ธ์— ๋ฐฐ์น˜๋ฅผ ๋„˜๊ฒจ์ค„ ๋•Œ, ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์ƒ˜ํ”Œ์„ ๋ฌถ์€ ๋ฐฐ์น˜๋ฅผ ํ•˜๋‚˜์˜ ํ…์„œ๋กœ ๋ชจ๋ธ์— ๋„˜๊ฒจ์ฃผ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ, ํ•˜๋‚˜์˜ ํ…์„œ๋กœ ๋ฌถ์–ด์ฃผ๋ ค๋ฉด ๊ฐ ์ƒ˜ํ”Œ์˜ ํฌ๊ธฐ๊ฐ€ ์™„์ „ํžˆ ๊ฐ™์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค. (๋ฌผ๋ก , ๋ฐฐ์น˜์˜ ์‚ฌ์ด์ฆˆ๊ฐ€ ํ•˜๋‚˜๋ผ๋ฉด ์ œ๊ฐ๊ฐ์ด์–ด๋„ ๋ฉ๋‹ˆ๋‹ค.)


๐Ÿ˜€ Solution

๋‹ค์‹œ ๋งํ•ด, ๋ฐฐ์น˜๋ฅผ ๊ฐ€์ ธ์˜ค๋Š” ์ƒํ™ฉ์—์„œ ๊ฐ ์ƒ˜ํ”Œ์˜ ํฌ๊ธฐ๊ฐ€ ๋‹ค๋ฅด๋‹ค๋ฉด, ์ œ์ผ ํฌ๊ธฐ๊ฐ€ ์ž‘์€ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด์„œ Zero-Padding์„ ํ•˜๋“ , ํŠน์ • Constant๋กœ ๊ฐ’๋“ค์„ ์ฑ„์›Œ ๋„ฃ์–ด ํฌ๊ธฐ๋ฅผ ๋งž์ถ”์–ด์ฃผ๋“  ์ถ”๊ฐ€์ ์ธ ์ž‘์—…์„ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ž‘์—…์„ ํ•ด์ค„ ์ˆ˜ ์žˆ๋Š” ๊ฒƒ์ด ๋ฐ”๋กœ DataLoader์˜ collate_fn์ด๋ผ๋Š” Argument์ž…๋‹ˆ๋‹ค.

dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

์œ„์™€ ๊ฐ™์ด ์„ ์–ธํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, Argument๋กœ๋Š” ๊ฐ ํ…์„œ๋ฅผ ์–ด๋–ป๊ฒŒ ํ•ฉ์น  ๊ฒƒ์ธ์ง€์— ๋Œ€ํ•œ ํ•จ์ˆ˜๋ฅผ ๋„ฃ์–ด์ฃผ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด, ์–ด๋–ป๊ฒŒ ํ•จ์ˆ˜๋ฅผ ์„ ์–ธํ•ด์•ผ ํ•˜๋Š”์ง€ ์•Œ์•„๋ด…์‹œ๋‹ค.

 

collate_fn์˜ Argument๋กœ๋Š” batch๊ฐ€ ๋“ค์–ด์˜ค๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์ด Batch๋Š” Dataset์˜ __getitem__ ๋ฉ”์„œ๋“œ์˜ ๋ฆฌํ„ด ๊ฐ’๋“ค๋กœ ๋ฆฌ์ŠคํŠธ๋ฅผ ๊ตฌ์„ฑํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

์•„๋ž˜์™€ ๊ฐ™์€ ๋ฐ์ดํ„ฐ์…‹ ์ฝ”๋“œ๊ฐ€ ์žˆ๋‹ค๊ณ  ํ•ฉ์‹œ๋‹ค.(๋ ˆํผ๋Ÿฐ์Šค ๋งํฌ

class CustomDataset(Dataset):
    def __init__(self, num=10):
        self.num = num
    
    def __len__(self):
        return self.num

    def __getitem__(self, idx):
        x = torch.tensor([idx] * (idx+1))
        y = torch.tensor([idx])
        return x, y

๋ ˆ์ด๋ธ”์€ ์ „๋ถ€ ๊ฐ™์ง€๋งŒ, input์œผ๋กœ ๋“ค์–ด์˜ค๋Š” ํ…์„œ์˜ ํฌ๊ธฐ๋Š” ๋‹ค ๋‹ค๋ฅด๋„ค์š”.

 

๋งŒ์•ฝ, ์•„๋ž˜์™€ ๊ฐ™์ด Batch Size๊ฐ€ 2์ธ DataLoader๋ฅผ ์„ ์–ธํ•ด ์ค„ ๊ฒƒ์ด๋ผ๋ฉด, ์•„๋ž˜์™€ ๊ฐ™์ด ์„ ์–ธํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

dl = DataLoader(ds, batch_size=2, collate_fn=collate_fn)

๊ทธ๋ฆฌ๊ณ , ์ด๋•Œ input์œผ๋กœ ๋“ค์–ด๊ฐˆ ํ…์„œ์˜ ํฌ๊ธฐ๊ฐ€ ์ œ๊ฐ๊ฐ์ด๊ธฐ ๋•Œ๋ฌธ์— collate_fn์„ ํ†ตํ•ด์„œ 'ํฌ๊ธฐ๊ฐ€ ๋‹ค๋ฅธ ํ…์„œ๋“ค์„ ์–ด๋–ป๊ฒŒ ์ฒ˜๋ฆฌํ•  ๊ฒƒ์ด๋ฉฐ, ์–ด๋–ป๊ฒŒ ํ•˜๋‚˜์˜ Batch๋กœ ํ•ฉ์ณ์ค„ ๊ฒƒ์ธ๊ฐ€'์— ๋Œ€ํ•œ ๊ณผ์ •์„ ์•„๋ž˜์™€ ๊ฐ™์€ ํ•จ์ˆ˜๋กœ ๋งŒ๋“ค์–ด์ค˜์•ผ ํ•ฉ๋‹ˆ๋‹ค.

def collate_fn(batch):
    """
    Args:
        - batch : [(sample_1), (sample_2), ...]
        - sample_n : Dataset.__getitem__(n) => Dataset์˜ getitem ๋ฉ”์„œ๋“œ์—์„œ ์„ ์–ธํ•œ๋Œ€๋กœ ํŠœํ”Œ์ด ๋“ค์–ด์˜ด
    Description:
        - ๊ฒฐ๋ก ์ ์œผ๋กœ collate_fn์˜ argument๋Š” Batch์— ๋Œ€ํ•œ ๊ฐ ์ƒ˜ํ”Œ์ด ํ•œ List์— ๋Œ€ํ•œ ์›์†Œ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์œผ๋ฉฐ, 
        - ์ด ๋ชจ๋“  ์›์†Œ(์ƒ˜ํ”Œ)๋“ค์„ ๋ณธ ํ•จ์ˆ˜ ๋‚ด์—์„œ ์ฒ˜๋ฆฌํ•˜์—ฌ, ๊ฐ™์€ ํฌ๊ธฐ์˜ Tensor๋ฅผ ๊ฐ–๋„๋ก ํ•œ ๋‹ค์Œ์—
        - ๋งˆ์ง€๋ง‰์œผ๋กœ, ์ด Tensor๋“ค์„ ๋ชจ๋‘ ํ•ฉ์น˜๋Š” ๊ณผ์ •์„ ๊ฑฐ์น˜๋„๋ก ํ•œ๋‹ค.
    """
    
    # Batch ์ค‘์— ๊ฐ€์žฅ ๊ธด size๋ฅผ ๊ฐ–๋Š” ํ…์„œ์˜ size ์ฐพ๊ธฐ
    longest_size = max([sample[0].shape[0] for sample in batch])
    
    # ๊ธด ์‚ฌ์ด์ฆˆ์˜ ํ…์„œ์— ๋งž์ถฐ์„œ Zero-padding
    x_batch = [F.pad(sample[0], (0, longest_size - sample[0].shape[0]), 'constant', 0) for sample in batch]
    
    # ์‚ฌ์ด์ฆˆ๋ฅผ ๋งž์ถ˜ ํ…์„œ๋“ค์„ ํ•˜๋‚˜์˜ ํ…์„œ๋กœ ๊ฒฐํ•ฉ
    x_batch = torch.stack([x for x in x_batch], dim=0)
    
    # label ๋˜ํ•œ ๊ฒฐํ•ฉ
    y_batch = torch.stack([sample[1] for sample in batch], dim=0)
    return x_batch, y_batch

์œ„์™€ ๊ฐ™์ด Argument๋กœ ๋“ค์–ด์˜ค๋Š” Batch๊ฐ€ ์–ด๋–ค ํ˜•ํƒœ์ธ์ง€๋ฅผ ์ž˜ ์ดํ•ดํ•˜๊ณ  ์žˆ๋‹ค๋ฉด, ์›ํ•˜๋Š” ํ˜•ํƒœ์˜ Batch๋กœ ๋ฆฌํ„ดํ•˜์—ฌ ๋™์ผํ•œ ํฌ๊ธฐ์˜ ํ…์„œ๋กœ ๋งŒ๋“ค์–ด์„œ ํ•™์Šต์„ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

์ด๊ฒƒ์„ ์ดํ•ดํ•˜๊ธฐ ์ œ์ผ ์‰ฝ๋„๋ก ํ•˜๋Š” ๊ฒƒ์€ 'Batch๋กœ ๋“ค์–ด์˜ค๋Š” Argument๋Š” ๊ธฐ์กด ๋ฐ์ดํ„ฐ์…‹ ์ฝ”๋“œ์˜__getitem__ ๋ฉ”์„œ๋“œ๊ฐ€ ์—ฌ๋Ÿฌ ๋ฒˆ ํ˜ธ์ถœ๋œ ๋ฆฌํ„ด ๊ฐ’๋“ค์ด ๋ฆฌ์ŠคํŠธ๋กœ ๋‹ด๊ฒจ์žˆ์„ ๋ฟ์ด๊ณ , ์ด๋ฅผ ํ•จ์ˆ˜ ๋‚ด์—์„œ ์–ด๋–ป๊ฒŒ ์ฒ˜๋ฆฌํ•˜์—ฌ ์›ํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋‚˜ํƒ€๋‚ผ ๊ฒƒ์ธ๊ฐ€'๋ฅผ ์ œ์ผ ํ•ต์‹ฌ์œผ๋กœ ์—ฌ๊ธด๋‹ค๋ฉด ์–ด๋ ต์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๐Ÿ“„ ์›๋ž˜์˜ DataLoader๋กœ Batch๋ฅผ ํ•˜๋‚˜์”ฉ ์ถœ๋ ฅํ–ˆ์„ ๋•Œ

tensor([[0]])
tensor([[0]])
---
tensor([[1, 1]])
tensor([[1]])
---
tensor([[2, 2, 2]])
tensor([[2]])
---
tensor([[3, 3, 3, 3]])
tensor([[3]])
---
tensor([[4, 4, 4, 4, 4]])
tensor([[4]])
---
tensor([[5, 5, 5, 5, 5, 5]])
tensor([[5]])
---
tensor([[6, 6, 6, 6, 6, 6, 6]])
tensor([[6]])
---
tensor([[7, 7, 7, 7, 7, 7, 7, 7]])
tensor([[7]])
---
tensor([[8, 8, 8, 8, 8, 8, 8, 8, 8]])
tensor([[8]])
---
tensor([[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]])
tensor([[9]])
---

๐Ÿ“„ collate_fn์„ ๊ตฌ์„ฑํ•˜๊ณ , Batch Size๋ฅผ 2๋กœ ์„ค์ •ํ•˜์—ฌ DataLoader๋ฅผ ์ถœ๋ ฅํ–ˆ์„ ๋•Œ

for x, y in dl:
	print(x, y, end='\n---\n', sep='\n')
    
[Output]
tensor([[0, 0],
        [1, 1]])
tensor([[0],
        [1]])
---
tensor([[2, 2, 2, 0],
        [3, 3, 3, 3]])
tensor([[2],
        [3]])
---
tensor([[4, 4, 4, 4, 4, 0],
        [5, 5, 5, 5, 5, 5]])
tensor([[4],
        [5]])
---
tensor([[6, 6, 6, 6, 6, 6, 6, 0],
        [7, 7, 7, 7, 7, 7, 7, 7]])
tensor([[6],
        [7]])
---
tensor([[8, 8, 8, 8, 8, 8, 8, 8, 8, 0],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9]])
tensor([[8],
        [9]])
---

๐Ÿ“‚ Reference

https://seokhee0516.tistory.com/entry/Pytorch-collatefn-%EC%9D%B4%EB%9E%80

 

Pytorch collate_fn ์ด๋ž€?

DataLoader์—๋Š” ์—ฌ๋Ÿฌ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์žˆ์–ด ํ•„์š”์‹œ ์ ์ ˆํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํ™œ์šฉํ•ด ์—ฌ๋Ÿฌ ์„ค์ •์„ ์ค„ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ ์ค‘์—์„œ๋„ collate_fn์€ variable length๊ฐ€ ๋‹ฌ๋ผ์„œ, ํŒจ๋”ฉํ•ด์ค„ ๋•Œ ์‚ฌ์šฉํ•œ๋‹ค. from torch.utils.data import Dataset,

seokhee0516.tistory.com

 

728x90