LLM & LMM

[PyTorch] Qwen-VL model을 이용한 이미지 캡셔닝 (Image Captioning)

cherie-ssom 2025. 11. 26. 14:45

오늘은 오픈 소스 모델 중에서도 비전-언어 이해와 생성 작업에 탁월한 Qwen2.5-VL 모델을 활용해서 이미지 캡셔닝 작업을 진행해 보았다.

앞서 설명했던 다국어 임베딩 능력을 위해서 지식 증류 학습을 진행했던 CLIP 모델에 또 다른(나의 데이터 셋에 적합한) 이미지-텍스트 데이터쌍을 추가로 학습시켜 contrastive learning을 진행하려고 했다. 그런데 그 이전에 필요한 데이터 셋이 필요한데, 이를 만들어주기 위해서 다양한 방법을 시도했고, 그중 하나인 Qwen2.5-VL을 활용해서 이미지를 텍스트로 캡셔닝하는 방법을 짧게 공유해보고자 한다.

1. 필요한 라이브러리 및 모델 로딩

import os, json
import pandas as pd
from tqdm import tqdm

import boto3
import pyarrow.parquet as pymysql
from dotenv import load_dotenv
from io import BytesIO
import io
from PIL import Image

import gc
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
warnings.filterwarnings("ignore")


# 1. 모델 로드 (양자화 없음)
print("모델 및 토크나이저 로딩 중..")
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
tokenizer = AutoProcessor.from_pretrained(model_id)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
model.eval()
print("모델 로딩 완료.")

모델은 허깅페이스에 사용하는 방식이 간단히 설명이 되어있기 때문에 참고해서 모델을 불러오면 된다.

 

2. 이미지 로드 및 모델 설정

def load_image_from_s3(bucket, key):
    """S3에서 이미지 로드"""
    try:
        obj = s3_client.get_object(Bucket=bucket, Key=key)
        img_bytes = obj["Body"].read()
        img = Image.open(BytesIO(img_bytes)).convert("RGB")
        return img
    except Exception as e:
        print(f"S3 로드 오류 ({key}): {str(e)}")
        return None

def clean_caption(caption):
    """캡션에서 불필요한 토큰 및 아티팩트 제거"""
    if caption is None:
        return None
    
    # 내부 토큰 제거
    caption = caption.replace("addCriterion", "").strip()
    caption = caption.replace("自动生成", "").strip()
    
    # 연속 개행 정리
    while "\n\n\n" in caption:
        caption = caption.replace("\n\n\n", "\n\n")
    
    # 앞뒤 공백/개행 제거
    caption = caption.strip()
    
    return caption if caption else None

def caption_image(image, model, tokenizer):
    """이미지에서 설명과 감성 추출"""
    try:
        # 프롬프트: 간단한 설명 + 감성 분석
        query = """Describe this image briefly in 1-2 sentences. 
                    Then, describe the emotional tone or mood of this image using adjectives like: warm, cold, energetic, calm, melancholic, joyful, mysterious, etc."""
        
        messages = [
            {"role": "user", "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": query}
            ]}
        ]
        
        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            return_tensors="pt",
            return_dict=True
        ).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=150,
                do_sample=False,
            )
        
        # 생성된 토큰만 추출
        generated_ids = output[0][inputs['input_ids'].shape[1]:]
        caption = tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        return caption
    except Exception as e:
        print(f"캡셔닝 오류: {str(e)}")
        traceback.print_exc()
        return None
    finally:
        torch.cuda.empty_cache()

여기부터 조금 복잡해 보일 수 있는데 맨 앞 함수부터 차곡차곡 살펴보겠다. 우선 나의 경우에는 이미지를 s3에서 불러와야 하기 때문에, 맨 앞에서 이미지를 불러오는 함수를 작성했다.
(각자가 가지고 있는 데이터 셋을 알맞게 불러와주면 된다.)

그리고 두 번째는 이미지를 캡셔닝할 때, 불필요한 토큰들이 같이 출력되어서 나오는 경우가 있어서 이런 부분은 삭제하고 출력할 수 있게끔 수정을 했다. 사실 이건 몇 번 샘플로 돌려보고 나서 이런 특수 토큰들이 출력된다는 걸 알게 돼서 추후에 삽입하게 된 함수다. 이제 이렇게 특수 토큰을 삭제하는 함수도 만들어둔다면, 다음에는 진짜로 이미지 캡셔닝을 위한 지시사항인 프롬프트를 입력해 주면 된다. (사실 한국어로 하고 싶은데, 한국어 자체는 잘 인식이 안 되는 듯했다.)

내가 하고 싶었던 건 해당 이미지를 간단히 1-2줄로 설명을 하되, 해당 이미지의 특유의 감성이나 무드 같은 것을 형용사랑 함께 설명해 주면 좋겠다고 생각했다. 다시 말해서 특정 이미지가 가지고 있는 어떤 감성을 라벨링 하고 싶었다. 이렇게 프롬프트를 입력하고 나면 여타 모델을 쓰는 것과 비슷하게 와 같은 함수를 통해서 지정해 준 프롬프트와 형식을 전달해 주면 된다. 그리고 그렇게 생성된 토큰을 다시 디코드 해주면 원하는 이미지를 넣고 이를 설명하는 캡셔닝 데이터를 얻을 수 있다.

 

