Fine-tuning strategy to maximize marginal log-likelihood of correct answers with Chain-of-Thought (CoT)
The Paper propose optimizing LLM reasoning by leveraging probabilistic latent-variable modeling.
📚 https://arxiv.org/abs/2312.02179
Original Problem 🔍:
LLMs struggle with chain-of-thought (CoT) reasoning, requiring expensive human-generated rationales for fine-tuning.
-----
Solution in this Paper 🧠:
• TRICE: A latent-variable inference method for CoT training
• Uses MCMC-EM algorithm to sample rationales conditioned on correct answers
• Incorporates a novel control-variate technique to reduce gradient variance
• Maximizes marginal log-likelihood of correct answers, averaging over possible rationales
-----
Key Insights from this Paper 💡:
• CoT methods are probabilistic latent-variable models
• TRICE outperforms STaR and direct tuning on GSM8K and BIG-Bench Hard tasks
• Learns from both correct and incorrect rationales
• Avoids ignoring difficult examples during training
-----
Results 📊:
• GSM8K: 74.7% accuracy (greedy decoding), 82.3% (self-consistency)
• BIG-Bench Hard: 76.7% average accuracy
• Generates valid rationales for 98-99% of training examples
• Outperforms supervised fine-tuning on human-generated rationales
Share this post