25  Advanced LLM Serving

Production Techniques for High-Throughput Inference


Serving LLMs at scale is a systems engineering challenge. Latency SLAs, cost efficiency, and reliability all matter.

This chapter covers the techniques that power production inference at companies serving billions of requests.

25.1 Beyond Basic Inference

Chapter 14 covered fundamentals: KV caching, continuous batching, speculative decoding. This chapter goes deeper into production-grade techniques.

Production LLM serving requirements:

Latency:      P99 < 200ms time-to-first-token
Throughput:   10,000+ requests/second
Cost:         < $0.001 per 1K tokens
Reliability:  99.9% uptime
Flexibility:  Multiple models, variable load

25.2 Prefix Caching and RadixAttention

25.2.1 The Observation

Many requests share common prefixes:

System prompt (shared):
  "You are a helpful AI assistant. Be concise and accurate..."

User requests:
  Request 1: [system prompt] + "What is the capital of France?"
  Request 2: [system prompt] + "Explain quantum computing."
  Request 3: [system prompt] + "Write a haiku about coding."

The system prompt KV cache is recomputed for every request.

25.2.2 RadixAttention (SGLang)

Store KV caches in a radix tree, share across requests:

class RadixCache:
    """
    Radix tree for prefix caching.

    Key insight: Common prefixes share KV cache.
    """
    def __init__(self):
        self.root = RadixNode()

    def insert(self, tokens, kv_cache):
        """Insert KV cache for token sequence."""
        node = self.root
        for i, token in enumerate(tokens):
            if token not in node.children:
                node.children[token] = RadixNode()
            node = node.children[token]
            node.kv_cache = kv_cache[i]  # Store per-position KV

    def lookup(self, tokens):
        """Find longest matching prefix."""
        node = self.root
        matched = 0
        kv_cache = []

        for token in tokens:
            if token in node.children:
                node = node.children[token]
                kv_cache.append(node.kv_cache)
                matched += 1
            else:
                break

        return matched, kv_cache

# Usage:
cache = RadixCache()

# First request computes full KV cache
kv = model.prefill("You are helpful...")
cache.insert(tokenize("You are helpful..."), kv)

# Second request reuses cached prefix
prefix_len, cached_kv = cache.lookup(tokenize("You are helpful..."))
# Only compute KV for new tokens!

25.2.3 Memory Management

Prefix caching needs smart eviction:

class LRUPrefixCache:
    def __init__(self, max_memory_gb):
        self.max_memory = max_memory_gb * 1e9
        self.cache = OrderedDict()  # LRU order
        self.current_memory = 0

    def get_or_compute(self, prefix_tokens, model):
        key = hash(tuple(prefix_tokens))

        if key in self.cache:
            # Move to end (most recently used)
            self.cache.move_to_end(key)
            return self.cache[key]

        # Compute new KV cache
        kv = model.prefill(prefix_tokens)
        kv_size = self.compute_size(kv)

        # Evict if necessary
        while self.current_memory + kv_size > self.max_memory:
            self.evict_oldest()

        self.cache[key] = kv
        self.current_memory += kv_size
        return kv

Performance impact:

System prompt: 500 tokens
User query: 50 tokens

Without prefix caching: 550 tokens prefill per request
With prefix caching: 50 tokens prefill per request (first request caches)

Speedup: 11x for time-to-first-token on repeated prefixes

25.3 Chunked Prefill

25.3.1 The Problem

Long prefill blocks decode:

Timeline without chunked prefill:

