
When you're building production AI applications, you quickly realize that a single LLM call rarely solves complex problems. Real-world scenarios require orchestrating multiple AI components: retrieving relevant documents, processing multi-step reasoning chains, integrating external APIs, and managing state across long conversations. This is where AI workflow frameworks become essential.
Consider a common enterprise scenario: building an intelligent customer support system that can analyze support tickets, retrieve relevant documentation from multiple sources, consult internal knowledge bases, and generate contextually appropriate responses. This involves document retrieval, semantic search, multi-step reasoning, and careful prompt engineering—all while maintaining performance at scale.
By the end of this lesson, you'll understand how to architect and implement sophisticated AI workflows that go far beyond simple question-answering systems. We'll dive deep into both LangChain and LlamaIndex, examining their architectural philosophies, performance characteristics, and ideal use cases.
What you'll learn:
You should have solid experience with Python, familiarity with transformer models and embeddings, and basic understanding of vector databases. Prior exposure to OpenAI's API or similar LLM services is helpful but not required.
Before diving into specific frameworks, let's establish what makes AI workflows different from traditional software pipelines. Traditional workflows process deterministic data through predictable transformations. AI workflows deal with probabilistic outputs, context-dependent decisions, and the inherent unpredictability of language models.
The core components of most AI workflows include:
Retrieval Systems: Finding relevant information from large document collections using semantic similarity rather than keyword matching. This involves chunking strategies, embedding generation, and vector database operations.
Reasoning Chains: Sequential processing where each step depends on previous outputs, often requiring the LLM to make intermediate decisions or transformations.
Agent Patterns: Autonomous decision-making where the AI determines which tools to use and when, based on the current context and available options.
Memory Management: Maintaining context across interactions, handling conversation history, and managing the balance between context relevance and token limits.
Let's start by implementing a foundational example that demonstrates these concepts:
import openai
from typing import List, Dict, Any
import numpy as np
from dataclasses import dataclass
import asyncio
from abc import ABC, abstractmethod
@dataclass
class WorkflowContext:
"""Carries state and intermediate results through workflow steps"""
user_query: str
retrieved_docs: List[Dict[str, Any]] = None
reasoning_steps: List[str] = None
final_response: str = None
metadata: Dict[str, Any] = None
class WorkflowStep(ABC):
"""Base class for workflow components"""
@abstractmethod
async def process(self, context: WorkflowContext) -> WorkflowContext:
pass
@abstractmethod
def validate_input(self, context: WorkflowContext) -> bool:
pass
class DocumentRetriever(WorkflowStep):
def __init__(self, embeddings_model: str, vector_db_client, top_k: int = 5):
self.embeddings_model = embeddings_model
self.vector_db = vector_db_client
self.top_k = top_k
async def process(self, context: WorkflowContext) -> WorkflowContext:
if not self.validate_input(context):
raise ValueError("Invalid context for document retrieval")
# Generate query embedding
query_embedding = await self._embed_query(context.user_query)
# Retrieve similar documents
similar_docs = await self.vector_db.similarity_search(
query_embedding,
top_k=self.top_k
)
context.retrieved_docs = similar_docs
context.metadata = context.metadata or {}
context.metadata['retrieval_scores'] = [doc['score'] for doc in similar_docs]
return context
def validate_input(self, context: WorkflowContext) -> bool:
return context.user_query is not None and len(context.user_query.strip()) > 0
async def _embed_query(self, query: str) -> List[float]:
# This would integrate with your embedding service
response = await openai.Embedding.acreate(
model=self.embeddings_model,
input=query
)
return response['data'][0]['embedding']
class ReasoningChain(WorkflowStep):
def __init__(self, llm_model: str, temperature: float = 0.1):
self.llm_model = llm_model
self.temperature = temperature
async def process(self, context: WorkflowContext) -> WorkflowContext:
if not self.validate_input(context):
raise ValueError("Missing retrieved documents for reasoning")
# Construct reasoning prompt with retrieved context
reasoning_prompt = self._build_reasoning_prompt(
context.user_query,
context.retrieved_docs
)
response = await openai.ChatCompletion.acreate(
model=self.llm_model,
messages=[{"role": "user", "content": reasoning_prompt}],
temperature=self.temperature
)
context.final_response = response.choices[0].message.content
# Extract reasoning steps if the model provided them
context.reasoning_steps = self._extract_reasoning_steps(
response.choices[0].message.content
)
return context
def validate_input(self, context: WorkflowContext) -> bool:
return (context.retrieved_docs is not None and
len(context.retrieved_docs) > 0)
def _build_reasoning_prompt(self, query: str, docs: List[Dict]) -> str:
doc_context = "\n\n".join([
f"Document {i+1}: {doc['content']}"
for i, doc in enumerate(docs[:3]) # Limit context size
])
return f"""
Based on the following documents, provide a comprehensive answer to the user's question.
Show your reasoning step by step.
Documents:
{doc_context}
User Question: {query}
Please structure your response as:
1. Analysis of the relevant information
2. Step-by-step reasoning
3. Final answer
"""
def _extract_reasoning_steps(self, response: str) -> List[str]:
# Simple extraction - in production, you'd want more robust parsing
lines = response.split('\n')
steps = []
for line in lines:
if line.strip().startswith(('1.', '2.', '3.', 'Step')):
steps.append(line.strip())
return steps
class AIWorkflow:
def __init__(self, steps: List[WorkflowStep]):
self.steps = steps
async def execute(self, initial_query: str) -> WorkflowContext:
context = WorkflowContext(user_query=initial_query)
for step in self.steps:
try:
context = await step.process(context)
except Exception as e:
# In production, you'd want more sophisticated error handling
context.metadata = context.metadata or {}
context.metadata['errors'] = context.metadata.get('errors', [])
context.metadata['errors'].append(f"{step.__class__.__name__}: {str(e)}")
raise
return context
This foundational architecture demonstrates several key principles that both LangChain and LlamaIndex implement in their own ways:
Composability: Each step is independent and can be combined with others. This makes workflows testable and maintainable.
Context Management: The WorkflowContext object carries state through the pipeline, allowing steps to build on previous work.
Error Handling: Each step validates its inputs and the workflow captures errors with context about where they occurred.
Async Support: Real workflows often need to make multiple API calls concurrently, so async support is essential for performance.
LangChain takes an opinionated approach to AI workflow construction, emphasizing flexibility and developer experience. Its architecture centers around several key abstractions that we'll explore in detail.
LangChain's chain concept goes far beyond simple sequential execution. Let's examine how to build sophisticated chains that handle complex reasoning patterns:
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains.sequential import SequentialChain
from langchain.chains.transform import TransformChain
from langchain.schema import BaseRetriever
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
import json
from typing import Dict, List
class AdvancedRAGChain:
"""
Sophisticated RAG implementation showing LangChain's composition patterns
"""
def __init__(self, vector_store_path: str, llm_model: str = "gpt-3.5-turbo"):
self.embeddings = OpenAIEmbeddings()
self.vector_store = Chroma(
persist_directory=vector_store_path,
embedding_function=self.embeddings
)
self.llm = OpenAI(model_name=llm_model, temperature=0.1)
# Build the chain components
self.retrieval_chain = self._build_retrieval_chain()
self.analysis_chain = self._build_analysis_chain()
self.synthesis_chain = self._build_synthesis_chain()
# Compose the complete workflow
self.complete_chain = self._build_complete_chain()
def _build_retrieval_chain(self) -> TransformChain:
"""Chain that handles document retrieval and relevance scoring"""
def retrieve_and_score(inputs: Dict) -> Dict:
query = inputs["query"]
# Retrieve documents
docs = self.vector_store.similarity_search_with_score(
query,
k=10 # Retrieve more than we'll use for better selection
)
# Apply relevance threshold and reranking
relevant_docs = [
{"content": doc.page_content, "score": score, "metadata": doc.metadata}
for doc, score in docs
if score < 0.5 # Adjust threshold based on your embedding space
]
# Rerank by combining semantic similarity with metadata features
reranked_docs = self._rerank_documents(relevant_docs, query)
return {
"query": query,
"retrieved_documents": reranked_docs[:5], # Top 5 after reranking
"retrieval_metadata": {
"total_candidates": len(docs),
"relevant_candidates": len(relevant_docs),
"final_count": len(reranked_docs[:5])
}
}
return TransformChain(
input_variables=["query"],
output_variables=["query", "retrieved_documents", "retrieval_metadata"],
transform=retrieve_and_score
)
def _rerank_documents(self, docs: List[Dict], query: str) -> List[Dict]:
"""Advanced reranking using multiple signals"""
# Simple reranking based on document length and recency
# In production, you might use a learned reranking model
for doc in docs:
base_score = doc["score"]
# Prefer moderately sized documents
length_penalty = abs(len(doc["content"]) - 1000) / 1000
# Prefer recent documents if timestamp available
recency_bonus = 0
if "timestamp" in doc["metadata"]:
# Implementation would calculate recency bonus
pass
doc["final_score"] = base_score + (length_penalty * 0.1) - recency_bonus
return sorted(docs, key=lambda x: x["final_score"])
def _build_analysis_chain(self) -> LLMChain:
"""Chain that analyzes retrieved documents for relevance and key insights"""
analysis_prompt = PromptTemplate(
input_variables=["query", "retrieved_documents"],
template="""
Analyze the following documents in relation to the user's query.
For each document, identify:
1. Relevance score (1-10)
2. Key information that addresses the query
3. Any contradictions or uncertainties
Query: {query}
Documents:
{retrieved_documents}
Provide your analysis in the following JSON format:
{{
"document_analyses": [
{{
"doc_index": 0,
"relevance_score": 8,
"key_insights": ["insight 1", "insight 2"],
"contradictions": ["any contradictory info"],
"confidence": 0.85
}}
],
"overall_assessment": "summary of information quality and completeness"
}}
"""
)
return LLMChain(llm=self.llm, prompt=analysis_prompt)
def _build_synthesis_chain(self) -> LLMChain:
"""Final synthesis chain that generates the response"""
synthesis_prompt = PromptTemplate(
input_variables=["query", "retrieved_documents", "analysis"],
template="""
Based on the document analysis, provide a comprehensive answer to the user's question.
User Query: {query}
Document Analysis: {analysis}
Original Documents: {retrieved_documents}
Guidelines:
1. Synthesize information from multiple sources
2. Acknowledge any limitations or uncertainties
3. Cite specific documents when making claims
4. If information is insufficient, say so clearly
Provide a well-structured response that directly addresses the user's question.
"""
)
return LLMChain(llm=self.llm, prompt=synthesis_prompt)
def _build_complete_chain(self) -> SequentialChain:
"""Compose all chains into a complete workflow"""
return SequentialChain(
chains=[self.retrieval_chain, self.analysis_chain, self.synthesis_chain],
input_variables=["query"],
output_variables=["text", "analysis", "retrieval_metadata"],
verbose=True # Enable for debugging
)
def query(self, user_question: str) -> Dict:
"""Execute the complete RAG workflow"""
try:
result = self.complete_chain({"query": user_question})
# Parse the analysis JSON if possible
try:
analysis_data = json.loads(result["analysis"])
result["parsed_analysis"] = analysis_data
except (json.JSONDecodeError, KeyError):
result["parsed_analysis"] = None
return result
except Exception as e:
return {
"error": str(e),
"query": user_question,
"success": False
}
# Example usage demonstrating the chain composition
def demonstrate_advanced_rag():
# Initialize the RAG system
rag_system = AdvancedRAGChain(
vector_store_path="./company_knowledge_base",
llm_model="gpt-3.5-turbo"
)
# Test query
result = rag_system.query(
"What are our current policies regarding remote work equipment reimbursement?"
)
if "error" not in result:
print("Generated Response:", result["text"])
print("Retrieval Stats:", result["retrieval_metadata"])
if result["parsed_analysis"]:
print("Document Relevance Scores:",
[doc["relevance_score"] for doc in result["parsed_analysis"]["document_analyses"]])
else:
print("Error:", result["error"])
LangChain's agent framework enables AI systems to make dynamic decisions about which tools to use and when. This is particularly powerful for complex workflows where the exact sequence of operations isn't predetermined:
from langchain.agents import Tool, AgentExecutor, initialize_agent, AgentType
from langchain.tools import BaseTool
from langchain.llms import OpenAI
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.schema import AgentAction, AgentFinish
import requests
import pandas as pd
from typing import Optional, Type, List
from pydantic import BaseModel, Field
class DatabaseQueryTool(BaseTool):
"""Custom tool for querying internal databases"""
name = "database_query"
description = """
Use this tool to query internal databases for specific information.
Input should be a SQL-like query description in natural language.
Example: 'Find all customers who purchased product X in the last 30 days'
"""
def _run(
self,
query_description: str,
run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
# In practice, this would connect to your actual database
# and translate natural language to SQL
if "customer" in query_description.lower():
# Simulated database response
return """
Found 156 customers matching criteria:
- Average purchase value: $234.56
- Most common location: California (23%)
- Peak purchase time: Tuesday afternoons
"""
elif "inventory" in query_description.lower():
return """
Current inventory status:
- Low stock items: 12
- Out of stock: 3
- Overstock items: 7
"""
else:
return "No relevant data found for the specified query."
async def _arun(
self,
query_description: str,
run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
# Async version for production use
return self._run(query_description, run_manager)
class APIIntegrationTool(BaseTool):
"""Tool for integrating with external APIs"""
name = "api_integration"
description = """
Use this tool to fetch data from external APIs or services.
Input should specify the service and what information you need.
Example: 'Get current stock price for AAPL' or 'Fetch weather for New York'
"""
def _run(
self,
api_request: str,
run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
# Route to appropriate API based on request
if "stock price" in api_request.lower():
return self._get_stock_price(api_request)
elif "weather" in api_request.lower():
return self._get_weather(api_request)
else:
return "API request not supported or unclear."
def _get_stock_price(self, request: str) -> str:
# Simulated stock API integration
return "AAPL current price: $175.43 (+2.3% today)"
def _get_weather(self, request: str) -> str:
# Simulated weather API integration
return "New York: 72°F, partly cloudy, 10% chance of rain"
class ReportGenerationTool(BaseTool):
"""Tool for generating formatted reports"""
name = "generate_report"
description = """
Use this tool to create formatted reports from collected data.
Input should specify the report type and include the data to be formatted.
"""
def _run(
self,
report_request: str,
run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
# Parse request and generate appropriate report format
if "summary" in report_request.lower():
return """
EXECUTIVE SUMMARY REPORT
========================
Key Metrics:
- Customer satisfaction: 87%
- Revenue growth: +12% QoQ
- System uptime: 99.94%
Recommendations:
1. Focus on customer retention programs
2. Expand successful product lines
3. Investigate intermittent service issues
"""
else:
return "Generated standard report with provided data."
class IntelligentBusinessAgent:
"""
Advanced agent that can handle complex business intelligence queries
"""
def __init__(self, model_name: str = "gpt-4"):
self.llm = OpenAI(model_name=model_name, temperature=0)
# Initialize custom tools
self.tools = [
DatabaseQueryTool(),
APIIntegrationTool(),
ReportGenerationTool(),
]
# Create the agent with custom tools
self.agent = initialize_agent(
tools=self.tools,
llm=self.llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
max_iterations=5, # Prevent infinite loops
early_stopping_method="generate"
)
# Add memory for conversation context
self.conversation_history = []
def query(self, business_question: str) -> Dict:
"""
Process complex business intelligence queries
"""
try:
# Add context from conversation history
contextual_question = self._add_context(business_question)
# Execute the agent
result = self.agent.run(contextual_question)
# Store in conversation history
self.conversation_history.append({
"question": business_question,
"answer": result,
"timestamp": pd.Timestamp.now()
})
return {
"answer": result,
"success": True,
"tools_used": self._extract_tools_used(),
"conversation_id": len(self.conversation_history)
}
except Exception as e:
return {
"error": str(e),
"success": False,
"question": business_question
}
def _add_context(self, question: str) -> str:
"""Add relevant context from conversation history"""
if len(self.conversation_history) == 0:
return question
# Simple context addition - in production, you'd want more sophisticated context management
recent_context = self.conversation_history[-2:] if len(self.conversation_history) > 1 else self.conversation_history
context_summary = "Previous conversation context:\n"
for item in recent_context:
context_summary += f"Q: {item['question']}\nA: {item['answer'][:100]}...\n\n"
return f"{context_summary}Current question: {question}"
def _extract_tools_used(self) -> List[str]:
"""Extract which tools were used in the last execution"""
# This would need to be implemented by capturing agent execution details
# For now, return placeholder
return ["database_query", "generate_report"]
# Example usage showing agent decision-making
def demonstrate_intelligent_agent():
agent = IntelligentBusinessAgent()
# Complex multi-step query
result = agent.query(
"I need to understand our Q3 performance. Can you pull our sales data, "
"compare it to industry benchmarks, and generate an executive summary report?"
)
if result["success"]:
print("Agent Response:", result["answer"])
print("Tools Used:", result["tools_used"])
else:
print("Error:", result["error"])
# Follow-up question that should use context
followup = agent.query(
"What specific recommendations do you have based on that analysis?"
)
print("Follow-up Response:", followup["answer"])
Performance Tip: LangChain agents can become expensive quickly due to multiple LLM calls for planning and execution. Monitor token usage carefully and consider implementing caching for repeated tool operations.
LlamaIndex takes a different architectural approach, optimizing specifically for retrieval-augmented generation scenarios. Its design philosophy centers around efficient indexing, retrieval, and synthesis patterns.
LlamaIndex provides sophisticated indexing capabilities that go beyond simple vector storage:
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, ServiceContext
from llama_index.indices.composability import ComposableGraph
from llama_index.indices.keyword_table import GPTKeywordTableIndex
from llama_index.indices.list import GPTListIndex
from llama_index.indices.tree import GPTTreeIndex
from llama_index.node_parser import SimpleNodeParser
from llama_index.text_splitter import TokenTextSplitter
from llama_index.embeddings import OpenAIEmbedding
from llama_index.llms import OpenAI
from llama_index.storage.storage_context import StorageContext
from llama_index.storage.docstore import SimpleDocumentStore
from llama_index.storage.index_store import SimpleIndexStore
from llama_index.storage.vector_store import SimpleVectorStore
import os
from typing import Dict, List, Any
class EnterpriseRAGSystem:
"""
Production-ready RAG system using LlamaIndex's advanced features
"""
def __init__(self, data_directory: str, persist_directory: str = "./storage"):
self.data_directory = data_directory
self.persist_directory = persist_directory
# Configure service context with custom settings
self.llm = OpenAI(model="gpt-4", temperature=0.1)
self.embedding_model = OpenAIEmbedding()
# Advanced node parsing configuration
text_splitter = TokenTextSplitter(
separator=" ",
chunk_size=1024,
chunk_overlap=128,
backup_separators=["\n", ".", "!", "?"]
)
node_parser = SimpleNodeParser(
text_splitter=text_splitter,
include_metadata=True,
include_prev_next_rel=True # Enable node relationships
)
self.service_context = ServiceContext.from_defaults(
llm=self.llm,
embed_model=self.embedding_model,
node_parser=node_parser,
chunk_size=1024
)
# Initialize storage components
self.storage_context = StorageContext.from_defaults(
docstore=SimpleDocumentStore(),
vector_store=SimpleVectorStore(),
index_store=SimpleIndexStore()
)
# Build the multi-index system
self.indices = {}
self.composable_graph = None
self._build_indices()
self._create_composable_graph()
def _build_indices(self):
"""Build multiple specialized indices for different query types"""
# Load documents
documents = SimpleDirectoryReader(
self.data_directory,
recursive=True,
required_exts=[".txt", ".md", ".pdf", ".docx"]
).load_data()
print(f"Loaded {len(documents)} documents")
# Vector Index - for semantic similarity
self.indices['vector'] = GPTVectorStoreIndex.from_documents(
documents,
service_context=self.service_context,
storage_context=self.storage_context,
show_progress=True
)
# Keyword Index - for exact term matching
self.indices['keyword'] = GPTKeywordTableIndex.from_documents(
documents,
service_context=self.service_context,
show_progress=True
)
# Tree Index - for hierarchical summarization
self.indices['tree'] = GPTTreeIndex.from_documents(
documents,
service_context=self.service_context,
show_progress=True
)
# List Index - for exhaustive search over smaller datasets
if len(documents) < 100: # Only create for smaller document sets
self.indices['list'] = GPTListIndex.from_documents(
documents,
service_context=self.service_context,
show_progress=True
)
# Persist indices
for name, index in self.indices.items():
index.storage_context.persist(
persist_dir=os.path.join(self.persist_directory, name)
)
def _create_composable_graph(self):
"""Create a composable graph that can route queries to appropriate indices"""
# Define index summaries for routing decisions
index_summaries = {
'vector': "Use for semantic similarity searches, conceptual questions, and finding related content",
'keyword': "Use for exact keyword matches, specific term lookups, and when precise terminology matters",
'tree': "Use for summarization tasks, getting overviews, and hierarchical information",
}
if 'list' in self.indices:
index_summaries['list'] = "Use for comprehensive searches over small document collections"
# Create composable graph
self.composable_graph = ComposableGraph.from_indices(
root_id="root",
children_indices=list(self.indices.values()),
index_summaries=list(index_summaries.values()),
service_context=self.service_context
)
def query_with_routing(
self,
query: str,
mode: str = "default",
similarity_top_k: int = 5,
response_mode: str = "compact"
) -> Dict[str, Any]:
"""
Query the system using automatic index routing
"""
query_engine = self.composable_graph.as_query_engine(
response_mode=response_mode,
similarity_top_k=similarity_top_k,
verbose=True
)
try:
response = query_engine.query(query)
return {
"response": str(response),
"source_nodes": [
{
"content": node.node.get_content()[:200] + "...",
"metadata": node.node.metadata,
"score": node.score
}
for node in response.source_nodes
],
"success": True
}
except Exception as e:
return {
"error": str(e),
"success": False
}
def query_specific_index(
self,
query: str,
index_type: str,
**kwargs
) -> Dict[str, Any]:
"""
Query a specific index directly
"""
if index_type not in self.indices:
return {
"error": f"Index type '{index_type}' not available",
"available_indices": list(self.indices.keys()),
"success": False
}
try:
query_engine = self.indices[index_type].as_query_engine(**kwargs)
response = query_engine.query(query)
return {
"response": str(response),
"index_used": index_type,
"source_nodes": [
{
"content": node.node.get_content()[:200] + "...",
"metadata": node.node.metadata,
"score": getattr(node, 'score', None)
}
for node in (response.source_nodes or [])
],
"success": True
}
except Exception as e:
return {
"error": str(e),
"index_used": index_type,
"success": False
}
def compare_indices(self, query: str) -> Dict[str, Any]:
"""
Compare responses across different indices for the same query
"""
results = {}
for index_type in self.indices.keys():
results[index_type] = self.query_specific_index(
query,
index_type,
response_mode="compact",
similarity_top_k=3
)
# Add routed response
results['routed'] = self.query_with_routing(query)
return {
"query": query,
"results_by_index": results,
"comparison_summary": self._generate_comparison_summary(results)
}
def _generate_comparison_summary(self, results: Dict) -> Dict[str, Any]:
"""Generate a summary comparing different index responses"""
successful_indices = [
idx for idx, result in results.items()
if result.get('success', False)
]
response_lengths = {
idx: len(result.get('response', ''))
for idx, result in results.items()
if result.get('success', False)
}
return {
"successful_indices": successful_indices,
"response_length_comparison": response_lengths,
"recommended_index": max(response_lengths, key=response_lengths.get) if response_lengths else None
}
class CustomRetriever:
"""
Custom retriever that implements advanced retrieval strategies
"""
def __init__(self, rag_system: EnterpriseRAGSystem):
self.rag_system = rag_system
self.retrieval_strategies = {
'hybrid': self._hybrid_retrieval,
'mmr': self._maximal_marginal_relevance,
'rerank': self._rerank_retrieval
}
def retrieve(
self,
query: str,
strategy: str = 'hybrid',
top_k: int = 5,
**kwargs
) -> List[Dict[str, Any]]:
"""
Retrieve documents using specified strategy
"""
if strategy not in self.retrieval_strategies:
raise ValueError(f"Strategy '{strategy}' not supported")
return self.retrieval_strategies[strategy](query, top_k, **kwargs)
def _hybrid_retrieval(self, query: str, top_k: int, **kwargs) -> List[Dict]:
"""Combine vector and keyword retrieval"""
# Get results from both vector and keyword indices
vector_results = self.rag_system.query_specific_index(
query, 'vector', similarity_top_k=top_k*2
)
keyword_results = self.rag_system.query_specific_index(
query, 'keyword', similarity_top_k=top_k*2
)
# Combine and deduplicate results
all_nodes = []
if vector_results.get('success'):
all_nodes.extend(vector_results['source_nodes'])
if keyword_results.get('success'):
# Add keyword results that aren't already included
existing_content = {node['content'] for node in all_nodes}
for node in keyword_results['source_nodes']:
if node['content'] not in existing_content:
all_nodes.append(node)
# Score and rank combined results
return self._score_and_rank_nodes(all_nodes, query)[:top_k]
def _maximal_marginal_relevance(self, query: str, top_k: int, **kwargs) -> List[Dict]:
"""Implement MMR to balance relevance and diversity"""
lambda_param = kwargs.get('lambda_mult', 0.5)
# Get initial candidate set
candidates = self._hybrid_retrieval(query, top_k * 3, **kwargs)
if not candidates:
return []
# MMR algorithm
selected = [candidates[0]] # Start with most relevant
candidates = candidates[1:]
for _ in range(min(top_k - 1, len(candidates))):
mmr_scores = []
for candidate in candidates:
# Relevance score
relevance = candidate.get('score', 0)
# Maximum similarity to already selected documents
max_sim = max([
self._compute_similarity(candidate['content'], selected_doc['content'])
for selected_doc in selected
])
# MMR score
mmr_score = lambda_param * relevance - (1 - lambda_param) * max_sim
mmr_scores.append((candidate, mmr_score))
# Select document with highest MMR score
best_candidate, _ = max(mmr_scores, key=lambda x: x[1])
selected.append(best_candidate)
candidates.remove(best_candidate)
return selected
def _rerank_retrieval(self, query: str, top_k: int, **kwargs) -> List[Dict]:
"""Implement learned reranking (simplified version)"""
# Get initial retrieval results
candidates = self._hybrid_retrieval(query, top_k * 2, **kwargs)
# Apply reranking heuristics (in production, you'd use a learned model)
for candidate in candidates:
base_score = candidate.get('score', 0)
# Length normalization
content_length = len(candidate['content'])
length_score = 1.0 - abs(content_length - 1000) / 1000 # Prefer ~1000 chars
# Recency bonus if timestamp available
recency_score = 1.0
if 'timestamp' in candidate.get('metadata', {}):
# Implementation would calculate recency
pass
# Combined reranking score
candidate['rerank_score'] = (
base_score * 0.7 +
length_score * 0.2 +
recency_score * 0.1
)
return sorted(candidates, key=lambda x: x.get('rerank_score', 0), reverse=True)[:top_k]
def _score_and_rank_nodes(self, nodes: List[Dict], query: str) -> List[Dict]:
"""Score and rank nodes based on relevance"""
# Simple scoring based on available information
for node in nodes:
base_score = node.get('score', 0.5)
# Boost score based on metadata relevance
metadata_boost = 0
if 'metadata' in node:
metadata_str = str(node['metadata']).lower()
query_terms = query.lower().split()
metadata_boost = sum(term in metadata_str for term in query_terms) * 0.1
node['final_score'] = base_score + metadata_boost
return sorted(nodes, key=lambda x: x.get('final_score', 0), reverse=True)
def _compute_similarity(self, text1: str, text2: str) -> float:
"""Simple similarity computation (in production, use proper embeddings)"""
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
intersection = words1.intersection(words2)
union = words1.union(words2)
return len(intersection) / len(union) if union else 0
# Example usage demonstrating LlamaIndex's capabilities
def demonstrate_llamaindex_system():
# Initialize the system (assumes you have a 'documents' directory)
rag_system = EnterpriseRAGSystem("./documents")
# Test automatic routing
routed_result = rag_system.query_with_routing(
"What are the main principles of our data governance policy?"
)
print("Routed Query Result:", routed_result['response'][:200], "...")
# Compare different indices
comparison = rag_system.compare_indices(
"How do we handle customer data privacy?"
)
print("Index Comparison Summary:", comparison['comparison_summary'])
# Test custom retrieval strategies
retriever = CustomRetriever(rag_system)
mmr_results = retriever.retrieve(
"machine learning model deployment",
strategy='mmr',
top_k=5,
lambda_mult=0.7
)
print(f"MMR Retrieved {len(mmr_results)} diverse results")
Architecture Note: LlamaIndex's composable graph approach allows for sophisticated query routing, but it adds complexity and latency. Use it when you have diverse document types that benefit from different indexing strategies.
Both frameworks require careful optimization for production use. Let's examine key performance considerations and optimization strategies:
import functools
import pickle
import hashlib
import redis
from typing import Optional, Callable, Any
import asyncio
from concurrent.futures import ThreadPoolExecutor
import time
class IntelligentCache:
"""
Production-grade caching system for AI workflows
"""
def __init__(
self,
redis_client: Optional[redis.Redis] = None,
default_ttl: int = 3600,
max_memory_cache: int = 1000
):
self.redis_client = redis_client or redis.Redis(host='localhost', port=6379, db=0)
self.default_ttl = default_ttl
self.memory_cache = {}
self.max_memory_cache = max_memory_cache
self.cache_stats = {
'hits': 0,
'misses': 0,
'redis_hits': 0,
'memory_hits': 0
}
def cache_key(self, func_name: str, args: tuple, kwargs: dict) -> str:
"""Generate consistent cache key for function calls"""
# Create a deterministic hash of the function call
key_data = {
'function': func_name,
'args': args,
'kwargs': sorted(kwargs.items())
}
key_string = pickle.dumps(key_data, protocol=pickle.HIGHEST_PROTOCOL)
return f"ai_workflow:{hashlib.sha256(key_string).hexdigest()}"
def get(self, key: str) -> Optional[Any]:
"""Get value from cache with memory -> Redis fallback"""
# Check memory cache first
if key in self.memory_cache:
self.cache_stats['hits'] += 1
self.cache_stats['memory_hits'] += 1
return self.memory_cache[key]['value']
# Check Redis cache
try:
redis_value = self.redis_client.get(key)
if redis_value:
value = pickle.loads(redis_value)
# Store in memory cache if there's space
if len(self.memory_cache) < self.max_memory_cache:
self.memory_cache[key] = {
'value': value,
'timestamp': time.time()
}
self.cache_stats['hits'] += 1
self.cache_stats['redis_hits'] += 1
return value
except (redis.ConnectionError, pickle.PickleError) as e:
print(f"Redis cache error: {e}")
self.cache_stats['misses'] += 1
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in both memory and Redis cache"""
ttl = ttl or self.default_ttl
# Store in memory cache
if len(self.memory_cache) >= self.max_memory_cache:
# Remove oldest entry
oldest_key = min(
self.memory_cache.keys(),
key=lambda k: self.memory_cache[k]['timestamp']
)
del self.memory_cache[oldest_key]
self.memory_cache[key] = {
'value': value,
'timestamp': time.time()
}
# Store in Redis
try:
serialized_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
return self.redis_client.setex(key, ttl, serialized_value)
except (redis.ConnectionError, pickle.PickleError) as e:
print(f"Redis set error: {e}")
return False
def invalidate_pattern(self, pattern: str) -> int:
"""Invalidate cache entries matching pattern"""
count = 0
# Clear from memory cache
keys_to_remove = [k for k in self.memory_cache.keys() if pattern in k]
for key in keys_to_remove:
del self.memory_cache[key]
count += 1
# Clear from Redis
try:
redis_keys = self.redis_client.keys(f"*{pattern}*")
if redis_keys:
count += self.redis_client.delete(*redis_keys)
except redis.ConnectionError as e:
print(f"Redis invalidation error: {e}")
return count
def get_stats(self) -> dict:
"""Get cache performance statistics"""
total_requests = self.cache_stats['hits'] + self.cache_stats['misses']
hit_rate = self.cache_stats['hits'] / total_requests if total_requests > 0 else 0
return {
**self.cache_stats,
'hit_rate': hit_rate,
'memory_cache_size': len(self.memory_cache),
'total_requests': total_requests
}
def cached_llm_call(cache: IntelligentCache, ttl: int = 3600):
"""Decorator for caching LLM calls"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
cache_key = cache.cache_key(func.__name__, args, kwargs)
# Try to get from cache
cached_result = cache.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function and cache result
result = func(*args, **kwargs)
cache.set(cache_key, result, ttl)
return result
return wrapper
return decorator
class AsyncWorkflowOptimizer:
"""
Optimize AI workflows with async processing and batching
"""
def __init__(self, max_concurrent: int = 5):
self.max_concurrent = max_concurrent
self.thread_pool = ThreadPoolExecutor(max_workers=max_concurrent)
async def batch_embed_documents(
self,
texts: List[str],
embedding_function: Callable,
batch_size: int = 20
) -> List[List[float]]:
"""Batch process embeddings for efficiency"""
embeddings = []
# Process in batches to avoid rate limits
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
# Create async tasks for the batch
tasks = [
asyncio.create_task(self._embed_single(text, embedding_function))
for text in batch
]
# Wait for batch completion
batch_embeddings = await asyncio.gather(*tasks)
embeddings.extend(batch_embeddings)
# Small delay to respect rate limits
if i + batch_size < len(texts):
await asyncio.sleep(0.1)
return embeddings
async def _embed_single(self, text: str, embedding_function: Callable) -> List[float]:
"""Embed single text with error handling"""
try:
# Run embedding function in thread pool to avoid blocking
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.thread_pool,
embedding_function,
text
)
except Exception as e:
print(f"Embedding error for text: {text[:50]}... Error: {e}")
# Return zero vector as fallback
return [0.0] * 1536 # Adjust dimension as needed
async def parallel_retrieval(
self,
query: str,
retrievers: List[Callable],
max_results_per_retriever: int = 10
) -> Dict[str, List[Any]]:
"""Run multiple retrievers in parallel"""
tasks = []
for i, retriever in enumerate(retrievers):
task = asyncio.create_task(
self._safe_retrieve(f"retriever_{i}", retriever, query, max_results_per_retriever)
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results and handle exceptions
processed_results = {}
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results[f"retriever_{i}"] = {
'error': str(result),
'results': []
}
else:
processed_results[f"retriever_{i}"] = result
return processed_results
async def _safe_retrieve(
self,
name: str,
retriever: Callable,
query: str,
max_results: int
) -> Dict[str, Any]:
"""Safe retrieval with error handling"""
try:
start_time = time.time()
# Run retriever in thread pool
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
self.thread_pool,
lambda: retriever(query, top_k=max_results)
)
end_time = time.time()
return {
'results': results[:max_results],
'retrieval_time': end_time - start_time,
'result_count': len(results) if results else 0
}
except Exception as e:
return {
'error': str(e),
'results': [],
'retrieval_time': 0,
'result_count': 0
}
# Production monitoring and error handling
class WorkflowMonitor:
"""
Monitor AI workflow performance and errors
"""
def __init__(self, log_file: str = "workflow_monitor.log"):
self.log_file = log_file
self.metrics = {
'total_requests': 0,
'successful_requests': 0,
'failed_requests': 0,
'average_response_time': 0,
'error_types': {}
}
self.request_times = []
def log_request(
self,
request_type: str,
success: bool,
response_time: float,
error: Optional[str] = None,
metadata: Optional[Dict] = None
):
"""Log a workflow request"""
self.metrics['total_requests'] += 1
if success:
self.metrics['successful_requests'] += 1
else:
self.metrics['failed_requests'] += 1
if error:
error_type = type(error).__name__ if isinstance(error, Exception) else str(error)
self.metrics['error_types'][error_type] = \
self.metrics['error_types'].get(error_type, 0) + 1
# Track response times
self.request_times.append(response_time)
if len(self.request_times) > 1000: # Keep last 1000 requests
self.request_times = self.request_times[-1000:]
self.metrics['average_response_time'] = sum(self.request_times) / len(self.request_times)
# Log to file
log_entry = {
'timestamp': time.time(),
'request_type': request_type,
'success': success,
'response_time': response_time,
'error': error,
'metadata': metadata
}
with open(self.log_file, 'a') as f:
f.write(f"{log_entry}\n")
def get_health_metrics(self) -> Dict[str, Any]:
"""Get current system health metrics"""
total = self.metrics['total_requests']
success_rate = self.metrics['successful_requests'] / total if total > 0 else 0
return {
'success_rate': success_rate,
'average_response_time': self.metrics['average_response_time'],
'total_requests': total,
'error_distribution': self.metrics['error_types'],
'health_status': 'healthy' if success_rate > 0.95 else 'degraded' if success_rate > 0.8 else 'unhealthy'
}
# Example production workflow implementation
class ProductionWorkflow:
"""
Production-ready workflow combining optimization techniques
"""
def __init__(self):
self.cache = IntelligentCache()
self.optimizer = AsyncWorkflowOptimizer()
self.monitor = WorkflowMonitor()
@cached_llm_call(cache=None, ttl=1800) # Will be set in __post_init__
def _cached_llm_call(self, prompt: str, model: str = "gpt-3.5-turbo") -> str:
"""Cached LLM call to avoid redundant API requests"""
# This would be your actual LLM call
# For demo, we'll simulate
time.sleep(0.1) # Simulate API latency
return f"Response to: {prompt[:50]}..."
def __post_init__(self):
# Set up cache decorator properly
self._cached_llm_call = cached_llm_call(self.cache, ttl=1800)(self._cached_llm_call)
async def process_complex_query(
self,
query: str,
document_retrievers: List[Callable]
) -> Dict[str, Any]:
"""
Process a complex query with full optimization
"""
start_time = time.time()
try:
# Step 1: Parallel retrieval from multiple sources
retrieval_results = await self.optimizer.parallel_retrieval(
query,
document_retrievers,
max_results_per_retriever=5
)
# Step 2: Process and combine results
combined_context = self._combine_retrieval_results(retrieval_results)
# Step 3: Generate response using cached LLM call
response = self._cached_llm_call(
f"Context: {combined_context}\n\nQuery: {query}"
)
response_time = time.time() - start_time
# Log successful request
self.monitor.log_request(
request_type="complex_query",
success=True,
response_time=response_time,
metadata={
'retrievers_used': len(document_retrievers),
'total_documents': sum(
len(result.get('results', []))
for result in retrieval_results.values()
)
}
)
return {
'response': response,
'retrieval_results': retrieval_results,
'processing_time': response_time,
'cache_stats': self.cache.get_stats(),
'success': True
}
except Exception as e:
response_time = time.time() - start_time
# Log failed request
self.monitor.log_request(
request_type="complex_query",
success=False,
response_time=response_time,
error=str(e)
)
return {
'error': str(e),
'processing_time': response_time,
'success': False
}
def _combine_retrieval_results(self, results: Dict[str, Any]) -> str:
"""Combine results from multiple retrievers"""
combined_text = []
for retriever_name, result in results.items():
if 'results' in result and result['results']:
combined_text.append(f"From {retriever_name}:")
for doc in result['results'][:3]: # Top 3 from each
if isinstance(doc, dict) and 'content' in doc:
combined_text.append(doc['content'][:200] + "...")
else:
combined_text.append(str(doc)[:200] + "...")
return "\n\n".join(combined_text)
def get_system_status(self) -> Dict[str, Any]:
"""Get comprehensive system status"""
return {
'cache_performance': self.cache.get_stats(),
'workflow_health': self.monitor.get_health_metrics(),
'system_timestamp': time.time()
}
# Initialize the production workflow
def initialize_production_system():
workflow = ProductionWorkflow()
workflow.__post_init__() # Set up caching
return workflow
Production Warning: Always implement circuit breakers and fallback mechanisms for external API calls. LLM services can experience outages or rate limiting that will cascade through your entire system.
After working extensively with both frameworks, here's a detailed comparison to guide your selection:
LangChain adopts a maximalist approach, providing building blocks for any conceivable AI workflow. Its strength lies in flexibility and the breadth of integrations. The framework excels when you need:
However, this flexibility comes with complexity costs. LangChain's abstraction layers can make debugging difficult, and the framework's rapid evolution sometimes breaks backward compatibility.
LlamaIndex takes a focused approach, optimizing specifically for retrieval-augmented generation. Its architecture decisions are opinionated but lead to better performance for RAG use cases:
# LlamaIndex performance comparison example
import time
from typing import List, Dict
import statistics
class FrameworkBenchmark:
"""
Compare performance characteristics of different approaches
"""
def __init__(self, test_queries: List[str], document_count: int = 1000):
self.test_queries = test_queries
self.document_count = document_count
self.results = {
'langchain': {'times': [], 'memory_usage': [], 'accuracy_scores': []},
'llamaindex': {'times': [], 'memory_usage': [], 'accuracy_scores': []}
}
def benchmark_langchain_rag(self, rag_chain) -> Dict[str, float]:
"""Benchmark LangChain RAG implementation"""
start_time = time.time()
memory_before = self._get_memory_usage()
results = []
for query in self.test_queries:
query_start = time.time()
try:
response = rag_chain.run(query)
query_time = time.time() - query_start
results.append({
'query': query,
'response': response,
'time': query_time,
'success': True
})
except Exception as e:
results.append({
'query': query,
'error': str(e),
'time': time.time() - query_start,
'success': False
})
total_time = time.time() - start_time
memory_after = self._get_memory_usage()
successful_queries = [r for r in results if r['success']]
average_query_time = statistics.mean([r['time'] for r in successful_queries]) if successful_queries else 0
return {
'total_time': total_time,
'average_query_time': average_query_time,
'memory_delta': memory_after - memory_before,
'success_rate': len(successful_queries) / len(self.test_queries),
'results': results
}
def benchmark_llamaindex_rag(self, index) -> Dict[str, float]:
"""Benchmark LlamaIndex implementation"""
start_time = time.time()
memory_before = self._get_memory_usage()
query_engine = index.as_query_engine(
response_mode="compact",
similarity_top_k=5
)
results = []
for query in self.test_queries:
query_start = time.time()
try:
response = query_engine.query(query)
query_time = time.time() - query_start
results.append({
'query': query,
'response': str(response),
'time': query_time,
'success': True,
'source_nodes': len(response.source_nodes) if hasattr(response, 'source_nodes') else 0
})
except Exception as e:
results.append({
'query': query,
'error': str(e),
'time': time.time() - query_start,
'success': False
})
total_time = time.time() - start_time
memory_after = self._get_memory_usage()
successful_queries = [r for r in results if r['success']]
average_query_time = statistics.mean([r['time'] for r in successful_queries]) if successful_queries else 0
return {
'total_time': total_time,
'average_query_time': average_query_time,
'memory_delta': memory_after - memory_before,
'success_rate': len(successful_queries) / len(self.test_queries),
'results': results
}
def _get_memory_usage(self) -> float:
"""Get current memory usage in MB"""
import psutil
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def run_comprehensive_benchmark(
self,
langchain_rag,
llamaindex_rag
) -> Dict[str, Any]:
"""Run comprehensive performance comparison"""
print("Benchmarking LangChain implementation...")
langchain_results = self.benchmark_langchain_rag(langchain_rag)
print("Benchmarking LlamaIndex implementation...")
llamaindex_results = self.benchmark_llamaindex_rag(llamaindex_rag)
# Calculate relative performance metrics
comparison = {
'speed_comparison': {
'langchain_avg_time': langchain_results['average_query_time'],
'llamaindex_avg_time': llamaindex_results['average_query_time'],
'llamaindex_speedup': langchain_results['average_query_time'] / llamaindex_results['average_query_time'] if llamaindex_results['average_query_time'] > 0 else float('inf')
},
'memory_comparison': {
'langchain_memory_delta': langchain_results['memory_delta'],
'llamaindex_memory_delta': llamaindex_results['memory_delta'],
'memory_efficiency': langchain_results['memory_delta'] / llamaindex_results['memory_delta'] if llamaindex_results['memory_delta'] > 0 else float('inf')
},
'reliability_comparison': {
'langchain_success_rate': langchain_results['success_rate'],
'llamaindex_success_rate': llamaindex_results['success_rate']
}
}
# Determine recommendation
recommendation = self._generate_recommendation(comparison)
return {
'langchain_results': langchain_results,
'llamaindex_results': llamaindex_results,
'comparison': comparison,
'recommendation': recommendation
}
def _generate_recommendation(self, comparison: Dict) -> Dict[str, str]:
"""Generate framework recommendation based on benchmarks
Learning Path: Building with LLMs