Doby's Lab

CLIP, 단순한 분류의 시대는 지났다 본문

AI/Concepts: Multi-Modal

CLIP, 단순한 분류의 시대는 지났다

도비(Doby) 2024. 2. 14. 16:58

✅ Intro

LLaVA-Med를 공부하면서 Visual Encoder로 CLIP(Contrastive Language-Image Pre-training)이 사용되어 이번 기회에 공부를 해보았습니다.

 

CLIP은 기존 Classification 방식에서 새로운 메커니즘을 제안했습니다. Classification은 수많은 데이터셋에서 라벨링 된 클래스로 분류하는 것이 일반적인 특징입니다. 하지만, 세상에는 여러 가지 사물이 존재하며, 이 사물 또한 어떠한 상태에 있냐에 따라 분류를 할 수 있는 범위는 셀 수 없을 정도로 많습니다. 예를 들어, '일반적인 자전거'와 '바퀴가 없는 자전거'라는 Task로 수많은 사물들이 더 디테일한 description을 원할 때, 단순한 Classification Task만으로는 큰 어려움이 있습니다.

 

그래서 CLIP은 Image-Text Pair 형태의 데이터를 활용하여 더 디테일한 description, 확장성을 가진 caption이라는 특징과 함께 주어진 Image 대해 Description(Text), 혹은 주어진 Text에 대한 Image를 가져올 수 있는 Task를 수행합니다.


✅ Architecture & Clone

그렇다면, 이러한 Task를 어떻게 수행하는지를 알아보기 위해서 CLIP을 간단하게 구현해 놓은 아래 포스팅을 참고했습니다. 이에 따라 클론 코딩을 하여 CLIP에 대한 이해도를 높일 수 있었습니다.

 

Simple Implementation of OpenAI CLIP model: A Tutorial

A tutorial on simple implementation of CLIP model in PyTorch.

towardsdatascience.com

또한, 클론한 파일을 전부 깃허브에 커밋해 두었습니다 :)

 

GitHub - drawcodeboy/CLIP-Implementation-clone: CLIP의 구현을 Clone을 통해서 공부했습니다.

CLIP의 구현을 Clone을 통해서 공부했습니다. Contribute to drawcodeboy/CLIP-Implementation-clone development by creating an account on GitHub.

github.com

CLIP을 학습시키기 위해서는 'Image-Text Pair' 형태의 데이터셋을 사용해야 합니다. 각 이미지가 어떤 것을 나타내는지에 대한 정보를 Text로 담고 있어야 합니다. 참고한 포스팅에서는 Kaggle의 Flickr8k 데이터셋을 사용했습니다.

📄  Image Encoder, ResNet50

CLIP이라는 모델이 Image와 Text를 이해하기 위해서는 Encoding이 필요합니다. 또한, 단순한 Encoding이 아닌 시각적 정보, 언어적 정보를 잘 추출할 수 있는 Encoding을 필요로 합니다. 그러기 위해서는 기존에 Pre-trained된 모델을 사용하여 시각적 정보(Visual features), 언어적 정보(Text features)를 추출하는 것이 좋을 거라는 것은 당연합니다.

 

 

시각적 정보를 추출하기 위해서 본 포스팅에서는 ResNet50이 ImageNet1K를 사전 학습한 모델을 Image Encoder로 사용했습니다. 이때, ResNet50의 역할은 시각적 정보를 추출하는 것일 뿐, 더 성능을 높일 필요는 없기 때문에 모델 프리징을 해주었습니다.

class ImageEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = resnet50(weights='IMAGENET1K_V1')
        self.model = nn.Sequential(*list(self.model.children())[:-1])
        self.flatten = nn.Flatten() # Image output Vector size = 2,048
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.model(x)
        x = self.flatten(x)
        return x

📄  Text Encoder, DistilBERT

이번에는 언어적 정보를 추출하기 위해서 DistilBERT라는 모델을 사용하였습니다. DistilBERT는 BERT의 경량화 버전이라 생각하면 되고, 사전 학습 모델은 Hugging Face의 모델을 사용하였으며, distilbert-base-uncased라는 weight의 사전 학습 모델을 가져왔습니다. 추가적으로, 이때의 uncased의 의미는 '대소문자를 구분하지 않겠다'는 의미입니다.

결론적으로, 여기서도 언어적 정보만을 추출해 주면 되기 때문에 모델 프리징을 해줍니다.

class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')

        for param in self.model.parameters():
            param.requires_grad = False

        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

📄  Features to Embeddings (Projection Head)

각 Encoder(Pre-trained Model)를 통해서 Image features, Text features를 추출해 주었습니다. 각 features는 (batch_size, 2048), (batch_size, 768)의 shape을 가지고 있습니다. 이제 이 features를 CLIP을 위해서 동일한 차원으로 Embedding을 만들어주어야 합니다. '아까부터 왜 features를 추출하고 이젠 Embedding까지 만드느냐?'에 대한 질문은 바로 아래의 Similarity에서 다룰 것이니 조금만 기다려주세요:)

 

256이라는 차원을 가지게 하기 위한 Projection Head는 아래와 같이 구성되었습니다. 단순히 Projection을 하기 위함이라는 목적을 갖고 있는 것에 비해 구성이 조금 복잡합니다. 그에 대한 이유는 features를 추출한 두 모델에서 CLIP으로 넘어가는 다리의 역할을 하고 있기 때문에 Projection Head 부분은 전부 trainable해서 features to embeddings이라는 중요한 역할을 수행해야 하기 때문에 단순한 nn.Linear와 같은 Projection보다는 복잡한 구성을 갖게 됩니다.

class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, 256)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(256, 256)
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(256)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

📄  Simiarity, CLIP

