One retrieval is never enough; CoRAG teaches LLMs to ask again and answer better.
Conventional RAG struggles with complex queries due to single retrieval. This paper introduces CoRAG, enabling step-by-step retrieval and reasoning by training LLMs with intermediate retrieval chains generated via rejection sampling, improving performance in multi-hop QA tasks.
-----
Paper - https://arxiv.org/abs/2501.14342
Original Problem 🤔:
→ Traditional Retrieval Augmented Generation (RAG) systems perform retrieval only once before generation.
→ This single retrieval step is often insufficient for complex queries.
→ Retrieval models struggle to capture all necessary information in one go, especially for multi-hop reasoning.
-----
Solution in this Paper 💡:
→ The paper proposes Chain-of-Retrieval Augmented Generation (CoRAG).
→ CoRAG enables LLMs to retrieve and reason in steps.
→ It uses rejection sampling to create intermediate retrieval chains from existing RAG datasets.
→ These chains consist of sub-queries and sub-answers, generated by an LLM.
→ CoRAG trains an LLM to predict the next sub-query, sub-answer, and final answer.
→ At test time, different decoding strategies like greedy, best-of-N sampling, and tree search are used to control compute.
-----
Key Insights from this Paper 🧐:
→ Iterative retrieval improves performance on multi-hop questions.
→ Test-time compute can be scaled by adjusting retrieval chain length and sampling.
→ The performance scales log-linearly with token consumption.
→ CoRAG can decompose complex queries and reformulate queries dynamically.
→ CoRAG is robust to varying retriever quality.
-----
Results 🚀:
→ CoRAG-8B achieves over 10 points EM improvement on multi-hop QA compared to baselines.
→ CoRAG-8B sets new state-of-the-art on KILT hidden test set across tasks.
→ On 2WikiMultihopQA, CoRAG-8B (L=10, best-of-8) reaches 72.5% EM and 77.3% F1.
Share this post