Paper Review

[Paper Review] Altas : Few-shot Learning with Retrieval Augmented Lagnguage Mode

cherie-ssom 2025. 10. 24. 11:01

오늘은 다양한 RAG 관련 논문에 이어 Altas 논문에 대해서 정리를 해보고자 한다.

1. 서론

대규모 언어 모델은 다양한 작업에서 few-shot 성능을 보이긴 했지만, Q&A와 같은 사실 확인과 같은 지식 집중형 작업에서는 막대한 크기의 파라미터와 사전 학습 데이터의 규모가 커져야 좋은 성능을 냈다.  

이는 다시 말해서 few-shot 성능은 모델의 파라미터 크기와 학습 데이터 크기와 관련이 있다고 말할 수 있다. 하지만 few-shot의 성능은 단순히 파라미터 내부에 암기한 결과와 동일하진 않다. 즉, 모델의 파라미터가 커야만 few-shot의 성능이 나오는 것인지, 혹은 더 효율적인 방법이 있는지에 대한 의문이 남아있기 때문에 얼마나 많은 파라미터와 데이터가 필요한지에 대한 명확한 기준이나 효율적인 정도가 불분명하다는 것이다.

따라서, 해당 논문에서는 파라미터를 무작정 늘려 지식을 암기하는 llm의 한계를 극복하고, 외부 지식에 접근하여 효율적으로 최신 지식을 활용할 수 있는   즉, non-parametric knowledge를 이용하기 위한 새로운 retrieval-augmented architecture를 제안하여 더 적은 파라미터를 가졌음에도 불구하고 뛰어난 few-shot 학습 능력을 보여주는 것을 목표로 한다.

 

2. Model Architecture 

그렇다면 이제 모델 구조와 학습 방식에 대해서 한번 살펴보도록 하자.
Altas는 크게 retriever(검색기)와 언어모델(Reader)의 두 가지로 이루어져 있다.

Retriever

먼저, retriever는 dual-encoder 아키텍처를 사용한 contriever를 기반으로 하는 검색 모듈이다.

먼저, dual-encoder architecture란 query와 document를 서로 다른 encoder를 통해 독립적으로 임베딩을 진행하는 방식이다. 그리고 average pooling을 거쳐서 최종 임베딩이 나온다. 즉, 서로 다른 encoder를 통해 임베딩을 진행하지만 최종적으로 같은 임베딩 공간에서 유사도를 계산하는 구조이다. (query와 document를 각각 임베딩한 후에 각각의 유사도를 같은 공간에서 구하는 구조) 이때, query와 document의 유사도는 dot-product를 통해 얻어진다.

(추가적으로 설명을 하자면, 이 방식의 장점은 모든 문서를 미리 임베딩해서 벡터 DB에 저장을 해두고 검색 시 query 하나만 인코딩을 하면 되기 때문에, 수억 개의 문서에 대해서도 빠른 검색이 가능하고, 서로 다른 목적으로 각각의 encoder를 fine-tuning 할 수 있기 때문에 retriever를 설계할 때, 대부분 dual-encoder 구조를 쓴다.)

그렇다면 이제 본격적으로 contriever에 대해서 알아보자.
이 contriever는 meta의 "unsupervised dense inforamtion retrieval with contrastive learning"이라는 논문에서 처음 나온 개념이다.

기존의 DPR은 supervised 학습이었다. 즉, QA 데이터 셋에 있는 question, positive passage 쌍을 이용했다. 하지만, 이런 데이터는 사실상 많이 존재하지가 않는다. 따라서, contriever는 unsupervised 방식으로 retriever를 학습하고자 했다. 다시 말해, 라벨이 없는 일반 문서 데이터만 가지고 학습이 가능하게 했다는 방식이다. 그래서 이렇게 라벨이 없는 데이터를 활용해서 contrastive learning을 진행하여 비슷한 의미의 문장은 embedding space에서 가깝게, 다른 문장은 멀게 학습을 진행했다. 즉, supervised QA 데이터 쌍이 없이도 semantic retrieval 능력을 학습하게 했다.

정리하자면, retriever를 단독으로 학습하여 (pre-training 하여) 자체의 퀄리티를 개선했다.

(즉, Contriever는 retrieval을 더 잘 학습해서 LM을 돕는 것, 반면 REAML은 LM을 retrieval을 이용해서 더 잘 학습하게 하는 것)


LLM / Reader

llm의 경우에는 T5 기반의 sequence-to-sequence 모델을 사용하며, fusion-in-decoder (FiD) 아키텍처를 활용한다.

즉, 검색된 k개의 문서를 각각 독립적으로 인코더에 입력하고, 문서별로 독립처리한 후에 marginalize 하는 것이 아니라 문서별 encoder ouput을 합친상태(concat)에서 decoder로 넘겨주어 cross-attention을 활용하게 한다. 즉, 각 문서의 인코더 출력과 질문을 디코더가 함께 처리하여 최종 답변을 생성한다.
(앞서 FiD에 대한 논문도 정리한 바가 있는데 참고하길 바란다.)


