26 Advanced LLM Serving
Production Techniques for High-Throughput Inference
Your LLM serving system handles 1,000 requests/second with P50 latency of 80ms. The product team wants 10,000 req/s at the same latency. Buying 10× more GPUs would cost $500K/month.
Can you get there without 10× the hardware? This chapter follows one request through a production system, showing why each optimization exists.
Prefix caching / RadixAttention exploits redundancy: many requests share common prefixes whose KV cache entries are identical. Speculative decoding exploits sparsity in the draft-target disagreement: most draft tokens are accepted, so verification is sparse. Disaggregated serving exploits separability: prefill and decode have different resource profiles and can be factored onto different hardware.
26.1 The Investigation: Anatomy of a Request
Before cataloging techniques, let’s trace a single request through a production LLM system and identify every bottleneck:
User sends: "Translate to French: Hello, how are you?"
1. Load balancer routes to a GPU server (~1ms network)
2. Tokenizer encodes the prompt (~0.5ms CPU)
3. KV cache lookup for shared prefix (~0.1ms)
4. Prefill: process all prompt tokens (~15ms, compute-bound)
5. Decode: generate tokens one at a time (~8ms/token, memory-bound)
6. Detokenize and stream back (~0.1ms/token)
Total for 20-token response: ~175ms
Target: <200ms P99. Barely making it at low load.
At 10,000 req/s, the system collapses — queueing delays explode (?sec-queueing). Each technique in this chapter attacks a specific bottleneck in this pipeline. Let’s see which ones give us that 10× throughput.
The Inference chapter covered the fundamentals: KV caching, continuous batching, speculative decoding. This chapter goes deeper with production-grade techniques.
26.2 Prefix Caching and RadixAttention
26.2.1 The Observation
Many requests share common prefixes:
System prompt (shared):
"You are a helpful AI assistant. Be concise and accurate..."
User requests:
Request 1: [system prompt] + "What is the capital of France?"
Request 2: [system prompt] + "Explain quantum computing."
Request 3: [system prompt] + "Write a haiku about coding."
The system prompt KV cache is recomputed for every request.
26.2.2 RadixAttention (SGLang)
Store KV caches in a radix tree, share across requests:
class RadixCache:
"""
Radix tree for prefix caching.
Key insight: Common prefixes share KV cache.
"""
def __init__(self):
self.root = RadixNode()
def insert(self, tokens, kv_cache):
"""Insert KV cache for token sequence."""
node = self.root
for i, token in enumerate(tokens):
if token not in node.children:
node.children[token] = RadixNode()
node = node.children[token]
node.kv_cache = kv_cache[i] # Store per-position KV
def lookup(self, tokens):
"""Find longest matching prefix."""
node = self.root
matched = 0
kv_cache = []
for token in tokens:
if token in node.children:
node = node.children[token]
kv_cache.append(node.kv_cache)
matched += 1
else:
break
return matched, kv_cache
# Usage:
cache = RadixCache()
# First request computes full KV cache
kv = model.prefill("You are helpful...")
cache.insert(tokenize("You are helpful..."), kv)
# Second request reuses cached prefix
prefix_len, cached_kv = cache.lookup(tokenize("You are helpful..."))
# Only compute KV for new tokens!26.2.3 Memory Management
Prefix caching needs smart eviction:
class LRUPrefixCache:
def __init__(self, max_memory_gb):
self.max_memory = max_memory_gb * 1e9
self.cache = OrderedDict() # LRU order
self.current_memory = 0
def get_or_compute(self, prefix_tokens, model):
key = hash(tuple(prefix_tokens))
if key in self.cache:
# Move to end (most recently used)
self.cache.move_to_end(key)
return self.cache[key]
# Compute new KV cache
kv = model.prefill(prefix_tokens)
kv_size = self.compute_size(kv)
# Evict if necessary
while self.current_memory + kv_size > self.max_memory:
self.evict_oldest()
self.cache[key] = kv
self.current_memory += kv_size
return kvPerformance impact:
System prompt: 500 tokens
User query: 50 tokens
Without prefix caching: 550 tokens prefill per request
With prefix caching: 50 tokens prefill per request (first request caches)
Speedup: 11x for time-to-first-token on repeated prefixes
26.3 Chunked Prefill
26.3.1 The Problem
Long prefill blocks decode:
Timeline without chunked prefill:
Request A (long prompt): [==========PREFILL==========][decode][decode]...
Request B (arrives during A's prefill):
waiting... waiting... [PREFILL][decode]...
↑ High latency for B!
26.3.2 The Solution
Split prefill into chunks, interleave with decode:
class ChunkedPrefillScheduler:
def __init__(self, chunk_size=512):
self.chunk_size = chunk_size
self.prefill_queue = []
self.decode_queue = []
def schedule_iteration(self):
"""
Each iteration: some prefill tokens + some decode tokens.
"""
batch = []
# Add decode requests (high priority for latency)
for req in self.decode_queue[:MAX_DECODE]:
batch.append(('decode', req, 1)) # 1 token per decode
# Fill remaining capacity with prefill chunks
remaining_capacity = MAX_TOKENS - len(batch)
for req in self.prefill_queue:
if remaining_capacity <= 0:
break
tokens_remaining = req.total_tokens - req.processed_tokens
chunk = min(tokens_remaining, self.chunk_size, remaining_capacity)
batch.append(('prefill', req, chunk))
remaining_capacity -= chunk
return batchTimeline with chunked prefill:
Request A: [prefill_chunk][decode B][prefill_chunk][decode B]...
Request B: waiting... [prefill][decode][decode]...
↑ Much lower latency!
26.4 Disaggregated Serving
26.4.1 The Observation
Prefill and decode have different characteristics:
Prefill:
- Compute-bound (matrix multiplies)
- High parallelism (many tokens)
- Batch-friendly
- High memory bandwidth for KV cache write
Decode:
- Memory-bound (reading KV cache)
- Sequential (one token at a time)
- Latency-sensitive
- Low compute intensity
26.4.2 Split Prefill and Decode
Run on different GPU pools:
class DisaggregatedServer:
def __init__(self):
self.prefill_workers = GPUPool(gpu_type='A100', count=4)
self.decode_workers = GPUPool(gpu_type='L4', count=16) # Cheaper GPUs
self.kv_cache_store = DistributedKVStore()
async def handle_request(self, prompt):
# 1. Prefill on high-compute GPUs
kv_cache = await self.prefill_workers.prefill(prompt)
# 2. Store KV cache in distributed store
cache_id = await self.kv_cache_store.put(kv_cache)
# 3. Decode on memory-optimized GPUs
output = await self.decode_workers.decode(cache_id)
return outputBenefits: - Prefill: Use fewer, more powerful GPUs (compute-bound) - Decode: Use more, cheaper GPUs (memory-bound) - Better overall cost efficiency
Challenges: - KV cache transfer latency - Complexity of distributed state
26.5 Speculative Decoding Advances
26.5.1 Beyond Simple Speculation
The Inference chapter covered basic speculative decoding. Advanced techniques:
1. Medusa: Multiple Heads
Add prediction heads that output multiple future tokens:
class MedusaModel(nn.Module):
def __init__(self, base_model, num_heads=4):
super().__init__()
self.base = base_model
# Each head predicts a different future position
self.heads = nn.ModuleList([
nn.Linear(hidden_dim, vocab_size)
for _ in range(num_heads)
])
def forward(self, x):
hidden = self.base(x) # [batch, seq, hidden]
# Main prediction (next token)
logits_0 = self.base.lm_head(hidden[:, -1:])
# Speculative predictions (tokens 2, 3, 4, ...)
speculative_logits = [
head(hidden[:, -1:])
for head in self.heads
]
return logits_0, speculative_logitsVerification: Single forward pass verifies all speculative tokens.
2. EAGLE: Draft with Hidden States
Use hidden states, not tokens, for drafting:
class EAGLEDraft:
"""
Draft using feature-level prediction.
"""
def draft(self, hidden_states, num_tokens):
drafts = []
h = hidden_states
for _ in range(num_tokens):
# Predict next hidden state
h_next = self.feature_predictor(h)
# Decode to token
logits = self.lm_head(h_next)
token = logits.argmax(-1)
drafts.append(token)
h = h_next
return draftsAdvantage: Captures more information than token-level drafting.
3. Lookahead Decoding
Speculate using n-gram patterns from the input:
def lookahead_decode(model, prompt, n=5):
"""
Use n-grams from prompt to speculate future tokens.
"""
# Build n-gram table from prompt
ngrams = extract_ngrams(prompt, n)
tokens = prompt
while not done:
# Get current n-gram
current_ngram = tuple(tokens[-n:])
if current_ngram in ngrams:
# Speculate based on seen pattern
speculation = ngrams[current_ngram]
# Verify with single forward pass
verified = verify(model, tokens, speculation)
tokens.extend(verified)
else:
# Standard decoding
next_token = model.decode_one(tokens)
tokens.append(next_token)
return tokens26.6 Batch Scheduling Strategies
Batching improves throughput but hurts latency—a counter-intuitive tradeoff that catches many engineers.
Why it’s counter-intuitive: Bigger batches amortize overhead and improve GPU utilization. More throughput should mean faster responses, right?
Why it hurts latency: A request that arrives while a large batch is processing must wait. If you’re the last request added to a batch of 64, you wait for 63 other requests to complete before seeing your first token.
The hidden variable: Throughput optimization assumes latency tolerance. Interactive applications don’t have it.
This is why production systems need priority scheduling—not all requests can tolerate the latching delay that maximizes throughput.
26.6.1 Priority-Based Scheduling
Not all requests are equal:
class PriorityScheduler:
def __init__(self):
self.queues = {
'realtime': deque(), # Interactive chat
'standard': deque(), # API requests
'batch': deque() # Background jobs
}
self.weights = {'realtime': 10, 'standard': 5, 'batch': 1}
def select_batch(self, max_tokens):
"""
Weighted fair scheduling across priority classes.
"""
batch = []
tokens_used = 0
# Round-robin with weights
for priority in ['realtime', 'standard', 'batch']:
weight = self.weights[priority]
queue = self.queues[priority]
for _ in range(weight):
if queue and tokens_used < max_tokens:
req = queue.popleft()
batch.append(req)
tokens_used += req.tokens_needed
return batch26.6.2 Preemption
Sometimes you need to preempt running requests:
class PreemptibleScheduler:
def __init__(self, preemption_threshold_ms=100):
self.threshold = preemption_threshold_ms
def maybe_preempt(self, running_requests, new_request):
"""
Preempt low-priority request if high-priority arrives.
"""
if new_request.priority != 'realtime':
return None # Only preempt for realtime
# Find lowest priority request
for req in sorted(running_requests, key=lambda r: r.priority):
if req.priority == 'batch':
# Save state for later resumption
saved_state = self.save_kv_cache(req)
self.queues['batch'].appendleft((req, saved_state))
return req # Preempt this one
return None26.7 Multi-Model Serving
26.7.1 Model Multiplexing
Serve multiple models efficiently:
class MultiModelServer:
def __init__(self, gpu_memory_gb=80):
self.memory_budget = gpu_memory_gb * 1e9
self.loaded_models = {} # model_id -> model
self.model_sizes = {}
def load_model(self, model_id):
"""Load model, evicting others if necessary."""
size = self.get_model_size(model_id)
# Evict until we have space
while self.current_memory() + size > self.memory_budget:
self.evict_lru_model()
model = load_from_disk(model_id)
self.loaded_models[model_id] = model
self.model_sizes[model_id] = size
async def infer(self, model_id, prompt):
if model_id not in self.loaded_models:
self.load_model(model_id)
return await self.loaded_models[model_id].generate(prompt)26.7.2 LoRA Serving
Efficient serving of many LoRA adapters:
class LoRAServer:
def __init__(self, base_model):
self.base = base_model
self.adapters = {} # adapter_id -> (A, B) matrices
def load_adapter(self, adapter_id, adapter_path):
"""Load LoRA adapter (small, can cache many)."""
A, B = load_lora_weights(adapter_path)
self.adapters[adapter_id] = (A.cuda(), B.cuda())
def forward(self, x, adapter_id):
"""Forward with specific adapter."""
A, B = self.adapters[adapter_id]
# Base model forward
base_out = self.base(x)
# LoRA forward
lora_out = (x @ A) @ B
return base_out + lora_out
def batched_forward(self, requests):
"""
Batch requests with different adapters.
Uses batched matrix multiply with adapter indices.
"""
# Group by adapter
by_adapter = defaultdict(list)
for req in requests:
by_adapter[req.adapter_id].append(req)
# Process each adapter group
outputs = []
for adapter_id, reqs in by_adapter.items():
batch = torch.stack([r.input for r in reqs])
out = self.forward(batch, adapter_id)
outputs.extend(zip(reqs, out))
return outputs26.8 Monitoring and Observability
26.8.1 Key Metrics
class ServingMetrics:
def __init__(self):
self.metrics = defaultdict(list)
def record_request(self, req):
# Latency metrics
self.metrics['ttft'].append(req.time_to_first_token)
self.metrics['tpot'].append(req.time_per_output_token)
self.metrics['total_latency'].append(req.total_latency)
# Throughput metrics
self.metrics['input_tokens'].append(req.input_length)
self.metrics['output_tokens'].append(req.output_length)
# Efficiency metrics
self.metrics['batch_size'].append(req.batch_size)
self.metrics['cache_hit_rate'].append(req.prefix_cache_hit)
def report(self):
return {
'ttft_p50': np.percentile(self.metrics['ttft'], 50),
'ttft_p99': np.percentile(self.metrics['ttft'], 99),
'throughput_tps': sum(self.metrics['output_tokens']) / total_time,
'cache_hit_rate': np.mean(self.metrics['cache_hit_rate']),
'avg_batch_size': np.mean(self.metrics['batch_size']),
}26.8.2 Health Checks
class HealthChecker:
async def check_health(self):
checks = {
'gpu_memory': self.check_gpu_memory(),
'model_loaded': self.check_model_loaded(),
'latency': await self.check_latency(),
'queue_depth': self.check_queue_depth(),
}
healthy = all(checks.values())
return {'healthy': healthy, 'checks': checks}
def check_gpu_memory(self):
used = torch.cuda.memory_allocated()
total = torch.cuda.get_device_properties(0).total_memory
return used / total < 0.95 # Alert at 95%
async def check_latency(self):
start = time.time()
_ = await self.model.generate("test", max_tokens=1)
latency = time.time() - start
return latency < 0.5 # Health check should be < 500ms26.9 Cost Optimization
26.9.1 Token Budgeting
class TokenBudget:
def __init__(self, daily_budget_tokens):
self.daily_budget = daily_budget_tokens
self.used_today = 0
self.reset_time = self.next_midnight()
def can_serve(self, estimated_tokens):
if time.time() > self.reset_time:
self.used_today = 0
self.reset_time = self.next_midnight()
return self.used_today + estimated_tokens <= self.daily_budget
def record_usage(self, tokens):
self.used_today += tokens
def estimate_request_cost(self, input_len, max_output):
# Pricing model
input_cost = input_len * 0.00001 # $0.01 per 1K input
output_cost = max_output * 0.00003 # $0.03 per 1K output
return input_cost + output_cost26.9.2 Model Selection
Route to cheaper models when possible:
class ModelRouter:
def __init__(self):
self.models = {
'small': {'model': 'llama-7b', 'cost': 0.001},
'medium': {'model': 'llama-13b', 'cost': 0.003},
'large': {'model': 'llama-70b', 'cost': 0.01},
}
def select_model(self, request):
"""Select cheapest model that can handle request."""
complexity = self.estimate_complexity(request)
if complexity < 0.3:
return self.models['small']
elif complexity < 0.7:
return self.models['medium']
else:
return self.models['large']
def estimate_complexity(self, request):
"""Estimate request complexity (0-1)."""
# Heuristics: length, keywords, domain
score = 0
if 'code' in request.lower() or 'math' in request.lower():
score += 0.4
if len(request) > 500:
score += 0.3
return min(score, 1.0)26.10 vLLM Internals
26.10.1 Architecture Overview
vLLM [1] is the most widely-deployed open-source LLM serving engine. Understanding its internals helps optimize production deployments.
vLLM Architecture:
┌─────────────────────────────────────────────────────────┐
│ API Server (FastAPI) │
├─────────────────────────────────────────────────────────┤
│ LLM Engine │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────┐ │
│ │ Scheduler │ │ KV Cache │ │ Model Executor │ │
│ │ │ │ Manager │ │ │ │
│ └─────────────┘ └─────────────┘ └─────────────────┘ │
├─────────────────────────────────────────────────────────┤
│ Worker (per GPU) │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────┐ │
│ │ Model │ │ PagedAttention│ │ Sampler │ │
│ │ Runner │ │ Kernels │ │ │ │
│ └─────────────┘ └─────────────┘ └─────────────────┘ │
└─────────────────────────────────────────────────────────┘
26.10.2 The Scheduler Deep Dive
vLLM’s scheduler makes critical decisions every iteration:
class vLLMScheduler:
"""
Simplified vLLM scheduler logic.
"""
def __init__(self, config):
self.waiting = deque() # Requests waiting to start
self.running = deque() # Requests currently generating
self.swapped = deque() # Requests swapped to CPU
self.block_manager = BlockManager(config)
def schedule(self):
"""
Core scheduling loop, called every iteration.
"""
# Step 1: Try to resume swapped requests
resumed = self._try_resume_swapped()
# Step 2: Preempt if out of memory
if self._is_oom():
self._preempt_requests()
# Step 3: Admit new requests
admitted = self._admit_waiting()
# Step 4: Build batch for this iteration
batch = SchedulerOutputs(
scheduled_running=list(self.running),
scheduled_prefill=admitted,
blocks_to_swap_in=resumed,
blocks_to_swap_out=preempted,
)
return batch
def _admit_waiting(self):
"""Admit waiting requests if we have memory."""
admitted = []
for seq_group in list(self.waiting):
# Calculate memory needed
num_required_blocks = self._get_required_blocks(seq_group)
if self.block_manager.can_allocate(num_required_blocks):
self.block_manager.allocate(seq_group, num_required_blocks)
self.waiting.remove(seq_group)
self.running.append(seq_group)
admitted.append(seq_group)
else:
break # Can't fit more
return admitted
def _preempt_requests(self):
"""Preempt lowest priority requests when OOM."""
# Sort by priority, preempt lowest first
sorted_running = sorted(self.running, key=lambda x: x.priority)
for seq_group in sorted_running:
if not self._is_oom():
break
# Swap KV cache to CPU
self.block_manager.swap_out(seq_group)
self.running.remove(seq_group)
self.swapped.append(seq_group)26.10.3 PagedAttention Memory Management
vLLM’s block allocator manages GPU memory like an OS manages RAM:
class BlockManager:
"""
Manages KV cache memory in fixed-size blocks.
"""
def __init__(self, block_size=16, num_gpu_blocks=1000, num_cpu_blocks=500):
self.block_size = block_size
# GPU block pool
self.gpu_free_blocks = list(range(num_gpu_blocks))
self.gpu_allocated = {} # seq_id -> [block_ids]
# CPU block pool (for swapping)
self.cpu_free_blocks = list(range(num_cpu_blocks))
self.cpu_allocated = {}
def allocate(self, seq_group, num_blocks):
"""Allocate blocks for a sequence group."""
if len(self.gpu_free_blocks) < num_blocks:
raise OOMError("Not enough GPU blocks")
blocks = [self.gpu_free_blocks.pop() for _ in range(num_blocks)]
self.gpu_allocated[seq_group.id] = blocks
return blocks
def append_slots(self, seq_group, num_tokens):
"""
Allocate additional slots as sequence grows.
Called when generating new tokens.
"""
blocks = self.gpu_allocated[seq_group.id]
current_slots = len(blocks) * self.block_size
if seq_group.num_tokens + num_tokens > current_slots:
# Need new block
if not self.gpu_free_blocks:
return None # Trigger preemption
new_block = self.gpu_free_blocks.pop()
blocks.append(new_block)
return blocks
def swap_out(self, seq_group):
"""Swap KV cache from GPU to CPU."""
gpu_blocks = self.gpu_allocated.pop(seq_group.id)
# Get CPU blocks
cpu_blocks = [self.cpu_free_blocks.pop() for _ in gpu_blocks]
# Actual copy happens in worker
self.cpu_allocated[seq_group.id] = {
'cpu_blocks': cpu_blocks,
'gpu_blocks': gpu_blocks, # Remember for swap_in
}
# Return GPU blocks to pool
self.gpu_free_blocks.extend(gpu_blocks)
return (gpu_blocks, cpu_blocks) # Copy mapping
def swap_in(self, seq_group):
"""Swap KV cache from CPU back to GPU."""
info = self.cpu_allocated.pop(seq_group.id)
cpu_blocks = info['cpu_blocks']
# Get fresh GPU blocks
gpu_blocks = [self.gpu_free_blocks.pop() for _ in cpu_blocks]
self.gpu_allocated[seq_group.id] = gpu_blocks
# Return CPU blocks to pool
self.cpu_free_blocks.extend(cpu_blocks)
return (cpu_blocks, gpu_blocks) # Copy mapping26.10.4 Key vLLM Configuration Options
from vllm import LLM, SamplingParams
# Memory optimization
llm = LLM(
model="meta-llama/Llama-2-70b-chat-hf",
# Memory management
gpu_memory_utilization=0.90, # Use 90% of GPU memory for KV cache
swap_space=16, # GB of CPU memory for swapping
# Scheduling
max_num_seqs=256, # Max concurrent sequences
max_num_batched_tokens=4096, # Max tokens per iteration
# Parallelism
tensor_parallel_size=4, # Distribute across 4 GPUs
pipeline_parallel_size=1,
# Quantization
quantization="awq", # or "gptq", "squeezellm"
# KV cache optimization
kv_cache_dtype="fp8_e5m2", # FP8 KV cache (Hopper+)
block_size=16,
# Prefix caching
enable_prefix_caching=True,
)26.11 SGLang: Structured Generation
26.11.1 Beyond Text Generation
SGLang extends LLM serving with structured outputs and control flow:
import sglang as sgl
# Define structured program
@sgl.function
def extract_info(s, text):
s += "Extract information from: " + text + "\n\n"
s += "Name: "
s += sgl.gen("name", stop="\n")
s += "Age: "
s += sgl.gen("age", stop="\n", regex=r"\d+")
s += "Email: "
s += sgl.gen("email", stop="\n", regex=r"[\w\.-]+@[\w\.-]+\.\w+")
return {"name": s["name"], "age": s["age"], "email": s["email"]}
# Execute with RadixAttention
result = extract_info.run(text="John Smith, 35, john@example.com")26.11.2 Constrained Decoding
SGLang enforces output structure during generation:
@sgl.function
def json_generation(s, schema):
"""Generate JSON matching a schema."""
s += "Generate a JSON object matching this schema:\n"
s += str(schema) + "\n\n"
# Constrained generation - only valid JSON tokens allowed
s += sgl.gen("json",
json_schema=schema,
max_tokens=500
)
return json.loads(s["json"])
# Schema enforcement
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"scores": {"type": "array", "items": {"type": "number"}}
},
"required": ["name", "scores"]
}
result = json_generation.run(schema=schema)
# Guaranteed to be valid JSON matching schema!26.11.3 Fork and Join
SGLang supports parallel generation paths:
@sgl.function
def multi_perspective(s, question):
s += f"Question: {question}\n\n"
# Fork into multiple perspectives
fork_results = []
perspectives = ["optimist", "pessimist", "realist"]
for perspective in perspectives:
with sgl.fork(s) as fork:
fork += f"Answer as a {perspective}:\n"
fork += sgl.gen(f"answer_{perspective}", max_tokens=100)
fork_results.append(fork[f"answer_{perspective}"])
# Join: Synthesize perspectives
s += "\nSynthesis of all perspectives:\n"
for p, r in zip(perspectives, fork_results):
s += f"- {p}: {r}\n"
s += "\nFinal balanced answer:\n"
s += sgl.gen("final", max_tokens=200)
return s["final"]26.11.4 RadixAttention in Detail
SGLang’s RadixAttention implementation:
class SGLangRadixCache:
"""
Production-grade RadixAttention implementation.
"""
def __init__(self, max_total_tokens=100_000):
self.root = RadixNode()
self.max_tokens = max_total_tokens
self.current_tokens = 0
self.lock = threading.Lock()
def match_prefix(self, input_ids: List[int]) -> Tuple[int, List[KVCache]]:
"""
Find longest matching prefix in O(n) time.
Returns:
matched_length: Number of tokens matched
kv_caches: KV cache for matched prefix
"""
with self.lock:
node = self.root
matched = 0
kv_caches = []
for token in input_ids:
if token in node.children:
child = node.children[token]
node = child
matched += 1
if child.kv_cache is not None:
kv_caches.append(child.kv_cache)
else:
break
# Update access time for LRU
self._update_access(node)
return matched, kv_caches
def insert(self, input_ids: List[int], kv_cache: KVCache):
"""Insert new prefix into cache."""
with self.lock:
# Evict if necessary
cache_size = self._estimate_size(kv_cache)
while self.current_tokens + len(input_ids) > self.max_tokens:
self._evict_lru()
# Insert
node = self.root
for i, token in enumerate(input_ids):
if token not in node.children:
node.children[token] = RadixNode()
node = node.children[token]
# Store KV cache at final position
if i == len(input_ids) - 1:
node.kv_cache = kv_cache
self.current_tokens += len(input_ids)
def _evict_lru(self):
"""Evict least recently used prefix."""
# Find LRU leaf node
lru_node, lru_parent, lru_token = self._find_lru_leaf()
if lru_node is not None:
tokens_freed = lru_node.depth
self.current_tokens -= tokens_freed
del lru_parent.children[lru_token]26.11.5 Performance Benefits
Benchmark: 1000 requests with shared system prompt
Without RadixAttention (vLLM PagedAttention only):
- Prefill per request: 1000 tokens
- Total prefill: 1,000,000 tokens
- Time: 45 seconds
With RadixAttention (SGLang):
- First request prefill: 1000 tokens
- Subsequent prefills: 50 tokens (unique part only)
- Total prefill: 1000 + 999*50 = 50,950 tokens
- Time: 4.2 seconds
Speedup: 10.7x
26.12 Key Takeaways
Prefix caching: Amortize common prefixes across requests for 10x+ speedup.
Chunked prefill: Interleave prefill and decode to avoid blocking.
Disaggregation: Separate prefill (compute) from decode (memory) for cost efficiency.
Advanced speculation: Medusa, EAGLE, lookahead improve on basic speculation.
Priority scheduling: Not all requests are equal; schedule accordingly.
Multi-model: Efficient LoRA serving and model multiplexing.
vLLM internals: Understand PagedAttention block allocation and scheduler for tuning.
SGLang: Structured generation with RadixAttention for maximum prefix reuse.
Observability: Track TTFT, TPOT, throughput, cache hits for optimization.
The accompanying notebook lets you:
- Implement prefix caching with a radix tree
- Simulate chunked prefill scheduling
- Compare speculative decoding strategies
- Build a simple multi-model router
Notebook support for this chapter is in progress. For now, implement the patterns locally and evaluate latency/throughput on your serving stack.
26.13 Further Reading
- Zheng et al. (2023). “SGLang: Efficient Execution of Structured Language Model Programs”
- Kwon et al. (2023). “Efficient Memory Management for Large Language Model Serving with PagedAttention”
- vLLM Documentation
- SGLang Documentation
- Agrawal et al. (2024). “Sarathi: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills”
- Cai et al. (2024). “Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads”
- Li et al. (2024). “EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty”
- Patel et al. (2024). “Splitwise: Efficient Generative LLM Inference Using Phase Splitting”