26  Advanced LLM Serving

Production Techniques for High-Throughput Inference

Your LLM serving system handles 1,000 requests/second with P50 latency of 80ms. The product team wants 10,000 req/s at the same latency. Buying 10× more GPUs would cost $500K/month.

Can you get there without 10× the hardware? This chapter follows one request through a production system, showing why each optimization exists.

NoteProperty Spotlight: Redundancy + Sparsity + Separability

Prefix caching / RadixAttention exploits redundancy: many requests share common prefixes whose KV cache entries are identical. Speculative decoding exploits sparsity in the draft-target disagreement: most draft tokens are accepted, so verification is sparse. Disaggregated serving exploits separability: prefill and decode have different resource profiles and can be factored onto different hardware.

26.1 The Investigation: Anatomy of a Request

Before cataloging techniques, let’s trace a single request through a production LLM system and identify every bottleneck:

User sends: "Translate to French: Hello, how are you?"

1. Load balancer routes to a GPU server     (~1ms network)
2. Tokenizer encodes the prompt              (~0.5ms CPU)
3. KV cache lookup for shared prefix         (~0.1ms)
4. Prefill: process all prompt tokens        (~15ms, compute-bound)
5. Decode: generate tokens one at a time     (~8ms/token, memory-bound)
6. Detokenize and stream back                (~0.1ms/token)

Total for 20-token response: ~175ms
Target: <200ms P99. Barely making it at low load.

At 10,000 req/s, the system collapses — queueing delays explode (?sec-queueing). Each technique in this chapter attacks a specific bottleneck in this pipeline. Let’s see which ones give us that 10× throughput.

The Inference chapter covered the fundamentals: KV caching, continuous batching, speculative decoding. This chapter goes deeper with production-grade techniques.

26.2 Prefix Caching and RadixAttention

26.2.1 The Observation

Many requests share common prefixes:

System prompt (shared):
  "You are a helpful AI assistant. Be concise and accurate..."

User requests:
  Request 1: [system prompt] + "What is the capital of France?"
  Request 2: [system prompt] + "Explain quantum computing."
  Request 3: [system prompt] + "Write a haiku about coding."

The system prompt KV cache is recomputed for every request.

26.2.2 RadixAttention (SGLang)

Store KV caches in a radix tree, share across requests:

class RadixCache:
    """
    Radix tree for prefix caching.

    Key insight: Common prefixes share KV cache.
    """
    def __init__(self):
        self.root = RadixNode()

    def insert(self, tokens, kv_cache):
        """Insert KV cache for token sequence."""
        node = self.root
        for i, token in enumerate(tokens):
            if token not in node.children:
                node.children[token] = RadixNode()
            node = node.children[token]
            node.kv_cache = kv_cache[i]  # Store per-position KV

    def lookup(self, tokens):
        """Find longest matching prefix."""
        node = self.root
        matched = 0
        kv_cache = []

        for token in tokens:
            if token in node.children:
                node = node.children[token]
                kv_cache.append(node.kv_cache)
                matched += 1
            else:
                break

        return matched, kv_cache

# Usage:
cache = RadixCache()

# First request computes full KV cache
kv = model.prefill("You are helpful...")
cache.insert(tokenize("You are helpful..."), kv)

# Second request reuses cached prefix
prefix_len, cached_kv = cache.lookup(tokenize("You are helpful..."))
# Only compute KV for new tokens!

26.2.3 Memory Management

Prefix caching needs smart eviction:

class LRUPrefixCache:
    def __init__(self, max_memory_gb):
        self.max_memory = max_memory_gb * 1e9
        self.cache = OrderedDict()  # LRU order
        self.current_memory = 0

    def get_or_compute(self, prefix_tokens, model):
        key = hash(tuple(prefix_tokens))

        if key in self.cache:
            # Move to end (most recently used)
            self.cache.move_to_end(key)
            return self.cache[key]

        # Compute new KV cache
        kv = model.prefill(prefix_tokens)
        kv_size = self.compute_size(kv)

        # Evict if necessary
        while self.current_memory + kv_size > self.max_memory:
            self.evict_oldest()

        self.cache[key] = kv
        self.current_memory += kv_size
        return kv

Performance impact:

System prompt: 500 tokens
User query: 50 tokens

Without prefix caching: 550 tokens prefill per request
With prefix caching: 50 tokens prefill per request (first request caches)

