RAG (Retrieval Augmented Generation) is a technique that combines the strengths of retrieval-based and generation-based language models. In essence, it works like this:
- Retrieval: When a user asks a question, the system first retrieves relevant documents from a knowledge base.
- Augmentation: The retrieved documents are then combined with the original question to provide context.
- Generation: A language model uses this augmented information to generate a coherent and informative answer.
This approach allows the model to provide more accurate and contextually relevant responses by grounding its knowledge in real-world data.
Project Goals
- Ingest and preprocess clinical notes from the MIMIC-IV-Ext Direct dataset.
- Build a retrieval system to find relevant clinical notes based on user queries.
- Use a generative model to produce accurate and coherent answers, based on retrieved context.
- Create a user-friendly interface for querying the system.
Building the RAG System
Here’s a detailed look at the steps I took to build the RAG system:
Dataset Familiarization and Preprocessing
The first step was to dive into the MIMIC-IV-Ext Direct dataset. This involved understanding its structure, the different types of records it contained, and how the data was organized. The data consists of JSON files, each representing a clinical note.
Data Loading and Exploration
I started by loading the JSON files from the ‘Finished’ directory of the MIMIC-IV dataset. Here’s the code:
import os
import json
import glob
from typing import List, Dict, Any
from langchain.docstore.document import Document
import streamlit as st
DATA_DIR = "./mimic-iv-ext-direct-1.0.0" # Adjust if your path differs
def extract_dict_text(obj: Any) -> str:
""" Recursively extracts text from nested dict/list structures. """
out = []
if isinstance(obj, dict):
for val in obj.values():
if isinstance(val, (dict, list)):
out.append(extract_dict_text(val))
else:
out.append(str(val))
elif isinstance(obj, list):
for item in obj:
if isinstance(item, (dict, list)):
out.append(extract_dict_text(item))
else:
out.append(str(item))
else:
out.append(str(obj))
return " ".join(filter(None, out)) # Filter out potential None values
def load_mimic_finished(data_dir: str) -> List[Document]:
""" Loads and preprocesses documents from MIMIC 'Finished' directory. """
finished_dir = os.path.join(data_dir, "Finished")
json_files = glob.glob(os.path.join(finished_dir, "**", "*.json"), recursive=True)
docs = []
for path in json_files:
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
except Exception:
continue # Skip files that can't be opened or parsed
text_parts = []
metadata = {"source": os.path.basename(path)}
for k, v in data.items():
if isinstance(v, dict) and ("$Intermedia_" in k or "$Cause_" in k):
text_parts.append(extract_dict_text(v))
elif k.startswith("input") and isinstance(v, str):
text_parts.append(v.strip())
combined_text = "\n".join(filter(None, text_parts)).strip()
if not combined_text:
continue
docs.append(Document(page_content=combined_text, metadata=metadata))
return docs
This function loads the JSON data and extracts relevant text, creating a Document
object from Langchain that contains the content and metadata like the source filename. It focuses on extracting text from specific keys within the JSON structure that contain clinical information.
Preprocessing
- Text Extraction: I extracted relevant text from the JSON files, focusing on sections containing clinical notes, diagnoses, and other important information.
- Data Cleaning: This involved removing unnecessary characters, standardizing text formats, and handling missing values.
- Synonym Expansion: To improve the accuracy of the retrieval process, I expanded common clinical abbreviations and acronyms to their full forms. For example, “HF” was expanded to “heart failure.” This helps the system better understand the meaning of the text.
import re
def expand_synonyms(text: str) -> str:
""" Basic expansions for domain abbreviations. """
expansions = {
r"\bHF\b": "heart failure", r"\bHFrEF\b": "heart failure with reduced ejection fraction",
r"\bCHF\b": "congestive heart failure", r"\bHTN\b": "hypertension",
r"\bDM2?\b": "type 2 diabetes mellitus", r"\bT2DM\b": "type 2 diabetes mellitus",
r"\bCAD\b": "coronary artery disease", r"\bMI\b": "myocardial infarction",
r"\bCABG\b": "coronary artery bypass graft", r"\bPCI\b": "percutaneous coronary intervention",
r"\bAFib\b": "atrial fibrillation", r"\bCKD\b": "chronic kidney disease",
r"\bESRD\b": "end-stage renal disease", r"\bPE\b": "pulmonary embolism",
r"\bDVT\b": "deep vein thrombosis", r"\bCOPD\b": "chronic obstructive pulmonary disease",
r"\bUA\b": "unstable angina", r"\bNSTEMI\b": "non-ST elevation myocardial infarction",
r"\bSTEMI\b": "ST elevation myocardial infarction",
r"LVEF\s*<\s*40%": "heart failure with reduced ejection fraction"
}
processed_text = text
for pattern, repl in expansions.items():
processed_text = re.sub(pattern, repl, processed_text, flags=re.IGNORECASE)
return processed_text
Chunking Because LLMs have context limits, I split the documents into smaller chunks.
from langchain.text_splitter import RecursiveCharacterTextSplitter
CHUNK_SIZE = 1024
CHUNK_OVERLAP = 128
def chunk_docs(_docs: List[Document], chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) -> List[Document]:
""" Splits documents into smaller chunks. """
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", " ", ""],
length_function=len,
)
new_docs = []
for doc in _docs:
try:
chunks = splitter.split_text(doc.page_content)
for chunk in chunks:
if chunk.strip():
new_docs.append(Document(page_content=chunk, metadata=doc.metadata.copy()))
except Exception as e:
print(f"Error chunking document from {doc.metadata.get('source', 'unknown')}: {e}")
continue
return new_docs
This code uses RecursiveCharacterTextSplitter
to divide the documents into smaller, manageable chunks. The chunk_size
and chunk_overlap
parameters control the size of the chunks and the amount of overlap between them, respectively. The separators are important for splitting the text in a meaningful way.
Designing the Retrieval Component
The retrieval component is responsible for finding the most relevant documents in the dataset for a given user query. I chose to use dense retrieval with embeddings for this task, as it generally provides better semantic search capabilities compared to keyword-based methods.
Embedding Generation
I used the GoogleGenerativeAIEmbeddings model to generate embeddings for each document chunk. These embeddings are numerical representations of the text, capturing its semantic meaning.
from langchain_google_genai import GoogleGenerativeAIEmbeddings
import os
from dotenv import load_dotenv
load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
GOOGLE_EMBEDDING_MODEL = "models/text-embedding-004"
# Initialize Google embeddings
embeddings = GoogleGenerativeAIEmbeddings(
google_api_key=GOOGLE_API_KEY,
model=GOOGLE_EMBEDDING_MODEL
)
Indexing with FAISS
I used FAISS (Facebook AI Similarity Search) to index the document embeddings. FAISS provides efficient methods for similarity search, allowing us to quickly find the documents that are most similar to the query embedding.
from langchain_community.vectorstores import FAISS
VECTOR_STORE_PATH = "faiss_vectorstore"
# Build the vector store
vectorstore = FAISS.from_documents(chunked_docs, embeddings)
# Save the index
vectorstore.save_local(VECTOR_STORE_PATH)
This code creates a FAISS index from the document embeddings and saves it to disk. This allows us to load the index later without having to re-generate the embeddings. The VECTOR_STORE_PATH
variable specifies the directory where the index will be saved.
To speed up the process for future use, I added the ability to load the index if it already exists:
import os
from langchain_community.vectorstores import FAISS
VECTOR_STORE_PATH = "faiss_vectorstore"
def get_vectorstore(
_docs_to_index: List[Document],
embedding_model_name: str,
index_path: str,
google_api_key: str # Required for Google embeddings
) -> FAISS:
"""
Loads FAISS index from disk if available, otherwise builds and saves it.
"""
# Initialize Google embeddings (as shown in previous snippets)
if os.path.exists(index_path):
print(f"Loading existing vector store from {index_path}...")
vectorstore = FAISS.load_local(
index_path,
embeddings, # Pass the Google embeddings object
allow_dangerous_deserialization=True # Needed for some embedding types
)
print(f"Vector store loaded successfully from {index_path}.")
return vectorstore
else:
print(f"No existing vector store found at {index_path}.")
# --- Build and Save Logic ---
print(f"Building new vector store...")
vectorstore = FAISS.from_documents(_docs_to_index, embeddings) # Use Google embeddings
# Save the index
print(f"Saving vector store to {index_path}...")
vectorstore.save_local(index_path)
print(f"Vector store saved successfully to {index_path}.")
return vectorstore
Query Formulation
When a user enters a query, the system generates an embedding for the query text using the same GoogleGenerativeAIEmbeddings model. This query embedding is then used to search the FAISS index for the most similar document embeddings.
Integrating the Generative Model
The generative model is responsible for producing the final answer to the user’s query, using the retrieved documents as context. I chose the Google’s Gemini model for this task, which is powerful and well-suited for generating coherent and informative text.
Prompt Engineering
I designed a prompt that combines the user’s query with the content of the retrieved documents. This prompt is then fed into the Gemini model to generate the answer. The prompt includes instructions to the model to be concise, accurate, and to cite the source documents when possible.
RAG Pipeline
I integrated the retriever and generator into a seamless pipeline using LangChain. This pipeline takes the user’s query, retrieves relevant documents, and feeds them into the Gemini model to generate the final answer.
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import RetrievalQA
GOOGLE_LLM_MODEL = "gemini-2.0-flash"
# Initialize Google LLM
llm = ChatGoogleGenerativeAI(
model=GOOGLE_LLM_MODEL,
google_api_key=GOOGLE_API_KEY,
temperature=0.5, # Adjust creativity
convert_system_message_to_human=True # Often needed for chat models in RAG
)
# Create Retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# Create RetrievalQA Chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # Suitable for models handling moderate context lengths
retriever=retriever,
return_source_documents=True,
)
This code initializes the Google Gemini model and creates a RetrievalQA
chain. The chain_type="stuff"
parameter tells the chain to simply stuff all of the retrieved documents into the prompt. The return_source_documents=True
parameter tells the chain to return the source documents along with the answer. I set k=3
so that the top 3 results are returned.
A Streamlit frontend
To make the system user-friendly, I built an interactive frontend using Streamlit. This frontend allows users to enter clinical queries and view the generated responses.
Streamlit App
The Streamlit app provides a simple interface for users to interact with the RAG system. It includes:
- A text input field for entering clinical queries.
- A button to trigger the query processing.
- A display area for the generated answer and the retrieved documents.
import streamlit as st
import time
st.title("🩺 MIMIC-IV Clinical Query Assistant (Google AI)")
st.markdown("""
Enter a clinical question. The system retrieves relevant MIMIC-IV info using Google AI embeddings
and generates an answer using Google's Gemini model.
**Note:** Demo only. Not for clinical use. Handle data ethically.
""")
# --- User Interaction ---
query = st.text_input("Enter your clinical question:", key="query_input")
if st.button("Get Answer", key="submit_button"):
if query:
with st.spinner("Retrieving documents and generating answer with Google Gemini..."):
try:
start_query_time = time.time()
result = qa_chain.invoke({"query": query}) # Use invoke
end_query_time = time.time()
answer = result.get("result", "No answer generated.")
source_docs = result.get("source_documents", [])
st.subheader("Generated Answer")
st.markdown(answer)
st.info(f"Query processed in {end_query_time - start_query_time:.2f} seconds.")
st.subheader(f"Retrieved Documents (Top {len(source_docs)})")
if source_docs:
for i, doc in enumerate(source_docs):
with st.expander(f"Source {i+1}: {doc.metadata.get('source', 'Unknown')}"):
st.write(f"**Content:**\n```\n{doc.page_content}\n```")
else:
st.write("No relevant documents were retrieved for this query.")
except Exception as e:
st.error(f"An error occurred during query processing: {e}")
st.error("Please check your API keys, network connection, and API quotas (Google).")
else:
st.warning("Please enter a question.")
This code creates a simple Streamlit app with a text input field for the query and a button to submit the query. The app displays the generated answer and the retrieved source documents. Error handling is included to catch potential exceptions during query processing.
To make the app easier to use, I’ve added caching so that the vectorstore and rag chain are only loaded once.
@st.cache_resource(show_spinner=False) # Cache the final vectorstore object
def get_vectorstore(
_docs_to_index: List[Document] | None,
embedding_model_name: str,
index_path: str,
google_api_key: str # Required for Google embeddings
) -> FAISS | None:
# Vectorstore implementation here
@st.cache_resource(show_spinner="Initializing RAG Chain with Google LLM...")
def create_google_rag_chain(
_vectorstore: FAISS | None,
google_api_key: str, # Required for Google LLM
llm_model_name: str = GOOGLE_LLM_MODEL
) -> RetrievalQA | None:
# Chain creation implementation here
Ethical and Privacy Considerations
Working with clinical data requires careful consideration of ethical and privacy issues. I took the following steps to address these concerns:
- Data Anonymization: I ensured that the data used in the system was properly anonymized to protect patient privacy.
- Data Use Agreement: I adhered to the terms of the MIMIC-IV data use agreement, which outlines the permitted uses of the data.
- Disclaimer: I included a disclaimer in the Streamlit app stating that the system is for educational and research purposes only and should not be used for clinical decision-making.
Challenges
- Data Preprocessing: Cleaning and preprocessing the clinical notes was a time-consuming but crucial step. The quality of the data directly impacts the accuracy of the retrieval and generation processes.
- Prompt Engineering: Designing effective prompts that elicit the desired responses from the generative model required experimentation and careful consideration of the context.
- Computational Resources: Generating embeddings for a large dataset can be computationally expensive. Leveraging cloud-based services and optimizing the code helped to mitigate this issue.
- API Keys: Remember to always store your API keys securely!
Conclusion
This project was a rewarding experience that provided valuable insights into the RAG approach and the power of large language models. By combining retrieval and generation, I was able to create a system that can answer clinical questions with a high degree of accuracy and coherence.