0:00
/
0:00
Transcript

"Memorize and Rank: Elevating LLMs for Clinical Diagnosis Prediction"

Below podcast on this paper is generated with Google's Illuminate.

Clinical diagnosis prediction models face challenges due to limited patient data and a vast number of potential diseases. This paper introduces MERA, a clinical diagnosis prediction model leveraging LLMs.

MERA fine-tunes an LLM to "memorize" medical codes and their definitions, then uses hierarchical contrastive learning on disease rankings. This approach bridges natural language knowledge with medical codes.

-----

https://arxiv.org/abs/2501.17326

📌 MERA shifts from token-level prediction to a ranking-based approach. Directly optimizing the output probability distribution over the entire International Classification of Diseases code space, using contrastive learning is innovative.

📌 The hierarchical contrastive loss exploits the structure of the International Classification of Diseases ontology. This forces the LLM to learn fine-grained distinctions between similar diseases, which mirrors clinical differential diagnosis.

📌 Memorization of medical code definitions as a pre-training step is key. It creates a strong knowledge base. This semantically grounds the LLM before tackling the complex temporal reasoning in diagnosis prediction.

----------

Methods Explored in this Paper 🔧:

→ MERA fine-tunes a LLM to associate medical codes with natural language definitions. This creates a bidirectional mapping.

→ Hierarchical contrastive learning is used. It distinguishes true diagnoses from a pool of increasingly relevant candidates within the International Classification of Diseases coding system, operating on the output probabilities for all candidate diseases.

→ A dynamic confidence threshold helps to determine the model's confidence level to predict more diagnosis and model the placement of the EOV token. Intra-visit diagnosis patterns are learned through teacher-forcing, optimizing medical code ranking given partial diagnoses.

-----

Key Insights 💡:

→ Pre-trained LLMs alone struggle with medical code recall and diagnosis prediction. Fine-tuning is critical to close the significant performance deficit.

→ MERA effectively bridges the gap between natural language and medical codes. This improves diagnosis prediction.

→ Hierarchical contrastive learning, dynamic confidence thresholds, and intra-visit pattern modeling are crucial for performance. These guide the model to understand disease relationships, output confidence, and clinical practices.

-----

Results 📊:

→ MERA (BioMistral-7B) achieves 33.24 weighted F1 and 49.01 Recall at 20 on MIMIC-III diagnosis prediction, outperforming the previous best model (KGxDP: 27.35 w-F1, 41.29 R at 20).

→ MERA achieves 99.61% code accuracy and 99.58% definition accuracy on medical code memorization (BioMistral-7B).

→ MERA (BioMistral-7B) shows significant improvement to 90.78 AUC and 79.13 F1 on Heart Failure prediction task compared to best baseline of (KGxDP: 86.57 AUC, 74.74 F1)

Discussion about this video

User's avatar