여기서 Feature를 추출하고, Embedding을 계산한 이유를 밝힐 수 있었습니다. CLIP을 공부하다 보면, 제일 많이 보게 되는 figure가 아마 아래 figure일 듯합니다.

 처음에 보았을 때는 조금 의문을 가진 부분이 많았으나 직접 구현해 보고 다시 보니 정말 잘 표현했다는 생각이 듭니다. \([T_1,\dotso,T_N]\)은 Text Embeddings,  \([I_1,\dotso,I_N]\)는 Image Embeddings를 의미합니다. 그런데 figure에서는 이를 Inner Product 해줍니다. Inner Product의 의미에는 Similarity를 측정하기 위함도 있습니다. 그럼 저희가 원하는 건 Batch 내의 Sample(image-text)이 서로 가장 높은 Similarity를 갖게 하는 것이 CLIP이 학습해야 할 목표가 됩니다.

 

그러면, target은 Identity Matrix일 거라 추측할 수도 있겠지만, target으로 사용하는 건 아래와 같습니다. 이러한 이유는 Batch 내의 서로 다른 Sample이 실제로 유사할 수도 있기 때문에 그런 것으로 추측됩니다.

images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T

targets = F.softmax(
    (images_similarity + texts_similarity) / 2.0 * self.temperature, dim=-1
)

결론적으로, CLIP Model의 구성은 아래와 같습니다.

class CLIPModel(nn.Module):
    def __init__(self, temperature, image_embedding, text_embedding):
        super().__init__()
        
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        # Get Image, Text features through ResNet50, DistilBERT
        image_features = self.image_encoder(batch['image'])
        text_features = self.text_encoder(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask']
        )

        # Same DIMENSION EMBEDDING
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T

        targets = F.softmax(
            (images_similarity + texts_similarity) / 2.0 * self.temperature, dim=-1
        )

        texts_loss = self.cross_entropy(logits, targets, reduction='none')
        images_loss = self.cross_entropy(logits.T, targets.T, reduction='none')
        loss = (images_loss + texts_loss) / 2.0
        
        return loss.mean()

    def cross_entropy(self, preds, targets, reduction='none'):
        log_softmax = nn.LogSoftmax(dim=-1)
        loss = (-targets * log_softmax(preds)).sum(1)
        if reduction == 'none':
            return loss
        elif reduction == 'mean':
            return loss.mean()

📄  Architecture Review

정리를 해보면, Pre-trained Model을 통해 Image features, Text features를 추출하고, Projection Head를 통해 Embedding을 만들어주었으며 Similarity(Inner Product)를 계산하여 'Image와 Text가 유사한 정보를 담고 있다'라는 것을 학습시켜 주었습니다. 이걸 figure로 조금 더 정리해 보면, 아래와 같이 나타낼 수 있습니다.

📄  Inference

하지만, 추론하는 방식에 있어서는 '음.. 아쉽다'라고 느꼈었습니다. 기존 Classification 메커니즘을 뛰어넘으려면, input을 넣는 대로 output이 있어야 하지만, input의 Embedding에 대해 가장 높은 Similarity를 가지는 데이터를 찾기 위해 다른 포맷의 Embedding을 계속 대조해 보아야 하기 때문입니다.

def get_image_embeddings(valid_df):
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    valid_loader = build_loaders(valid_df, tokenizer, mode='valid')

    valid_image_embeddings = []
    model.eval()
    with torch.no_grad():
        n_data = 0
        for batch in valid_loader:
            image_features = model.image_encoder(batch['image'].to(device))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)
            n_data += 8
            print(f'\r{100.0*n_data/len(valid_loader.dataset):.2f}%', end='')
    return torch.cat(valid_image_embeddings)
    
def find_matches(model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch['input_ids'], attention_mask=batch['attention_mask']
        )
        text_embeddings = model.text_projection(text_features)

    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T

    # Similarity 값이 높을수록 좋기 때문에 topk 함수를 통해 가장 높은 것 n*5개를 뽑는다.
    # 그리고, 그 중에 9개를 plt.show() 한다.
    values, indices = torch.topk(dot_similarity.squeeze(0), n*5)
    matches = [image_filenames[idx] for idx in indices[::5]]

    _, axes = plt.subplots(3, 3, figsize=(6, 6))
    for match, ax in zip(matches, axes.flatten()):
        image = cv2.imread(f'F:/Doby/CLIP/Flickr8k/Images/{match}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (224, 224))
        ax.imshow(image)
        ax.axis('off')
    plt.suptitle(f'Query [{query}]')
    plt.show()
    
find_matches(
    model,
    image_embeddings,
    query = 'a group of people dancing in a party',
    image_filenames=valid_df['image'].values,
    n=9
)

 

그래서 'Embedding에 대한 Decoding을 해서 데이터를 발생시킬 수는 없을까??'라는 생각도 해보게 되는 CLIP이었습니다.


✅ Outro

이번에는 CLIP을 배워보면서 어떤 원리로 이루어지는가를 클론 코딩을 통해서 배웠습니다. 또한, Hugging Face를 처음 다루어봤고, NLP 쪽 모델도 이번이 첫 핸들링인 만큼 NLP 모델에 대해서도 더 공부를 해볼 필요가 있음을 느꼈습니다.

또한, CLIP 논문에 디테일한 부분 (Temperature, Target이 저런 형태인 이유)에 대해서는 아직 공부하지 않았으니 더 공부해 볼 필요가 있습니다.

마지막으로, CLIP에 대한 개념을 알게 되었으니 원래 목적이었던 LLaVA-Med에서 어떻게 적용되었는가를 살펴보아야 합니다.


✅ Reference

[1] Simple Implementation of OpenAI CLIP model: A Tutorial

[2] Learning Transferable Visual Models From Natural Language Supervision