Speedup: 11x for time-to-first-token on repeated prefixes

26.3 Chunked Prefill

26.3.1 The Problem

Long prefill blocks decode:

Timeline without chunked prefill:

Request A (long prompt): [==========PREFILL==========][decode][decode]...
Request B (arrives during A's prefill):
                         waiting... waiting... [PREFILL][decode]...
                         ↑ High latency for B!

26.3.2 The Solution

Split prefill into chunks, interleave with decode:

class ChunkedPrefillScheduler:
    def __init__(self, chunk_size=512):
        self.chunk_size = chunk_size
        self.prefill_queue = []
        self.decode_queue = []

    def schedule_iteration(self):
        """
        Each iteration: some prefill tokens + some decode tokens.
        """
        batch = []

        # Add decode requests (high priority for latency)
        for req in self.decode_queue[:MAX_DECODE]:
            batch.append(('decode', req, 1))  # 1 token per decode

        # Fill remaining capacity with prefill chunks
        remaining_capacity = MAX_TOKENS - len(batch)
        for req in self.prefill_queue:
            if remaining_capacity <= 0:
                break

            tokens_remaining = req.total_tokens - req.processed_tokens
            chunk = min(tokens_remaining, self.chunk_size, remaining_capacity)

            batch.append(('prefill', req, chunk))
            remaining_capacity -= chunk

        return batch

Timeline with chunked prefill:

Request A: [prefill_chunk][decode B][prefill_chunk][decode B]...
Request B: waiting... [prefill][decode][decode]...
                     ↑ Much lower latency!

26.4 Disaggregated Serving

26.4.1 The Observation

Prefill and decode have different characteristics:

Prefill:
  - Compute-bound (matrix multiplies)
  - High parallelism (many tokens)
  - Batch-friendly
  - High memory bandwidth for KV cache write

Decode:
  - Memory-bound (reading KV cache)
  - Sequential (one token at a time)
  - Latency-sensitive
  - Low compute intensity

26.4.2 Split Prefill and Decode

Run on different GPU pools:

class DisaggregatedServer:
    def __init__(self):
        self.prefill_workers = GPUPool(gpu_type='A100', count=4)
        self.decode_workers = GPUPool(gpu_type='L4', count=16)  # Cheaper GPUs
        self.kv_cache_store = DistributedKVStore()

    async def handle_request(self, prompt):
        # 1. Prefill on high-compute GPUs
        kv_cache = await self.prefill_workers.prefill(prompt)

        # 2. Store KV cache in distributed store
        cache_id = await self.kv_cache_store.put(kv_cache)

        # 3. Decode on memory-optimized GPUs
        output = await self.decode_workers.decode(cache_id)

        return output

Benefits: - Prefill: Use fewer, more powerful GPUs (compute-bound) - Decode: Use more, cheaper GPUs (memory-bound) - Better overall cost efficiency

Challenges: - KV cache transfer latency - Complexity of distributed state

26.5 Speculative Decoding Advances

26.5.1 Beyond Simple Speculation

The Inference chapter covered basic speculative decoding. Advanced techniques:

1. Medusa: Multiple Heads

Add prediction heads that output multiple future tokens:

class MedusaModel(nn.Module):
    def __init__(self, base_model, num_heads=4):
        super().__init__()
        self.base = base_model
        # Each head predicts a different future position
        self.heads = nn.ModuleList([
            nn.Linear(hidden_dim, vocab_size)
            for _ in range(num_heads)
        ])

    def forward(self, x):
        hidden = self.base(x)  # [batch, seq, hidden]

        # Main prediction (next token)
        logits_0 = self.base.lm_head(hidden[:, -1:])

        # Speculative predictions (tokens 2, 3, 4, ...)
        speculative_logits = [
            head(hidden[:, -1:])
            for head in self.heads
        ]

        return logits_0, speculative_logits

Verification: Single forward pass verifies all speculative tokens.

2. EAGLE: Draft with Hidden States

Use hidden states, not tokens, for drafting:

class EAGLEDraft:
    """
    Draft using feature-level prediction.
    """
    def draft(self, hidden_states, num_tokens):
        drafts = []

        h = hidden_states
        for _ in range(num_tokens):
            # Predict next hidden state
            h_next = self.feature_predictor(h)

            # Decode to token
            logits = self.lm_head(h_next)
            token = logits.argmax(-1)

            drafts.append(token)
            h = h_next

        return drafts

Advantage: Captures more information than token-level drafting.

3. Lookahead Decoding

Speculate using n-gram patterns from the input:

def lookahead_decode(model, prompt, n=5):
    """
    Use n-grams from prompt to speculate future tokens.
    """
    # Build n-gram table from prompt
    ngrams = extract_ngrams(prompt, n)

    tokens = prompt
    while not done:
        # Get current n-gram
        current_ngram = tuple(tokens[-n:])

        if current_ngram in ngrams:
            # Speculate based on seen pattern
            speculation = ngrams[current_ngram]
            # Verify with single forward pass
            verified = verify(model, tokens, speculation)
            tokens.extend(verified)
        else:
            # Standard decoding
            next_token = model.decode_one(tokens)
            tokens.append(next_token)

    return tokens

26.6 Batch Scheduling Strategies

WarningThe Batching Paradox

Batching improves throughput but hurts latency—a counter-intuitive tradeoff that catches many engineers.

Why it’s counter-intuitive: Bigger batches amortize overhead and improve GPU utilization. More throughput should mean faster responses, right?

Why it hurts latency: A request that arrives while a large batch is processing must wait. If you’re the last request added to a batch of 64, you wait for 63 other requests to complete before seeing your first token.

The hidden variable: Throughput optimization assumes latency tolerance. Interactive applications don’t have it.

This is why production systems need priority scheduling—not all requests can tolerate the latching delay that maximizes throughput.

26.6.1 Priority-Based Scheduling

Not all requests are equal:

class PriorityScheduler:
    def __init__(self):
        self.queues = {
            'realtime': deque(),   # Interactive chat
            'standard': deque(),   # API requests
            'batch': deque()       # Background jobs
        }
        self.weights = {'realtime': 10, 'standard': 5, 'batch': 1}

    def select_batch(self, max_tokens):
        """
        Weighted fair scheduling across priority classes.
        """
        batch = []
        tokens_used = 0

        # Round-robin with weights
        for priority in ['realtime', 'standard', 'batch']:
            weight = self.weights[priority]
            queue = self.queues[priority]

            for _ in range(weight):
                if queue and tokens_used < max_tokens:
                    req = queue.popleft()
                    batch.append(req)
                    tokens_used += req.tokens_needed

        return batch

26.6.2 Preemption

Sometimes you need to preempt running requests:

class PreemptibleScheduler:
    def __init__(self, preemption_threshold_ms=100):
        self.threshold = preemption_threshold_ms

    def maybe_preempt(self, running_requests, new_request):
        """
        Preempt low-priority request if high-priority arrives.
        """
        if new_request.priority != 'realtime':
            return None  # Only preempt for realtime

        # Find lowest priority request
        for req in sorted(running_requests, key=lambda r: r.priority):
            if req.priority == 'batch':
                # Save state for later resumption
                saved_state = self.save_kv_cache(req)
                self.queues['batch'].appendleft((req, saved_state))
                return req  # Preempt this one

        return None

26.7 Multi-Model Serving

26.7.1 Model Multiplexing

Serve multiple models efficiently:

class MultiModelServer:
    def __init__(self, gpu_memory_gb=80):
        self.memory_budget = gpu_memory_gb * 1e9
        self.loaded_models = {}  # model_id -> model
        self.model_sizes = {}

    def load_model(self, model_id):
        """Load model, evicting others if necessary."""
        size = self.get_model_size(model_id)

        # Evict until we have space
        while self.current_memory() + size > self.memory_budget:
            self.evict_lru_model()

        model = load_from_disk(model_id)
        self.loaded_models[model_id] = model
        self.model_sizes[model_id] = size

    async def infer(self, model_id, prompt):
        if model_id not in self.loaded_models:
            self.load_model(model_id)

        return await self.loaded_models[model_id].generate(prompt)

26.7.2 LoRA Serving

Efficient serving of many LoRA adapters:

class LoRAServer:
    def __init__(self, base_model):
        self.base = base_model
        self.adapters = {}  # adapter_id -> (A, B) matrices

    def load_adapter(self, adapter_id, adapter_path):
        """Load LoRA adapter (small, can cache many)."""
        A, B = load_lora_weights(adapter_path)
        self.adapters[adapter_id] = (A.cuda(), B.cuda())

    def forward(self, x, adapter_id):
        """Forward with specific adapter."""
        A, B = self.adapters[adapter_id]

        # Base model forward
        base_out = self.base(x)

        # LoRA forward
        lora_out = (x @ A) @ B

        return base_out + lora_out

    def batched_forward(self, requests):
        """
        Batch requests with different adapters.

        Uses batched matrix multiply with adapter indices.
        """
        # Group by adapter
        by_adapter = defaultdict(list)
        for req in requests:
            by_adapter[req.adapter_id].append(req)

        # Process each adapter group
        outputs = []
        for adapter_id, reqs in by_adapter.items():
            batch = torch.stack([r.input for r in reqs])
            out = self.forward(batch, adapter_id)
            outputs.extend(zip(reqs, out))

        return outputs

26.8 Monitoring and Observability

26.8.1 Key Metrics

class ServingMetrics:
    def __init__(self):
        self.metrics = defaultdict(list)

    def record_request(self, req):
        # Latency metrics
        self.metrics['ttft'].append(req.time_to_first_token)
        self.metrics['tpot'].append(req.time_per_output_token)
        self.metrics['total_latency'].append(req.total_latency)

        # Throughput metrics
        self.metrics['input_tokens'].append(req.input_length)
        self.metrics['output_tokens'].append(req.output_length)

        # Efficiency metrics
        self.metrics['batch_size'].append(req.batch_size)
        self.metrics['cache_hit_rate'].append(req.prefix_cache_hit)

    def report(self):
        return {
            'ttft_p50': np.percentile(self.metrics['ttft'], 50),
            'ttft_p99': np.percentile(self.metrics['ttft'], 99),
            'throughput_tps': sum(self.metrics['output_tokens']) / total_time,
            'cache_hit_rate': np.mean(self.metrics['cache_hit_rate']),
            'avg_batch_size': np.mean(self.metrics['batch_size']),
        }

26.8.2 Health Checks

class HealthChecker:
    async def check_health(self):
        checks = {
            'gpu_memory': self.check_gpu_memory(),
            'model_loaded': self.check_model_loaded(),
            'latency': await self.check_latency(),
            'queue_depth': self.check_queue_depth(),
        }

        healthy = all(checks.values())
        return {'healthy': healthy, 'checks': checks}

    def check_gpu_memory(self):
        used = torch.cuda.memory_allocated()
        total = torch.cuda.get_device_properties(0).total_memory
        return used / total < 0.95  # Alert at 95%

    async def check_latency(self):
        start = time.time()
        _ = await self.model.generate("test", max_tokens=1)
        latency = time.time() - start
        return latency < 0.5  # Health check should be < 500ms

26.9 Cost Optimization

26.9.1 Token Budgeting

class TokenBudget:
    def __init__(self, daily_budget_tokens):
        self.daily_budget = daily_budget_tokens
        self.used_today = 0
        self.reset_time = self.next_midnight()

    def can_serve(self, estimated_tokens):
        if time.time() > self.reset_time:
            self.used_today = 0
            self.reset_time = self.next_midnight()

        return self.used_today + estimated_tokens <= self.daily_budget

    def record_usage(self, tokens):
        self.used_today += tokens

    def estimate_request_cost(self, input_len, max_output):
        # Pricing model
        input_cost = input_len * 0.00001  # $0.01 per 1K input
        output_cost = max_output * 0.00003  # $0.03 per 1K output
        return input_cost + output_cost

26.9.2 Model Selection

Route to cheaper models when possible:

class ModelRouter:
    def __init__(self):
        self.models = {
            'small': {'model': 'llama-7b', 'cost': 0.001},
            'medium': {'model': 'llama-13b', 'cost': 0.003},
            'large': {'model': 'llama-70b', 'cost': 0.01},
        }

    def select_model(self, request):
        """Select cheapest model that can handle request."""
        complexity = self.estimate_complexity(request)

        if complexity < 0.3:
            return self.models['small']
        elif complexity < 0.7:
            return self.models['medium']
        else:
            return self.models['large']

    def estimate_complexity(self, request):
        """Estimate request complexity (0-1)."""
        # Heuristics: length, keywords, domain
        score = 0
        if 'code' in request.lower() or 'math' in request.lower():
            score += 0.4
        if len(request) > 500:
            score += 0.3
        return min(score, 1.0)

26.10 vLLM Internals

26.10.1 Architecture Overview

vLLM [1] is the most widely-deployed open-source LLM serving engine. Understanding its internals helps optimize production deployments.

vLLM Architecture:

┌─────────────────────────────────────────────────────────┐
│                   API Server (FastAPI)                  │
├─────────────────────────────────────────────────────────┤
│                   LLM Engine                            │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────────┐  │
│  │  Scheduler  │  │ KV Cache    │  │ Model Executor  │  │
│  │             │  │ Manager     │  │                 │  │
│  └─────────────┘  └─────────────┘  └─────────────────┘  │
├─────────────────────────────────────────────────────────┤
│                   Worker (per GPU)                      │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────────┐  │
│  │ Model       │  │ PagedAttention│ │ Sampler        │  │
│  │ Runner      │  │ Kernels     │  │                 │  │
│  └─────────────┘  └─────────────┘  └─────────────────┘  │
└─────────────────────────────────────────────────────────┘

26.10.2 The Scheduler Deep Dive

vLLM’s scheduler makes critical decisions every iteration:

class vLLMScheduler:
    """
    Simplified vLLM scheduler logic.
    """
    def __init__(self, config):
        self.waiting = deque()      # Requests waiting to start
        self.running = deque()      # Requests currently generating
        self.swapped = deque()      # Requests swapped to CPU
        self.block_manager = BlockManager(config)

    def schedule(self):
        """
        Core scheduling loop, called every iteration.
        """
        # Step 1: Try to resume swapped requests
        resumed = self._try_resume_swapped()

        # Step 2: Preempt if out of memory
        if self._is_oom():
            self._preempt_requests()

        # Step 3: Admit new requests
        admitted = self._admit_waiting()

        # Step 4: Build batch for this iteration
        batch = SchedulerOutputs(
            scheduled_running=list(self.running),
            scheduled_prefill=admitted,
            blocks_to_swap_in=resumed,
            blocks_to_swap_out=preempted,
        )

        return batch

    def _admit_waiting(self):
        """Admit waiting requests if we have memory."""
        admitted = []

        for seq_group in list(self.waiting):
            # Calculate memory needed
            num_required_blocks = self._get_required_blocks(seq_group)

            if self.block_manager.can_allocate(num_required_blocks):
                self.block_manager.allocate(seq_group, num_required_blocks)
                self.waiting.remove(seq_group)
                self.running.append(seq_group)
                admitted.append(seq_group)
            else:
                break  # Can't fit more

        return admitted

    def _preempt_requests(self):
        """Preempt lowest priority requests when OOM."""
        # Sort by priority, preempt lowest first
        sorted_running = sorted(self.running, key=lambda x: x.priority)

        for seq_group in sorted_running:
            if not self._is_oom():
                break

            # Swap KV cache to CPU
            self.block_manager.swap_out(seq_group)
            self.running.remove(seq_group)
            self.swapped.append(seq_group)

26.10.3 PagedAttention Memory Management

vLLM’s block allocator manages GPU memory like an OS manages RAM:

class BlockManager:
    """
    Manages KV cache memory in fixed-size blocks.
    """
    def __init__(self, block_size=16, num_gpu_blocks=1000, num_cpu_blocks=500):
        self.block_size = block_size

        # GPU block pool
        self.gpu_free_blocks = list(range(num_gpu_blocks))
        self.gpu_allocated = {}  # seq_id -> [block_ids]

        # CPU block pool (for swapping)
        self.cpu_free_blocks = list(range(num_cpu_blocks))
        self.cpu_allocated = {}

    def allocate(self, seq_group, num_blocks):
        """Allocate blocks for a sequence group."""
        if len(self.gpu_free_blocks) < num_blocks:
            raise OOMError("Not enough GPU blocks")

        blocks = [self.gpu_free_blocks.pop() for _ in range(num_blocks)]
        self.gpu_allocated[seq_group.id] = blocks
        return blocks

    def append_slots(self, seq_group, num_tokens):
        """
        Allocate additional slots as sequence grows.
        Called when generating new tokens.
        """
        blocks = self.gpu_allocated[seq_group.id]
        current_slots = len(blocks) * self.block_size

        if seq_group.num_tokens + num_tokens > current_slots:
            # Need new block
            if not self.gpu_free_blocks:
                return None  # Trigger preemption
            new_block = self.gpu_free_blocks.pop()
            blocks.append(new_block)

        return blocks

    def swap_out(self, seq_group):
        """Swap KV cache from GPU to CPU."""
        gpu_blocks = self.gpu_allocated.pop(seq_group.id)

        # Get CPU blocks
        cpu_blocks = [self.cpu_free_blocks.pop() for _ in gpu_blocks]

        # Actual copy happens in worker
        self.cpu_allocated[seq_group.id] = {
            'cpu_blocks': cpu_blocks,
            'gpu_blocks': gpu_blocks,  # Remember for swap_in
        }

        # Return GPU blocks to pool
        self.gpu_free_blocks.extend(gpu_blocks)

        return (gpu_blocks, cpu_blocks)  # Copy mapping

    def swap_in(self, seq_group):
        """Swap KV cache from CPU back to GPU."""
        info = self.cpu_allocated.pop(seq_group.id)
        cpu_blocks = info['cpu_blocks']

        # Get fresh GPU blocks
        gpu_blocks = [self.gpu_free_blocks.pop() for _ in cpu_blocks]
        self.gpu_allocated[seq_group.id] = gpu_blocks

        # Return CPU blocks to pool
        self.cpu_free_blocks.extend(cpu_blocks)

        return (cpu_blocks, gpu_blocks)  # Copy mapping

26.10.4 Key vLLM Configuration Options

from vllm import LLM, SamplingParams

# Memory optimization
llm = LLM(
    model="meta-llama/Llama-2-70b-chat-hf",

    # Memory management
    gpu_memory_utilization=0.90,  # Use 90% of GPU memory for KV cache
    swap_space=16,  # GB of CPU memory for swapping

    # Scheduling
    max_num_seqs=256,  # Max concurrent sequences
    max_num_batched_tokens=4096,  # Max tokens per iteration

    # Parallelism
    tensor_parallel_size=4,  # Distribute across 4 GPUs
    pipeline_parallel_size=1,

    # Quantization
    quantization="awq",  # or "gptq", "squeezellm"

    # KV cache optimization
    kv_cache_dtype="fp8_e5m2",  # FP8 KV cache (Hopper+)
    block_size=16,

    # Prefix caching
    enable_prefix_caching=True,
)

26.11 SGLang: Structured Generation

26.11.1 Beyond Text Generation

SGLang extends LLM serving with structured outputs and control flow:

import sglang as sgl

# Define structured program
@sgl.function
def extract_info(s, text):
    s += "Extract information from: " + text + "\n\n"

    s += "Name: "
    s += sgl.gen("name", stop="\n")

    s += "Age: "
    s += sgl.gen("age", stop="\n", regex=r"\d+")

    s += "Email: "
    s += sgl.gen("email", stop="\n", regex=r"[\w\.-]+@[\w\.-]+\.\w+")

    return {"name": s["name"], "age": s["age"], "email": s["email"]}

# Execute with RadixAttention
result = extract_info.run(text="John Smith, 35, john@example.com")

26.11.2 Constrained Decoding

SGLang enforces output structure during generation:

@sgl.function
def json_generation(s, schema):
    """Generate JSON matching a schema."""
    s += "Generate a JSON object matching this schema:\n"
    s += str(schema) + "\n\n"

    # Constrained generation - only valid JSON tokens allowed
    s += sgl.gen("json",
        json_schema=schema,
        max_tokens=500
    )

    return json.loads(s["json"])

# Schema enforcement
schema = {
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "scores": {"type": "array", "items": {"type": "number"}}
    },
    "required": ["name", "scores"]
}

