Clinical diagnosis prediction models face challenges due to limited patient data and a vast number of potential diseases. This paper introduces MERA, a clinical diagnosis prediction model leveraging LLMs.
MERA fine-tunes an LLM to "memorize" medical codes and their definitions, then uses hierarchical contrastive learning on disease rankings. This approach bridges natural language knowledge with medical codes.
-----
https://arxiv.org/abs/2501.17326
📌 MERA shifts from token-level prediction to a ranking-based approach. Directly optimizing the output probability distribution over the entire International Classification of Diseases code space, using contrastive learning is innovative.
📌 The hierarchical contrastive loss exploits the structure of the International Classification of Diseases ontology. This forces the LLM to learn fine-grained distinctions between similar diseases, which mirrors clinical differential diagnosis.
📌 Memorization of medical code definitions as a pre-training step is key. It creates a strong knowledge base. This semantically grounds the LLM before tackling the complex temporal reasoning in diagnosis prediction.
----------
Methods Explored in this Paper 🔧:
→ MERA fine-tunes a LLM to associate medical codes with natural language definitions. This creates a bidirectional mapping.
→ Hierarchical contrastive learning is used. It distinguishes true diagnoses from a pool of increasingly relevant candidates within the International Classification of Diseases coding system, operating on the output probabilities for all candidate diseases.
→ A dynamic confidence threshold helps to determine the model's confidence level to predict more diagnosis and model the placement of the EOV token. Intra-visit diagnosis patterns are learned through teacher-forcing, optimizing medical code ranking given partial diagnoses.
-----
Key Insights 💡:
→ Pre-trained LLMs alone struggle with medical code recall and diagnosis prediction. Fine-tuning is critical to close the significant performance deficit.
→ MERA effectively bridges the gap between natural language and medical codes. This improves diagnosis prediction.
→ Hierarchical contrastive learning, dynamic confidence thresholds, and intra-visit pattern modeling are crucial for performance. These guide the model to understand disease relationships, output confidence, and clinical practices.
-----
Results 📊:
→ MERA (BioMistral-7B) achieves 33.24 weighted F1 and 49.01 Recall at 20 on MIMIC-III diagnosis prediction, outperforming the previous best model (KGxDP: 27.35 w-F1, 41.29 R at 20).
→ MERA achieves 99.61% code accuracy and 99.58% definition accuracy on medical code memorization (BioMistral-7B).
→ MERA (BioMistral-7B) shows significant improvement to 90.78 AUC and 79.13 F1 on Heart Failure prediction task compared to best baseline of (KGxDP: 86.57 AUC, 74.74 F1)