0:00
/
0:00
Transcript

"O1 Replication Journey -- Part 3: Inference-time Scaling for Medical Reasoning"

Generated below podcast on this paper with Google's Illuminate.

When LLMs take their time, patients get better answers

This paper explores how increasing inference time improves medical reasoning in LLMs, yielding 6-11% performance gains using journey learning and extended thought processes.

-----

https://arxiv.org/abs/2501.06458

Original Problem 🤔:

LLMs struggle with complex medical reasoning tasks like diagnostics and treatment planning. Traditional scaling methods like increasing model size or training data have limitations. Need a better approach.

-----

Solution in this Paper 💡:

→ Introduces inference-time scaling through journey learning - allowing LLMs more processing time for complex medical tasks

→ Uses two key methods: LongStep and LongMonolog for generating extended reasoning chains

→ Synthesizes training data from 500 samples across MedQA and JAMA Clinical Challenges

→ Implements majority voting combined with journey learning to enhance model predictions

→ Employs knowledge distillation from GPT-4 to create high-quality demonstration data

-----

Key Insights 🔍:

→ Harder medical tasks require longer reasoning chains and more inference time

→ Model size directly impacts inference-time scaling effectiveness

→ Journey learning outperforms simple majority voting approaches

→ Base model capability is crucial for successful inference-time scaling

-----

Results 📊:

→ 6-11% accuracy improvement with 500 training samples

→ Qwen2.5-72B achieves 77.18% accuracy with LongMonolog

→ Token length increases from ~400 to ~1000 with journey learning

→ Higher performance gains on complex JAMA cases vs simpler MedQA tasks

Discussion about this video