Request A (long prompt): [==========PREFILL==========][decode][decode]...
Request B (arrives during A's prefill):
                         waiting... waiting... [PREFILL][decode]...
                         ↑ High latency for B!

25.3.2 The Solution

Split prefill into chunks, interleave with decode:

class ChunkedPrefillScheduler:
    def __init__(self, chunk_size=512):
        self.chunk_size = chunk_size
        self.prefill_queue = []
        self.decode_queue = []

    def schedule_iteration(self):
        """
        Each iteration: some prefill tokens + some decode tokens.
        """
        batch = []

        # Add decode requests (high priority for latency)
        for req in self.decode_queue[:MAX_DECODE]:
            batch.append(('decode', req, 1))  # 1 token per decode

        # Fill remaining capacity with prefill chunks
        remaining_capacity = MAX_TOKENS - len(batch)
        for req in self.prefill_queue:
            if remaining_capacity <= 0:
                break

            tokens_remaining = req.total_tokens - req.processed_tokens
            chunk = min(tokens_remaining, self.chunk_size, remaining_capacity)

            batch.append(('prefill', req, chunk))
            remaining_capacity -= chunk

        return batch

Timeline with chunked prefill:

Request A: [prefill_chunk][decode B][prefill_chunk][decode B]...
Request B: waiting... [prefill][decode][decode]...
                     ↑ Much lower latency!

25.4 Disaggregated Serving

25.4.1 The Observation

Prefill and decode have different characteristics:

Prefill:
  - Compute-bound (matrix multiplies)
  - High parallelism (many tokens)
  - Batch-friendly
  - High memory bandwidth for KV cache write

Decode:
  - Memory-bound (reading KV cache)
  - Sequential (one token at a time)
  - Latency-sensitive
  - Low compute intensity

25.4.2 Split Prefill and Decode

Run on different GPU pools:

class DisaggregatedServer:
    def __init__(self):
        self.prefill_workers = GPUPool(gpu_type='A100', count=4)
        self.decode_workers = GPUPool(gpu_type='L4', count=16)  # Cheaper GPUs
        self.kv_cache_store = DistributedKVStore()

    async def handle_request(self, prompt):
        # 1. Prefill on high-compute GPUs
        kv_cache = await self.prefill_workers.prefill(prompt)

        # 2. Store KV cache in distributed store
        cache_id = await self.kv_cache_store.put(kv_cache)

        # 3. Decode on memory-optimized GPUs
        output = await self.decode_workers.decode(cache_id)

        return output

Benefits: - Prefill: Use fewer, more powerful GPUs (compute-bound) - Decode: Use more, cheaper GPUs (memory-bound) - Better overall cost efficiency

Challenges: - KV cache transfer latency - Complexity of distributed state

25.5 Speculative Decoding Advances

25.5.1 Beyond Simple Speculation

Chapter 14 covered basic speculative decoding. Advanced techniques:

1. Medusa: Multiple Heads

Add prediction heads that output multiple future tokens:

class MedusaModel(nn.Module):
    def __init__(self, base_model, num_heads=4):
        super().__init__()
        self.base = base_model
        # Each head predicts a different future position
        self.heads = nn.ModuleList([
            nn.Linear(hidden_dim, vocab_size)
            for _ in range(num_heads)
        ])

    def forward(self, x):
        hidden = self.base(x)  # [batch, seq, hidden]

        # Main prediction (next token)
        logits_0 = self.base.lm_head(hidden[:, -1:])

        # Speculative predictions (tokens 2, 3, 4, ...)
        speculative_logits = [
            head(hidden[:, -1:])
            for head in self.heads
        ]

        return logits_0, speculative_logits

Verification: Single forward pass verifies all speculative tokens.

2. EAGLE: Draft with Hidden States

Use hidden states, not tokens, for drafting:

class EAGLEDraft:
    """
    Draft using feature-level prediction.
    """
    def draft(self, hidden_states, num_tokens):
        drafts = []

        h = hidden_states
        for _ in range(num_tokens):
            # Predict next hidden state
            h_next = self.feature_predictor(h)

            # Decode to token
            logits = self.lm_head(h_next)
            token = logits.argmax(-1)

            drafts.append(token)
            h = h_next

        return drafts

Advantage: Captures more information than token-level drafting.

3. Lookahead Decoding

Speculate using n-gram patterns from the input:

def lookahead_decode(model, prompt, n=5):
    """
    Use n-grams from prompt to speculate future tokens.
    """
    # Build n-gram table from prompt
    ngrams = extract_ngrams(prompt, n)

    tokens = prompt
    while not done:
        # Get current n-gram
        current_ngram = tuple(tokens[-n:])

        if current_ngram in ngrams:
            # Speculate based on seen pattern
            speculation = ngrams[current_ngram]
            # Verify with single forward pass
            verified = verify(model, tokens, speculation)
            tokens.extend(verified)
        else:
            # Standard decoding
            next_token = model.decode_one(tokens)
            tokens.append(next_token)

    return tokens

25.6 Batch Scheduling Strategies

25.6.1 Priority-Based Scheduling

Not all requests are equal:

class PriorityScheduler:
    def __init__(self):
        self.queues = {
            'realtime': deque(),   # Interactive chat
            'standard': deque(),   # API requests
            'batch': deque()       # Background jobs
        }
        self.weights = {'realtime': 10, 'standard': 5, 'batch': 1}

    def select_batch(self, max_tokens):
        """
        Weighted fair scheduling across priority classes.
        """
        batch = []
        tokens_used = 0

        # Round-robin with weights
        for priority in ['realtime', 'standard', 'batch']:
            weight = self.weights[priority]
            queue = self.queues[priority]

            for _ in range(weight):
                if queue and tokens_used < max_tokens:
                    req = queue.popleft()
                    batch.append(req)
                    tokens_used += req.tokens_needed

        return batch

25.6.2 Preemption

Sometimes you need to preempt running requests:

class PreemptibleScheduler:
    def __init__(self, preemption_threshold_ms=100):
        self.threshold = preemption_threshold_ms

    def maybe_preempt(self, running_requests, new_request):
        """
        Preempt low-priority request if high-priority arrives.
        """
        if new_request.priority != 'realtime':
            return None  # Only preempt for realtime

        # Find lowest priority request
        for req in sorted(running_requests, key=lambda r: r.priority):
            if req.priority == 'batch':
                # Save state for later resumption
                saved_state = self.save_kv_cache(req)
                self.queues['batch'].appendleft((req, saved_state))
                return req  # Preempt this one

        return None

25.7 Multi-Model Serving

25.7.1 Model Multiplexing

Serve multiple models efficiently:

class MultiModelServer:
    def __init__(self, gpu_memory_gb=80):
        self.memory_budget = gpu_memory_gb * 1e9
        self.loaded_models = {}  # model_id -> model
        self.model_sizes = {}

    def load_model(self, model_id):
        """Load model, evicting others if necessary."""
        size = self.get_model_size(model_id)

        # Evict until we have space
        while self.current_memory() + size > self.memory_budget:
            self.evict_lru_model()

        model = load_from_disk(model_id)
        self.loaded_models[model_id] = model
        self.model_sizes[model_id] = size

    async def infer(self, model_id, prompt):
        if model_id not in self.loaded_models:
            self.load_model(model_id)

        return await self.loaded_models[model_id].generate(prompt)

25.7.2 LoRA Serving

Efficient serving of many LoRA adapters:

class LoRAServer:
    def __init__(self, base_model):
        self.base = base_model
        self.adapters = {}  # adapter_id -> (A, B) matrices

    def load_adapter(self, adapter_id, adapter_path):
        """Load LoRA adapter (small, can cache many)."""
        A, B = load_lora_weights(adapter_path)
        self.adapters[adapter_id] = (A.cuda(), B.cuda())

    def forward(self, x, adapter_id):
        """Forward with specific adapter."""
        A, B = self.adapters[adapter_id]

        # Base model forward
        base_out = self.base(x)

        # LoRA forward
        lora_out = (x @ A) @ B

        return base_out + lora_out

    def batched_forward(self, requests):
        """
        Batch requests with different adapters.

        Uses batched matrix multiply with adapter indices.
        """
        # Group by adapter
        by_adapter = defaultdict(list)
        for req in requests:
            by_adapter[req.adapter_id].append(req)

        # Process each adapter group
        outputs = []
        for adapter_id, reqs in by_adapter.items():
            batch = torch.stack([r.input for r in reqs])
            out = self.forward(batch, adapter_id)
            outputs.extend(zip(reqs, out))

        return outputs

25.8 Monitoring and Observability

25.8.1 Key Metrics

class ServingMetrics:
    def __init__(self):
        self.metrics = defaultdict(list)

    def record_request(self, req):
        # Latency metrics
        self.metrics['ttft'].append(req.time_to_first_token)
        self.metrics['tpot'].append(req.time_per_output_token)
        self.metrics['total_latency'].append(req.total_latency)

        # Throughput metrics
        self.metrics['input_tokens'].append(req.input_length)
        self.metrics['output_tokens'].append(req.output_length)

        # Efficiency metrics
        self.metrics['batch_size'].append(req.batch_size)
        self.metrics['cache_hit_rate'].append(req.prefix_cache_hit)

    def report(self):
        return {
            'ttft_p50': np.percentile(self.metrics['ttft'], 50),
            'ttft_p99': np.percentile(self.metrics['ttft'], 99),
            'throughput_tps': sum(self.metrics['output_tokens']) / total_time,
            'cache_hit_rate': np.mean(self.metrics['cache_hit_rate']),
            'avg_batch_size': np.mean(self.metrics['batch_size']),
        }

25.8.2 Health Checks

class HealthChecker:
    async def check_health(self):
        checks = {
            'gpu_memory': self.check_gpu_memory(),
            'model_loaded': self.check_model_loaded(),
            'latency': await self.check_latency(),
            'queue_depth': self.check_queue_depth(),
        }

        healthy = all(checks.values())
        return {'healthy': healthy, 'checks': checks}

    def check_gpu_memory(self):
        used = torch.cuda.memory_allocated()
        total = torch.cuda.get_device_properties(0).total_memory
        return used / total < 0.95  # Alert at 95%

    async def check_latency(self):
        start = time.time()
        _ = await self.model.generate("test", max_tokens=1)
        latency = time.time() - start
        return latency < 0.5  # Health check should be < 500ms

25.9 Cost Optimization

25.9.1 Token Budgeting

class TokenBudget:
    def __init__(self, daily_budget_tokens):
        self.daily_budget = daily_budget_tokens
        self.used_today = 0
        self.reset_time = self.next_midnight()

    def can_serve(self, estimated_tokens):
        if time.time() > self.reset_time:
            self.used_today = 0
            self.reset_time = self.next_midnight()

        return self.used_today + estimated_tokens <= self.daily_budget

    def record_usage(self, tokens):
        self.used_today += tokens

    def estimate_request_cost(self, input_len, max_output):
        # Pricing model
        input_cost = input_len * 0.00001  # $0.01 per 1K input
        output_cost = max_output * 0.00003  # $0.03 per 1K output
        return input_cost + output_cost

25.9.2 Model Selection

Route to cheaper models when possible:

class ModelRouter:
    def __init__(self):
        self.models = {
            'small': {'model': 'llama-7b', 'cost': 0.001},
            'medium': {'model': 'llama-13b', 'cost': 0.003},
            'large': {'model': 'llama-70b', 'cost': 0.01},
        }

    def select_model(self, request):
        """Select cheapest model that can handle request."""
        complexity = self.estimate_complexity(request)

        if complexity < 0.3:
            return self.models['small']
        elif complexity < 0.7:
            return self.models['medium']
        else:
            return self.models['large']

    def estimate_complexity(self, request):
        """Estimate request complexity (0-1)."""
        # Heuristics: length, keywords, domain
        score = 0
        if 'code' in request.lower() or 'math' in request.lower():
            score += 0.4
        if len(request) > 500:
            score += 0.3
        return min(score, 1.0)

25.10 vLLM Internals

25.10.1 Architecture Overview

vLLM is the most widely-deployed open-source LLM serving engine. Understanding its internals helps optimize production deployments.

vLLM Architecture:

┌─────────────────────────────────────────────────────────┐
│                   API Server (FastAPI)                  │
├─────────────────────────────────────────────────────────┤
│                   LLM Engine                            │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────────┐  │
│  │  Scheduler  │  │ KV Cache    │  │ Model Executor  │  │
│  │             │  │ Manager     │  │                 │  │
│  └─────────────┘  └─────────────┘  └─────────────────┘  │
├─────────────────────────────────────────────────────────┤
│                   Worker (per GPU)                      │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────────┐  │
│  │ Model       │  │ PagedAttention│ │ Sampler        │  │
│  │ Runner      │  │ Kernels     │  │                 │  │
│  └─────────────┘  └─────────────┘  └─────────────────┘  │
└─────────────────────────────────────────────────────────┘

25.10.2 The Scheduler Deep Dive

vLLM’s scheduler makes critical decisions every iteration:

class vLLMScheduler:
    """
    Simplified vLLM scheduler logic.
    """
    def __init__(self, config):
        self.waiting = deque()      # Requests waiting to start
        self.running = deque()      # Requests currently generating
        self.swapped = deque()      # Requests swapped to CPU
        self.block_manager = BlockManager(config)

    def schedule(self):
        """
        Core scheduling loop, called every iteration.
        """
        # Step 1: Try to resume swapped requests
        resumed = self._try_resume_swapped()

        # Step 2: Preempt if out of memory
        if self._is_oom():
            self._preempt_requests()

        # Step 3: Admit new requests
        admitted = self._admit_waiting()

        # Step 4: Build batch for this iteration
        batch = SchedulerOutputs(
            scheduled_running=list(self.running),
            scheduled_prefill=admitted,
            blocks_to_swap_in=resumed,
            blocks_to_swap_out=preempted,
        )

        return batch

    def _admit_waiting(self):
        """Admit waiting requests if we have memory."""
        admitted = []

        for seq_group in list(self.waiting):
            # Calculate memory needed
            num_required_blocks = self._get_required_blocks(seq_group)

            if self.block_manager.can_allocate(num_required_blocks):
                self.block_manager.allocate(seq_group, num_required_blocks)
                self.waiting.remove(seq_group)
                self.running.append(seq_group)
                admitted.append(seq_group)
            else:
                break  # Can't fit more

        return admitted

    def _preempt_requests(self):
        """Preempt lowest priority requests when OOM."""
        # Sort by priority, preempt lowest first
        sorted_running = sorted(self.running, key=lambda x: x.priority)

        for seq_group in sorted_running:
            if not self._is_oom():
                break

            # Swap KV cache to CPU
            self.block_manager.swap_out(seq_group)
            self.running.remove(seq_group)
            self.swapped.append(seq_group)

25.10.3 PagedAttention Memory Management

vLLM’s block allocator manages GPU memory like an OS manages RAM:

class BlockManager:
    """
    Manages KV cache memory in fixed-size blocks.
    """
    def __init__(self, block_size=16, num_gpu_blocks=1000, num_cpu_blocks=500):
        self.block_size = block_size

        # GPU block pool
        self.gpu_free_blocks = list(range(num_gpu_blocks))
        self.gpu_allocated = {}  # seq_id -> [block_ids]

        # CPU block pool (for swapping)
        self.cpu_free_blocks = list(range(num_cpu_blocks))
        self.cpu_allocated = {}

    def allocate(self, seq_group, num_blocks):
        """Allocate blocks for a sequence group."""
        if len(self.gpu_free_blocks) < num_blocks:
            raise OOMError("Not enough GPU blocks")

        blocks = [self.gpu_free_blocks.pop() for _ in range(num_blocks)]
        self.gpu_allocated[seq_group.id] = blocks
        return blocks

    def append_slots(self, seq_group, num_tokens):
        """
        Allocate additional slots as sequence grows.
        Called when generating new tokens.
        """
        blocks = self.gpu_allocated[seq_group.id]
        current_slots = len(blocks) * self.block_size

        if seq_group.num_tokens + num_tokens > current_slots:
            # Need new block
            if not self.gpu_free_blocks:
                return None  # Trigger preemption
            new_block = self.gpu_free_blocks.pop()
            blocks.append(new_block)

        return blocks

    def swap_out(self, seq_group):
        """Swap KV cache from GPU to CPU."""
        gpu_blocks = self.gpu_allocated.pop(seq_group.id)

        # Get CPU blocks
        cpu_blocks = [self.cpu_free_blocks.pop() for _ in gpu_blocks]

        # Actual copy happens in worker
        self.cpu_allocated[seq_group.id] = {
            'cpu_blocks': cpu_blocks,
            'gpu_blocks': gpu_blocks,  # Remember for swap_in
        }

        # Return GPU blocks to pool
        self.gpu_free_blocks.extend(gpu_blocks)

        return (gpu_blocks, cpu_blocks)  # Copy mapping

    def swap_in(self, seq_group):
        """Swap KV cache from CPU back to GPU."""
        info = self.cpu_allocated.pop(seq_group.id)
        cpu_blocks = info['cpu_blocks']

        # Get fresh GPU blocks
        gpu_blocks = [self.gpu_free_blocks.pop() for _ in cpu_blocks]
        self.gpu_allocated[seq_group.id] = gpu_blocks

        # Return CPU blocks to pool
        self.cpu_free_blocks.extend(cpu_blocks)

        return (cpu_blocks, gpu_blocks)  # Copy mapping

25.10.4 Key vLLM Configuration Options

from vllm import LLM, SamplingParams

# Memory optimization
llm = LLM(
    model="meta-llama/Llama-2-70b-chat-hf",

    # Memory management
    gpu_memory_utilization=0.90,  # Use 90% of GPU memory for KV cache
    swap_space=16,  # GB of CPU memory for swapping

    # Scheduling
    max_num_seqs=256,  # Max concurrent sequences
    max_num_batched_tokens=4096,  # Max tokens per iteration

    # Parallelism
    tensor_parallel_size=4,  # Distribute across 4 GPUs
    pipeline_parallel_size=1,

    # Quantization
    quantization="awq",  # or "gptq", "squeezellm"

    # KV cache optimization
    kv_cache_dtype="fp8_e5m2",  # FP8 KV cache (Hopper+)
    block_size=16,

    # Prefix caching
    enable_prefix_caching=True,
)

25.11 SGLang: Structured Generation

25.11.1 Beyond Text Generation

SGLang extends LLM serving with structured outputs and control flow:

import sglang as sgl

# Define structured program
@sgl.function
def extract_info(s, text):
    s += "Extract information from: " + text + "\n\n"

    s += "Name: "
    s += sgl.gen("name", stop="\n")

    s += "Age: "
    s += sgl.gen("age", stop="\n", regex=r"\d+")

    s += "Email: "
    s += sgl.gen("email", stop="\n", regex=r"[\w\.-]+@[\w\.-]+\.\w+")

    return {"name": s["name"], "age": s["age"], "email": s["email"]}

# Execute with RadixAttention
result = extract_info.run(text="John Smith, 35, john@example.com")

25.11.2 Constrained Decoding

SGLang enforces output structure during generation:

@sgl.function
def json_generation(s, schema):
    """Generate JSON matching a schema."""
    s += "Generate a JSON object matching this schema:\n"
    s += str(schema) + "\n\n"

    # Constrained generation - only valid JSON tokens allowed
    s += sgl.gen("json",
        json_schema=schema,
        max_tokens=500
    )

    return json.loads(s["json"])

# Schema enforcement
schema = {
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "scores": {"type": "array", "items": {"type": "number"}}
    },
    "required": ["name", "scores"]
}