3. 이미지 캡셔닝 처리

# 2. 배치 처리
bucket_name = 'data'
keys = final_img_path_df['s3_path'].to_list()[:]  # 전체 데이터
batch_size = 100 
output_file = "qwen_emotion_captions.jsonl"

# 기존 결과 확인 (중단됐을 경우 재개)
processed_keys = set()
if os.path.exists(output_file):
    with open(output_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            processed_keys.add(data['key'])
    print(f"이전에 처리한 파일: {len(processed_keys)}개 (재개 모드)")
else:
    print("새로 시작합니다")

# 배치 처리
print(f"\n총 {len(keys)}개 이미지 캡셔닝 시작...")
print(f"배치 크기: {batch_size}, 예상 배치 수: {len(keys)//batch_size + 1}")

with open(output_file, 'a', encoding='utf-8') as f:
    for i, key in enumerate(tqdm(keys, desc="이미지 캡셔닝")):
        # 이미 처리한 것은 스킵
        if key in processed_keys:
            continue
        
        # 이미지 로드
        img = load_image_from_s3(bucket_name, key)
        caption = None
        
        if img is not None:
            caption = caption_image(img, model, tokenizer)
            caption = clean_caption(caption)  # 아티팩트 제거
        
        # 결과 저장
        result = {"key": key, "caption": caption}
        f.write(json.dumps(result, ensure_ascii=False) + '\n')
        f.flush()
        
        # 메모리 정리
        if img is not None:
            del img
        gc.collect()
        
        # 배치마다 GPU 캐시 정리
        if (i + 1) % batch_size == 0:
            torch.cuda.empty_cache()
            completed = len(processed_keys) + (i + 1)
            percentage = (completed / len(keys)) * 100
            print(f"\n배치 {(i+1)//batch_size} 완료 ({completed}/{len(keys)}, {percentage:.1f}%), GPU 메모리 정리됨")

print("\n" + "="*50)
print("캡셔닝 완료!")
print("="*50)

# 3. 최종 통계
total_processed = len(processed_keys) + len(keys)
with open(output_file, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    total_lines = len(lines)
    success_count = sum(1 for line in lines if json.loads(line).get('caption') is not None)

print(f"\n최종 통계:")
print(f"  총 이미지: {len(keys)}개")
print(f"  처리 완료: {total_lines}개")
print(f"  성공: {success_count}개 ({success_count/total_lines*100:.1f}%)")
print(f"  실패: {total_lines - success_count}개")

# 4. 샘플 결과
print("\n=== 샘플 결과 (마지막 3개) ===")
try:
    with open(output_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines[-3:]:
            data = json.loads(line)
            if data['caption']:
                print(f"\n키: {data['key']}")
                print(f"캡션:\n{data['caption']}")
            else:
                print(f"\n키: {data['key']}")
                print(f"캡션: [실패]")
except Exception as e:
    print(f"샘플 읽기 오류: {e}")

# 메모리 정리
del model
torch.cuda.empty_cache()



이렇게 함수가 다 완성이 됐으면, 실제로 데이터를 불러오고 만들어둔 함수를 활용해서 이미지 캡셔닝을 진행하면 된다. 다만, 리소스를 워낙 많이 먹기 때문에 중간중간 각 배치마다 캐시를 삭제해 주는 코드를 작성했다. (그런데 사실 이미지 자체를 데이터로더 같은 거로 병렬로 불러오는 게 아니라서 사실 이 배치는 그저 GPU와 메모리를 초기화해주기 위한 장치로 생각하면 될 것 같다. (이렇게 중간중간 캐시를 초기화해주지 않으면 램이 터져버리게 된다..)

조금 더 효율적인 코드를 짜고 싶었는데, 어떻게 해야 이 이미지 데이터를 더 효율적으로 불러올 수 있을지 고민이 조금 더 필요할 것 같다.. (나의 경우에는 10만 개가 넘는 데이터를 돌리다 보니 일주일이라는 상당한 시간이 소요가 됐다.. 그래서 이 부분에 대해서는 추후에 고도화가 된다면 업데이트를 진행.. 해보고자 한다..) 어짜피 generate 하려면 이미지가 하나하나 들어가야하는 게 맞지만, 그 전에 이미지를 불러오는 방식이나 여러 이미지 자체를 모델에 하는건 가능한 것으로 알고 있기 때문에, 이런 부분들을 수정해준다면 조금은 빨라지지 않을까 생각한다.

오늘은 이렇게 해서 오픈 소스 모델 중에서도 다방면에서 성능이 좋기로 유명한 모델을 활용해서 이미지-텍스트 캡셔닝 작업을 진행해 보자 했다.

오늘도 도움이 되셨다면 좋겠습니다.

감사합니다.