모델 학습 방법

위에서 Altas 논문에서 제안한 모델의 아키텍처와 각각의 개념이 무엇인지에 대해서 살펴보았다. 

 

그렇다면 이러한 모델은 어떤 방법을 통해서 학습을 진행할까?

 

Altas 논문에서는 단순히 retriever + generator를 붙였다가 아니라 이걸 어떻게 학습해서 (pre-train + fine-tune) 해서 few-show 성능을 끌어올렸는가가 핵심이다. 그 방법은 바로 end-to-end 학습이다.

 

Altas는 학습된 contriever를 기반으로 문서를 검색하고 T5 기반 LM을 fine-tining 하면서 retrieval 결과를 few-shot example처럼 사용하는 end-to-end 학습 구조를 제시했다. 그렇기 때문에, retriever와 lm 모델의 loss가 무엇인지, 이를 어떻게 학습하여야 end-to-end로 학습되는지 알아보고자 한다.

 

Training for LM

먼저, LM에 대한 학습을 살펴보면 contriever로 검색한 결과를 이용해서 T5를 학습하는 단계가 있는데 이는 말 그대로 언어 모델의 생성 능력을 학습하는 것이다. (즉, 검색된 문서를 조건으로 하여, T5의 text-to-text generation objective를 통해 정답 시퀀스를 생성하도록 학습한다.) 이때의 generative loss는 cross-entropy loss라고  할 수 있다.

Training objectives for retriever

그리고 검색기인 retriever 학습의 경우 contrastive learning을 통해 나온 contrastive loss가 있고, Retriever loss의 경우 generator가 생성한 log-likelihood를 바탕으로 유용했던 문서를 좀 더 높은 점수로 ranking 하도록 retriever를 미세조정 했다. 즉, llm이 실제로 읽어서 중요한 문서에 attention 준 걸 retriever가 모방하도록 distillation 한다. 이 학습을 통해서 altras가 few-shot에서 강력해진다. 

 

(Retriever가 top-k개 문서를 검색 → T5가 문서를 읽으며 cross-attention을 진행 → 어떤 문서에 attention을 많이 줬는지 드러나고 → retriever의 ranking score를 attention이 높으면 점수를 많이 주고 낮으면 낮게 되도록 학습을 진행했다. )


논문에서는 4가지 종류의 distillation loss를 설명하고 있다. 

하지만 논문에서 명시적으로 사용한 loss는 4가지 중에서도 ADist Loss (Attention Distillation) 를 활용했다.

 

Jointly pre-train

 

그래서 앞서 LM이 학습하는 방식과 Retriever가 학습하는 방식 두 가지가 서로 분리되어 진행되는 것이 아니라 상호작용하면서 학습이 된다. 즉, Retriever는 LM이 실제로 참고한 문서에 높은 점수를 주도록 distillation을 통해 조정되고, LM은 Retriever가 제공한 문서를 조건으로 생성 학습을 수행한다. 이 상호작용 구조 자체를 사전학습 단계에서 최적화함으로써, Atlas는 별도의 많은 파라미터 확장 없이도 Few-shot 환경에서 강력한 성능을 보이게 된다.

그래서 최종적으로 Contrastive loss + Distillation loss + Generation loss 이 세 가지 loss들을 가중합 하여 전체 모델을 공동 학습을 시킨다.

 

여기서 왜 갑자기 contrastive loss가 나오냐고 생각할 수 있는데, 기본적으로 retriever 자체를 contriever라는 걸 사용하고 있고, 이 과정에서 contrastive learning이 일어나고 있기 때문에 함께 표현된다고 생각하면 된다. 


정리하자면, 이러한 공동 학습을 통해 Altas는 llm의 내재적 지식과 retriever의 외부 지식 활용 능력을 유기적으로 통합하여 few-shot 환경에서 기존 llm보다 적은 매개변수로도 높은 성능을 달성할 수 있었다.

Pretext tasks

retriever와 language model을 jointly pre-train할 때 사용한 pretext task가 있다.

Altas의 pre-train은 언어모델이 retrieval 결과를 실제로 생성 태스크에서 활용하는 법을 미리 배우도록 하는 사전학습의 일종이지만, 단순 MLM이나 BERT 식의 언어모델링이 아니라 retriever와 language model이 함께 쓰이는 태스크를 사전학습으로 정의했다.

 

1) Prefix Language Modeling

  • 긴 텍스트 조각을 앞뒤로 분할해서 앞부분을 query, 뒷부분을 target generation으로 삼아 학습
  • query를 기반으로 대응 텍스트를 생성하면서 필요한 문서를 참조하는 법을 배우게 함
  • text-to-text 형식으로 language modeling을 수행할 수 있도록 함

2) Masked Language Modeling

  • T5처럼 입력 텍스트의 일부 토큰/스팬을 mask
  • mask된 토큰을 예측하는 과제를 수행
  • 하지만 일반 T5와는 달리 retrieval 모듈을 통해 관련 문서들을 가져오고 그 문서들을 context로 같이 보면서 마스크를 풀도록 학습
  • 실험 결과 이 부분이 가장 효과적인 pretext task로 나타남

