오늘은 순서상 조금 늦었지만, REALM이라는 논문에 대해서 정리해보고자 한다.
사실 RAG를 공부하다 보면서 보게 되는 논문이고, 시간으로 따지면 DPR 보다 이전에 나온 논문이긴 하다..
조금은 쉽고 간단하게 본 논문을 정리해보고자 한다.
https://arxiv.org/abs/2002.08909
REALM: Retrieval-Augmented Language Model Pre-Training
Language model pre-training has been shown to capture a surprising amount of world knowledge, crucial for NLP tasks such as question answering. However, this knowledge is stored implicitly in the parameters of a neural network, requiring ever-larger networ
arxiv.org
1. Introduction
BERT, GPT 계열의 LM 모델들은 pre-training으로 QA와 같은 NLP task 영역에서 많은 양의 world knowledge를 포착한다. 하지만 학습 과정에서 명확히 어디에 어떤 정보를 저장하는지 알 수 없을뿐더러 neural network의 parameter에 모든 세계의 지식을 담는 데는 한계가 있다.
심지어 이러한 지식은 시간이 지나면서 바므로, 모델을 다시 학습시키는 것 자체가 비용과 속도 측면에서 매우 비효율적이다.
따라서, pret-train, fine-tuning 및 inference 중에 사용되는 위키피디아와 같은 지식을 두고 모델 외부에 두고 필요할 때 검색하는 Retrieval-Augmented LM이 주목받기 시작했다.
본 논문에서는 LM pre-training 자체에 retrieval을 끼워넣어 retriever를 unsupervised 신호로 학습하는 방법론을 소개한다.
2. REALM의 핵심 아이디어
Pre-train과 Fine-tuning에서 input이 주어질 때의 답의 분포를 학습하는 것을 목표로 한다.
(pre-train의 경우 MLM task를 수행하고, Fine-tuning 시에는 Open-QA task를 수행한다.)
수식으로 표현하자면 p(y|x)를 구하는 것을 목표한다.
즉, input으로 주어졌을 때, model의 vocab에 대한 분포를 학습하는 것이다.
REALM은 크게 두 부분으로 구성되며, p(z|x)를 구하는 Neural Knowlege Retriever와 p(y|z, x)를 구하는 knowlege-augemented encoder가 있다.
p(y|x)를 구하는 과정을 retrieve, predict 두 단계로 나누어 수행한다.
- REALM's generative process
- 쿼리(질문)와 문서(지식)를 같은 임베딩 공간으로 매핑해 벡터 검색 수행
- BM25 같은 전통 IR이 아니라 딥러닝 기반 DPR-style dense retrieval 사용
- 먼저, input x가 주어지면 knowledge corpus z로 부터 document z를 가져온다 → Retrieve 과정이자 p(z|x)를 구하는 과정
- 이후, retrieve 된 document z와 x를 입력하여 최종 output인 y를 산출한다.
- 즉, p(y|x)의 분포를 구하기 위해 z에 대해 marginalize 하여 p(y|x)를 산출한다.
- knowledge coupus z는 여러 document로 이루어져 있고, 결국 p(z|x)의 분포는 x가 주어졌을 때, z가 출현할 확률의 모임이고, p(y|z, x)의 경우에도 x와 z가 주어졌을 때, 각 token y가 출현할 확률인 것이다. 따라서, model input으로 x를 넣어서 output으로 y가 나오는 것은 결국 특정 document가 선택된 사건과 y가 선택될 사건이 동시에 발생하는 것이다.
- 즉, y를 구하기 위해 document z에 대해 구해진 곱사건의 확률을 모두 더하는 (marginalize) 과정을 거치게 된다.

- Neural Knowlege Retriever
- x가 주어졌을 때, document z에 대한 p(z|x)를 modeling 한다.
- 각 input x와 document z는 특정 d 차원으로 embedding 한 벡터고, 이러한 벡터는 BERT를 활용해서 BERT의 input으로 넣은 다음에 [CLS] token 위치의 output 값을 추출해서 inner product로 relevance score를 계산한다.
→ 그렇게 해서 최종적으로 p(z|x)를 구해주게 된다.