result = json_generation.run(schema=schema)
# Guaranteed to be valid JSON matching schema!

26.11.3 Fork and Join

SGLang supports parallel generation paths:

@sgl.function
def multi_perspective(s, question):
    s += f"Question: {question}\n\n"

    # Fork into multiple perspectives
    fork_results = []

    perspectives = ["optimist", "pessimist", "realist"]

    for perspective in perspectives:
        with sgl.fork(s) as fork:
            fork += f"Answer as a {perspective}:\n"
            fork += sgl.gen(f"answer_{perspective}", max_tokens=100)
            fork_results.append(fork[f"answer_{perspective}"])

    # Join: Synthesize perspectives
    s += "\nSynthesis of all perspectives:\n"
    for p, r in zip(perspectives, fork_results):
        s += f"- {p}: {r}\n"

    s += "\nFinal balanced answer:\n"
    s += sgl.gen("final", max_tokens=200)

    return s["final"]

26.11.4 RadixAttention in Detail

SGLang’s RadixAttention implementation:

class SGLangRadixCache:
    """
    Production-grade RadixAttention implementation.
    """
    def __init__(self, max_total_tokens=100_000):
        self.root = RadixNode()
        self.max_tokens = max_total_tokens
        self.current_tokens = 0
        self.lock = threading.Lock()

    def match_prefix(self, input_ids: List[int]) -> Tuple[int, List[KVCache]]:
        """
        Find longest matching prefix in O(n) time.

        Returns:
            matched_length: Number of tokens matched
            kv_caches: KV cache for matched prefix
        """
        with self.lock:
            node = self.root
            matched = 0
            kv_caches = []

            for token in input_ids:
                if token in node.children:
                    child = node.children[token]
                    node = child
                    matched += 1
                    if child.kv_cache is not None:
                        kv_caches.append(child.kv_cache)
                else:
                    break

            # Update access time for LRU
            self._update_access(node)

            return matched, kv_caches

    def insert(self, input_ids: List[int], kv_cache: KVCache):
        """Insert new prefix into cache."""
        with self.lock:
            # Evict if necessary
            cache_size = self._estimate_size(kv_cache)
            while self.current_tokens + len(input_ids) > self.max_tokens:
                self._evict_lru()

            # Insert
            node = self.root
            for i, token in enumerate(input_ids):
                if token not in node.children:
                    node.children[token] = RadixNode()
                node = node.children[token]

                # Store KV cache at final position
                if i == len(input_ids) - 1:
                    node.kv_cache = kv_cache

            self.current_tokens += len(input_ids)

    def _evict_lru(self):
        """Evict least recently used prefix."""
        # Find LRU leaf node
        lru_node, lru_parent, lru_token = self._find_lru_leaf()

        if lru_node is not None:
            tokens_freed = lru_node.depth
            self.current_tokens -= tokens_freed
            del lru_parent.children[lru_token]