3) Title-to-section / Wiki Generation

  • 위키피디아 문서의 제목을 가지고 section 본문을 생성하는 태스크
  • 문서의 query로 전체 문서를 생성하는 task
  • 해당 과정에서도 retirever와 llm이 상호작용하도록 설계

3. Experiment & Results

Altas 논문은 여러 downstream task에서 few-shot 효과를 검증하고 있다. 

크게 3가지로 나눌 수 있을 것 같은데, joint pre-training이 정말로 few-shot 성능을 올리는가?

그리고 retriever-LM 공동 학습에서 무엇이 중요한지, retriever가 보는 index(외부 지식)의 성질이 성능에 어떤 영향을 주는지를 확인했다. 

 

즉, 이 Altas에 대한 성능을 어떻게 검증하고 있는가에 대한 부분을 정리해 보겠다.

하지만 그전에 모델의 성능을 높여주기 위해 진행했던 pre-training에서 pretext task의 성능에 대해서 간략하게 먼저 설명하자면, retrieval-augmented pre-training이 few-shot performance에 중요하며, closed-book baseline을 크게 능가했습니다. masked language modelling이 다른 pretext tasks에 비해 약간의 이점을 보였다. (해당 부분이 논문의 Pre-training Loss & Tasks 부분에 작성되어 있는데, 순서상 다른 실험 결과보다 먼저 설명하는 것이 조금 더 이해하기가 쉬워서 앞서 간략하게 설명하고자 했다.

 

1) Few-shot Learning 성능 평가

 

그렇다면 이제 Altas가 실제로 잘 되는지를 확인하기 위해서 다양한 실험을 진행했다.
NaturalQuestions, TriviaQA, FEVER 등을 64-shot, 1024-shot 같은 적은 labeled examples로 fine-tuning
Atlas-11B 모델이 64개 예제만으로도 NaturalQuestions에서 42% 이상의 정확도를 달성
→ PaLM 540B 같은 대형 모델과 비교해서도 높은 성능을 보임
 
2) 몇 가지 벤치마크 테스트

다음과 같은 다양한 knowledge-intensive task들을 실험함:

MMLU (Multi-task Language Understanding)
KILT 작업 모음 (QA, fact checking 등)
WEBQuestions, etc.
→ Atlas는 대부분의 경우 few-shot 기준에서 strong performance를 보였음.

 

여기서, 1)과 2)는 같은 실험 축으로 Altas가 잘 되는가? 에 대한 실험을 진행한 부분이다. 

3) Index Content Influence 실험


해당 실험에서는 retrieval index에 어떤 데이터를 넣느냐에 따라 성능이 달라지는지를 실험했다. 

즉, Retriever가 무엇을 검색하느냐가 얼마나 중요한가에 대해서 살펴보는 실험을 진행했다.

Wikipedia + Common Crawl 등을 섞어서 index를 구성하며,
이 문서 인덱스를 업데이트함으로써 retrieval quality가 향상되는 것을 확인

4) Fine-tuning 전략 실험 (Retriever)

그렇다면, 1) 번과 2) 번 실험에 이어서 Altas가 왜 잘되는지, 무엇이 중요한 건지를 밝혀내기 위해서 논문은 retriever를 fine-tune 하는 여러 방식도 비교합니다:

Full index update – 전체 문서 임베딩을 재계산
Re-ranking – 기존 임베딩에서 후보를 뽑고 재정렬

Query-side fine-tuning – document encoder를 freeze 하고 query encoder만 갱신

이때, retriever를 고정하는 것은 성능 저하를 초래했다. 
→ few-shot 설정에서는 query-side fine-tuning이 full index update와 유사한 성능을 보였다. 즉, 효율적이고 좋은 성능을 보여줌

 

4. Conclusion

그래서 정리하자면 단순히 pre-trained 된 llm은 few-shot 성능이 떨어지는데, contriever 즉, retriever 자체를 contrastive learning으로 매우 강력하게 학습시켜 둔 dense retriever를 이용하고 이를 T5 기반 LM 앞단에 붙여 retrieval-augmented few-shot learning을 구현하여 parametric memory가 아니라 external memory를 통해 reasoning 할 수 있게 되어서, few-shot 상황에서도 답변의 성능을 올릴 수 있다는 것이다.

되게 어렵게 말하는 것처럼 보이는데, retrieval 자체가 일종의 fine-tuning 된 걸 쓰고 있으며, 이를 llm의 답변에 활용할 수 있게 구조화했다. 그리고 그러한 모델 구조 자체를 상호작용하며 pre-training 할 수 있게 만들어 llm 답변의 few-shot 성능을 올리고 있다고 이해하면 될 것 같다.

 

처음으로 retriever-augmented 구조를 pre-train 했다는 것에 의의가 있는 논문이다.