When LLMs take their time, patients get better answers
This paper explores how increasing inference time improves medical reasoning in LLMs, yielding 6-11% performance gains using journey learning and extended thought processes.
-----
https://arxiv.org/abs/2501.06458
Original Problem 🤔:
LLMs struggle with complex medical reasoning tasks like diagnostics and treatment planning. Traditional scaling methods like increasing model size or training data have limitations. Need a better approach.
-----
Solution in this Paper 💡:
→ Introduces inference-time scaling through journey learning - allowing LLMs more processing time for complex medical tasks
→ Uses two key methods: LongStep and LongMonolog for generating extended reasoning chains
→ Synthesizes training data from 500 samples across MedQA and JAMA Clinical Challenges
→ Implements majority voting combined with journey learning to enhance model predictions
→ Employs knowledge distillation from GPT-4 to create high-quality demonstration data
-----
Key Insights 🔍:
→ Harder medical tasks require longer reasoning chains and more inference time
→ Model size directly impacts inference-time scaling effectiveness
→ Journey learning outperforms simple majority voting approaches
→ Base model capability is crucial for successful inference-time scaling
-----
Results 📊:
→ 6-11% accuracy improvement with 500 training samples
→ Qwen2.5-72B achieves 77.18% accuracy with LongMonolog
→ Token length increases from ~400 to ~1000 with journey learning
→ Higher performance gains on complex JAMA cases vs simpler MedQA tasks
Share this post