result = json_generation.run(schema=schema)
# Guaranteed to be valid JSON matching schema!

25.11.3 Fork and Join

SGLang supports parallel generation paths:

@sgl.function
def multi_perspective(s, question):
    s += f"Question: {question}\n\n"

    # Fork into multiple perspectives
    fork_results = []

    perspectives = ["optimist", "pessimist", "realist"]

    for perspective in perspectives:
        with sgl.fork(s) as fork:
            fork += f"Answer as a {perspective}:\n"
            fork += sgl.gen(f"answer_{perspective}", max_tokens=100)
            fork_results.append(fork[f"answer_{perspective}"])

    # Join: Synthesize perspectives
    s += "\nSynthesis of all perspectives:\n"
    for p, r in zip(perspectives, fork_results):
        s += f"- {p}: {r}\n"

    s += "\nFinal balanced answer:\n"
    s += sgl.gen("final", max_tokens=200)

    return s["final"]

25.11.4 RadixAttention in Detail

SGLang’s RadixAttention implementation:

class SGLangRadixCache:
    """
    Production-grade RadixAttention implementation.
    """
    def __init__(self, max_total_tokens=100_000):
        self.root = RadixNode()
        self.max_tokens = max_total_tokens
        self.current_tokens = 0
        self.lock = threading.Lock()

    def match_prefix(self, input_ids: List[int]) -> Tuple[int, List[KVCache]]:
        """
        Find longest matching prefix in O(n) time.

        Returns:
            matched_length: Number of tokens matched
            kv_caches: KV cache for matched prefix
        """
        with self.lock:
            node = self.root
            matched = 0
            kv_caches = []

            for token in input_ids:
                if token in node.children:
                    child = node.children[token]
                    node = child
                    matched += 1
                    if child.kv_cache is not None:
                        kv_caches.append(child.kv_cache)
                else:
                    break

            # Update access time for LRU
            self._update_access(node)

            return matched, kv_caches

    def insert(self, input_ids: List[int], kv_cache: KVCache):
        """Insert new prefix into cache."""
        with self.lock:
            # Evict if necessary
            cache_size = self._estimate_size(kv_cache)
            while self.current_tokens + len(input_ids) > self.max_tokens:
                self._evict_lru()

            # Insert
            node = self.root
            for i, token in enumerate(input_ids):
                if token not in node.children:
                    node.children[token] = RadixNode()
                node = node.children[token]

                # Store KV cache at final position
                if i == len(input_ids) - 1:
                    node.kv_cache = kv_cache

            self.current_tokens += len(input_ids)

    def _evict_lru(self):
        """Evict least recently used prefix."""
        # Find LRU leaf node
        lru_node, lru_parent, lru_token = self._find_lru_leaf()

        if lru_node is not None:
            tokens_freed = lru_node.depth
            self.current_tokens -= tokens_freed
            del lru_parent.children[lru_token]