26.11.5 Performance Benefits

Benchmark: 1000 requests with shared system prompt

Without RadixAttention (vLLM PagedAttention only):
  - Prefill per request: 1000 tokens
  - Total prefill: 1,000,000 tokens
  - Time: 45 seconds

With RadixAttention (SGLang):
  - First request prefill: 1000 tokens
  - Subsequent prefills: 50 tokens (unique part only)
  - Total prefill: 1000 + 999*50 = 50,950 tokens
  - Time: 4.2 seconds

Speedup: 10.7x

26.12 Key Takeaways

  1. Prefix caching: Amortize common prefixes across requests for 10x+ speedup.

  2. Chunked prefill: Interleave prefill and decode to avoid blocking.

  3. Disaggregation: Separate prefill (compute) from decode (memory) for cost efficiency.

  4. Advanced speculation: Medusa, EAGLE, lookahead improve on basic speculation.

  5. Priority scheduling: Not all requests are equal; schedule accordingly.

  6. Multi-model: Efficient LoRA serving and model multiplexing.

  7. vLLM internals: Understand PagedAttention block allocation and scheduler for tuning.

  8. SGLang: Structured generation with RadixAttention for maximum prefix reuse.

  9. Observability: Track TTFT, TPOT, throughput, cache hits for optimization.

NoteTry It Yourself

The accompanying notebook lets you:

  • Implement prefix caching with a radix tree
  • Simulate chunked prefill scheduling
  • Compare speculative decoding strategies
  • Build a simple multi-model router

Notebook support for this chapter is in progress. For now, implement the patterns locally and evaluate latency/throughput on your serving stack.

26.13 Further Reading

  • Zheng et al. (2023). “SGLang: Efficient Execution of Structured Language Model Programs”
  • Kwon et al. (2023). “Efficient Memory Management for Large Language Model Serving with PagedAttention”
  • vLLM Documentation
  • SGLang Documentation
  • Agrawal et al. (2024). “Sarathi: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills”
  • Cai et al. (2024). “Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads”
  • Li et al. (2024). “EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty”
  • Patel et al. (2024). “Splitwise: Efficient Generative LLM Inference Using Phase Splitting”
[1]
W. Kwon et al., “Efficient memory management for large language model serving with PagedAttention,” arXiv preprint arXiv:2309.06180, 2023.