Forget one-shot RAG; CoRAG brings iterative search to language models for deeper reasoning.
CoRAG teaches LLMs to ask again and answer better.
Conventional RAG struggles with complex queries due to single retrieval. This paper introduces CoRAG, enabling step-by-step retrieval and reasoning by training LLMs with intermediate retrieval chains generated via rejection sampling, improving performance in multi-hop QA tasks.
-----
Paper - https://arxiv.org/abs/2501.14342
Original Problem 🤔:
→ Traditional Retrieval Augmented Generation (RAG) systems perform retrieval only once before generation.
→ This single retrieval step is often insufficient for complex queries.
→ Retrieval models struggle to capture all necessary information in one go, especially for multi-hop reasoning.
-----
Solution in this Paper 💡:
→ The paper proposes Chain-of-Retrieval Augmented Generation (CoRAG).
→ CoRAG enables LLMs to retrieve and reason in steps.
→ It uses rejection sampling to create intermediate retrieval chains from existing RAG datasets.
→ These chains consist of sub-queries and sub-answers, generated by an LLM.
→ CoRAG trains an LLM to predict the next sub-query, sub-answer, and final answer.
→ At test time, different decoding strategies like greedy, best-of-N sampling, and tree search are used to control compute.
-----
Key Insights from this Paper 🧐:
→ Iterative retrieval improves performance on multi-hop questions.
→ Test-time compute can be scaled by adjusting retrieval chain length and sampling.
→ The performance scales log-linearly with token consumption.
→ CoRAG can decompose complex queries and reformulate queries dynamically.
→ CoRAG is robust to varying retriever quality.
-----
Results 🚀:
→ CoRAG-8B achieves over 10 points EM improvement on multi-hop QA compared to baselines.
→ CoRAG-8B sets new state-of-the-art on KILT hidden test set across tasks.
→ On 2WikiMultihopQA, CoRAG-8B (L=10, best-of-8) reaches 72.5% EM and 77.3% F1.
Share this post