25.11.5 Performance Benefits

Benchmark: 1000 requests with shared system prompt

Without RadixAttention (vLLM PagedAttention only):
  - Prefill per request: 1000 tokens
  - Total prefill: 1,000,000 tokens
  - Time: 45 seconds

With RadixAttention (SGLang):
  - First request prefill: 1000 tokens
  - Subsequent prefills: 50 tokens (unique part only)
  - Total prefill: 1000 + 999*50 = 50,950 tokens
  - Time: 4.2 seconds

Speedup: 10.7x

25.12 Key Takeaways

  1. Prefix caching: Amortize common prefixes across requests for 10x+ speedup.

  2. Chunked prefill: Interleave prefill and decode to avoid blocking.

  3. Disaggregation: Separate prefill (compute) from decode (memory) for cost efficiency.

  4. Advanced speculation: Medusa, EAGLE, lookahead improve on basic speculation.

  5. Priority scheduling: Not all requests are equal; schedule accordingly.

  6. Multi-model: Efficient LoRA serving and model multiplexing.

  7. vLLM internals: Understand PagedAttention block allocation and scheduler for tuning.

  8. SGLang: Structured generation with RadixAttention for maximum prefix reuse.

  9. Observability: Track TTFT, TPOT, throughput, cache hits for optimization.

NoteTry It Yourself

The accompanying notebook lets you:

  • Implement prefix caching with a radix tree
  • Simulate chunked prefill scheduling
  • Compare speculative decoding strategies
  • Build a simple multi-model router

Open In Colab

25.13 Further Reading

  • Zheng et al. (2023). “SGLang: Efficient Execution of Structured Language Model Programs”
  • Kwon et al. (2023). “Efficient Memory Management for Large Language Model Serving with PagedAttention”
  • vLLM Documentation
  • SGLang Documentation
  • Agrawal et al. (2024). “Sarathi: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills”
  • Cai et al. (2024). “Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads”
  • Li et al. (2024). “EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty”
  • Patel et al. (2024). “Splitwise: Efficient Generative LLM Inference Using Phase Splitting”