- Knowlege-augmented Encoder
- input x와 document z가 주어졌을 때, output y에 관한 확률인 p(y|z, x)를 모델링한다.
- 즉, x와 z를 input으로 받아 y를 산출해내는 아키텍처이다.
- 다만, pre-training시와 Fine-tuning 시의 작동 방식은 살짝 다르다.
- pre-training 시에는 MLM task를 수행하기 때문에 [mask] token 위치의 원래 token을 예측하게 된다. 그런데 이 mask를 예측하는 방식은 앞서 설명했던 방식을 거쳐서 구하게 된다.
조금 더 쉽게 풀어서 설명하자면, [mask] 부분을 예측하는걸 retriever가 찾아온 document z로 하게 된다. 그렇기 때문에 따로 DPR처럼 정답 데이터가 필요없기 때문에, unsupervised task라고 하는 것!
→ Retrieval은 LM이 스스로 외부에서 필요한 정보를 찾는 행위이고, Learning은 그 검색을 통해 MLM 성능을 높이는 행위기 때문에, LM이 단어를 예측할 때, "어떤 문서를 참고해야 가장 잘 맞출 수 있는지"를 retriever가 점점 배워갈 수 있어 Retreiver도 함께 학습이 된다는 의미!
즉, retriever는 LM을 보조하는 수단이자 LM의 pre-training 과정에 통합시켜 unsupervised 학습을 진행한 것.
- Fine-tuning 시에는 Open-QA task를 수행한다. 단, REALM은 BERT 기반의 model이기 때문에 특정 corpus안에서 알맞은 부분만 추출하는 방식으로 Open-QA task를 수행하게 된다.
- 그런데 이때, 모든 document에 대한 확률을 더하게 되면 계산량이 많아지기 때문에 top-k개의 document에 대해서만 summation함.
- 이때, top-k개의 document를 어떻게 찾을 것인가 하는 부분에 대해서는 MIPS(Maximize Inner Product Search)라는 알고리즘을 활용했다.

* 이외에도 논문에서는 REALM을 훈련시킬 때, retrieval이 잘 학습되게 하기 위해서 추가적인 방법들을 활용함.
- Salient span masking
- Null document
- Prohibiting trivial retrievals
- Initialization
3. Contribution
- Retrieval과 LM을 동시에 학습 → Retriever가 downstream task에 최적화됨
- 지식 업데이트 유연성
- 외부 corpus만 교체하면 되므로, 모델을 다시 pretrain 하지 않아도 새로운 지식 반영 가능
- Knowledge-intensive task에서 강력한 성능
- Open-domain QA에서 SOTA 달성 (2020 기준)
- 특히, T5와 비교했을 때, parameter의 수에서도 차이를 보이지만 RELAM이 성능이 더 우수함

4. Conclusion & Limit
- 단순히 Retrieval을 inference에만 쓰는 게 아니라, Pre-training 단계에서 Retrieval을 끼워넣어 함께 학습한다는 점이 핵심
→ DPR 방식처럼 Q-A쌍의 데이터를 주는 것이 아니라 LM이 잘 예측하는 방식으로 학습하는 unsupervised 방식을 채택 - 즉, LM이 Masked LM을 풀 때 외부 지식(위키피디아)을 찾아오도록 설계
→ Retrieval과 Language Model이 End-to-End 공동 학습됨. - Input: 질문 + context (retrieved docs)
- Retriever: 쿼리를 인코딩 → 전체 위키피디아 문서 중 관련 문서 Top-K 검색
- Reader: 검색된 문서를 참고해 MLM 또는 QA 태스크 수행
- Retrieval 과정이 end-to-end라 학습/추론 속도가 느림
- Index가 커질수록 retriever 관리 비용 ↑
- FiD(Fusion-in-Decoder) 같은 후속 연구에 비해 retrieved document 활용 방식이 단순 (Concat만 함)
REALM은 단순히 "검색 붙인 LM"이 아니라, 검색과 언어모델을 end-to-end로 학습한 첫 시도 중 하나라는 점에서 의의가 큼.
오늘날 ChatGPT+검색, LLM+RAG 시스템의 뿌리 중 하나로, 외부 지식을 참고하는 학습 패러다임을 정립했다는 점에서 여전히 중요한 milestone이다.
오늘은 이렇게 해서 REALM 논문에 대해서 정리를 해보았다.
오늘날 RAG의 발전에 영향을 끼친 여러 논문들을 계속해서 정리할 예정이라, 추후 타 논문들과의 차이점도 같이 작성해 볼 예정이다.
Reference
(논문을 읽고 이해하는 과정은 다음 블로그 글을 참고했습니다 ㅎ 감사합니다..)
[논문 리뷰] REALM: Retrieval-Augmented Language Model Pre-Training
이번 게시물에서는 최초로 retrieval와 language model을 같이 pre-training을 진행한 REALM을 제안한 논문인 REALM: Retrieval-Augmented Language Model Pre-Training에 대해 다뤄보겠다. 원문 링크는 아래와 같다. REALM: Ret
gbdai.tistory.com