“”“
Steady batching = iteration-level scheduling + ragged (packed) batching.
Two approaches are in contrast (each run BATCH_SIZE sequences concurrently, so the
comparability is slot-for-slot honest):
1. Static batching (baseline):
Prompts are processed BATCH_SIZE at a time. Every wave is padded to a
widespread size and run collectively till the LONGEST request in that wave
finishes; a tough “batch barrier” then has to clear earlier than the subsequent wave
begins. Brief requests sit idle behind the barrier.
2. Steady batching (production-aligned):
Two concepts mix to maintain the GPU busy.
(a) Iteration-level scheduling: the second a sequence finishes it frees
its slot, and the subsequent queued immediate is admitted on the SAME step –
no ready for the remainder of the batch.
(b) Ragged / packed batching – the half that makes it actually “steady“:
as a substitute of padding each sequence into an oblong [B, max_len]
tensor, ALL in-flight tokens are concatenated right into a single unpadded
[1, total_tokens] row and run in ONE ahead cross. A block-diagonal
causal consideration masks stops tokens from attending throughout sequence
boundaries, so packing is mathematically similar to working every
sequence by itself (verified: grasping output matches per-prompt
era token-for-token).
As a result of consideration is ruled solely by the masks, a newly admitted
immediate’s multi-token PREFILL rides alongside in the identical ahead cross as
each different sequence’s single-token DECODE step. Prefill and decode are
fused: no padding, no separate prefill cross.
KV cache: every sequence retains its personal DynamicCache; each step the caches
are concatenated alongside the time axis into one packed cache, and the newly
computed KV is scattered again per sequence. (Actual engines retailer the
cache in fixed-size pages – “paged consideration” – to keep away from this per-step
reassembly, however the consideration/masking logic is precisely what you see right here.)
““”
import time
import torch
from dataclasses import dataclass, subject
from typing import Non-obligatory
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from transformers.cache_utils import DynamicLayer
MODEL_ID = “openai-community/gpt2” # swap for any causal LM
BATCH_SIZE = 3 # max concurrent sequences (slots)
def _device_sync(mannequin) -> None:
“”“Block till queued GPU work finishes, so timings are correct.”“”
if mannequin.system.kind == “cuda”:
torch.cuda.synchronize()
elif mannequin.system.kind == “mps”:
torch.mps.synchronize()
def static_batching(requests: record[tuple[str, int]], tokenizer, mannequin) -> record[str]:
“”“Baseline. Course of requests BATCH_SIZE at a time; every wave runs collectively
till its LONGEST request finishes, then a batch barrier clears earlier than the
subsequent wave begins.
Draw back: quick requests in a wave idle till the wave’s longest is finished –
and no slot could be refilled till the entire wave clears the barrier.
““”
if not requests:
return []
tokenizer.padding_side = “left”
outcomes: dict[int, str] = {}
listed = record(enumerate(requests)) # (req_id, (immediate, cap))
for wave_start in vary(0, len(listed), BATCH_SIZE):
wave = listed[wave_start: wave_start + BATCH_SIZE]
wave_max = max(cap for _, (_, cap) in wave)
# Present which request occupies every slot on this wave.
for slot, (req_id, (immediate, cap)) in enumerate(wave):
print(f” ++ slot {slot} <- req {req_id} ({cap} tok cap): {immediate!r}”, flush=True)
prompts = [p for _, (p, _) in wave]
inputs = tokenizer(
prompts, return_tensors=“pt”, padding=True, truncation=True
).to(mannequin.system)
with torch.no_grad():
output_ids = mannequin.generate(
**inputs,
max_new_tokens=wave_max, # complete wave decodes to the longest
pad_token_id=tokenizer.eos_token_id,
do_sample=False,
)
width = inputs.input_ids.form[1]
print(
f” *** batch barrier: all {len(wave)} slots anticipate the longest “
f“({wave_max} tokens) ***”,
flush=True,
)
for slot, ((req_id, (immediate, cap)), row) in enumerate(zip(wave, output_ids)):
textual content = immediate + tokenizer.decode(row[width:width + cap], skip_special_tokens=True)
outcomes[req_id] = textual content
print(
f” — slot {slot} finished req {req_id} ({cap}/{wave_max} tokens): {textual content[:90]}”,
flush=True,
)
return [results[k] for ok in sorted(outcomes)]
@dataclass
class Sequence:
“”“State for a single in-flight sequence.”“”
req_id: int # unique request index (for ordering outcomes)
immediate: str
max_new_tokens: int # per-request cap so quick requests end early
# Tokens to feed on the NEXT step: the entire immediate proper after admission
# (prefill), then a single token per step (decode).
pending_ids: record[int]
# Per-sequence KV-cache; None till this sequence has run as soon as.
kv_cache: Non-obligatory[DynamicCache] = None
kv_len: int = 0 # variety of cached tokens (immediate + generated)
tokens_generated: int = 0
output_ids: record[int] = subject(default_factory=record)
def _make_cache(layers_kv: record[tuple[torch.Tensor, torch.Tensor]]) -> DynamicCache:
“”“Construct a DynamicCache from express per-layer (keys, values) tensors.
We SET the tensors immediately as a substitute of calling DynamicLayer.replace() (which
would append), as a result of we’re assembling caches from scratch every step.
““”
cache = DynamicCache()
for ok, v in layers_kv:
layer = DynamicLayer()
layer.lazy_initialization(ok, v)
layer.keys = ok
layer.values = v
cache.layers.append(layer)
return cache
def _ragged_step(seqs: record[Sequence], mannequin, system, dtype) -> record[int]:
“”“Run ONE packed ahead cross over each energetic sequence.
All sequences are flattened right into a single row (batch dim = 1):
input_ids [1, total_q] – each sequence’s pending tokens
position_ids [1, total_q] – every token’s place in ITS sequence
attention_mask [1, 1, total_q, total_kv + total_q] – block-diagonal causal
past_key_values packed cache [1, H, total_kv, D]
total_q = sum of pending tokens (1 per decoding seq, prompt_len per new seq)
total_kv = sum of already-cached tokens throughout sequences
Returns the subsequent grasping token for every sequence (similar order as “seqs“).
““”
q_lens = [len(s.pending_ids) for s in seqs]
total_q = sum(q_lens)
total_kv = sum(s.kv_len for s in seqs)
# Packed inputs: concatenate each sequence’s pending tokens into one row.
flat_ids = [t for s in seqs for t in s.pending_ids]
input_ids = torch.tensor([flat_ids], dtype=torch.lengthy, system=system)
# Tag each KEY and each QUERY token with (sequence index, position-in-sequence).
# Key house is laid out as [ cached tokens | this step’s new tokens ], matching
# how the mannequin appends new KV to the top of the packed cache.
key_seq, key_pos = [], []
for si, s in enumerate(seqs): # cached block
for p in vary(s.kv_len):
key_seq.append(si)
key_pos.append(p)
q_seq, q_pos = [], []
for si, s in enumerate(seqs): # new block (additionally queries)
for j in vary(len(s.pending_ids)):
pos = s.kv_len + j
q_seq.append(si)
q_pos.append(pos)
key_seq.append(si)
key_pos.append(pos)
q_seq_t = torch.tensor(q_seq, system=system)
q_pos_t = torch.tensor(q_pos, system=system)
key_seq_t = torch.tensor(key_seq, system=system)
key_pos_t = torch.tensor(key_pos, system=system)
# Every token’s positional embedding makes use of its personal sequence place, not its
# offset within the packed row.
position_ids = q_pos_t.unsqueeze(0) # [1, total_q]
# Block-diagonal causal masks: a question could attend to a key provided that they belong
# to the SAME sequence (block-diagonal) and the secret is not sooner or later
# (causal). That is the entire trick – it makes packing equal to working
# every sequence individually. 0.0 = attend, large-negative = blocked (additive).
similar = q_seq_t[:, None] == key_seq_t[None, :]
causal = key_pos_t[None, :] <= q_pos_t[:, None]
allowed = similar & causal # [total_q, total_kv + total_q]
attn_mask = torch.zeros(1, 1, total_q, total_kv + total_q, dtype=dtype, system=system)
attn_mask.masked_fill_(~allowed[None, None], torch.finfo(dtype).min)
# Packed KV-cache: concatenate every sequence’s cache alongside the time axis.
# Freshly admitted sequences (kv_len == 0) contribute nothing right here.
cached = [s for s in seqs if s.kv_len > 0]
if cached:
num_layers = len(cached[0].kv_cache.layers)
layers_kv = []
for l in vary(num_layers):
ks = torch.cat([s.kv_cache.layers[l].keys for s in cached], dim=2)
vs = torch.cat([s.kv_cache.layers[l].values for s in cached], dim=2)
layers_kv.append((ks, vs))
previous = _make_cache(layers_kv)
else:
previous = DynamicCache()
with torch.no_grad():
out = mannequin(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=position_ids,
past_key_values=previous,
use_cache=True,
)
# Grasping subsequent token for every sequence: learn the logits at its LAST pending
# token (for a prefilling sequence that’s the remaining immediate token).
logits = out.logits[0] # [total_q, vocab]
offsets, last_idx, off = [], [], 0
for ql in q_lens:
offsets.append(off)
last_idx.append(off + ql – 1)
off += ql
next_tokens = [int(logits[i].argmax()) for i in last_idx]
# Scatter the newly computed KV again to every sequence. The output cache is
# [ old packed block | new packed block ]; slice this step’s new block per
# sequence and append it to that sequence’s personal cache.
out_kv = out.past_key_values
num_layers = len(out_kv.layers)
for si, s in enumerate(seqs):
o, ql = offsets[si], q_lens[si]
layers_kv = []
for l in vary(num_layers):
k_new = out_kv.layers[l].keys[:, :, total_kv + o: total_kv + o + ql, :]
v_new = out_kv.layers[l].values[:, :, total_kv + o: total_kv + o + ql, :]
if s.kv_cache is None:
layers_kv.append((k_new, v_new))
else:
layers_kv.append((
torch.cat([s.kv_cache.layers[l].keys, k_new], dim=2),
torch.cat([s.kv_cache.layers[l].values, v_new], dim=2),
))
s.kv_cache = _make_cache(layers_kv)
s.kv_len += ql
return next_tokens
def visualize_ragged_step(seqs: record[Sequence], tokenizer, title: str, slot_ids: record[int]) -> None:
“”“Illustrative print of ONE packed step: the concatenated enter row and the
block-diagonal causal consideration masks.
This mirrors the masking logic in _ragged_step (recomputed right here as a boolean
grid purely for show) so you may SEE that sequences are packed collectively
but remoted by the masks. Every sequence will get a letter A, B, C, …
# = a question could attend to that key . = blocked
““”
labels = [chr(ord(“A”) + s.req_id) for s in seqs]
q_lens = [len(s.pending_ids) for s in seqs]
total_q = sum(q_lens)
total_kv = sum(s.kv_len for s in seqs)
print(f“n{‘=’ * 72}n {title}”)
print(f” total_q={total_q} tokens fed this step | total_kv={total_kv} cached”)
print(f” {len(seqs)} sequences packed into ONE unpadded row of form [1, {total_q}]:n”)
# The concatenated tokens, grouped per sequence (that is the “ragged” row).
for i, s in enumerate(seqs):
variety = f“PREFILL({q_lens[i]})” if s.kv_len == 0 else f“decode({q_lens[i]})”
toks = ” “.be part of(repr(tokenizer.decode([t])) for t in s.pending_ids)
if len(toks) > 66:
toks = toks[:63] + “…”
print(f” {labels[i]} = slot {slot_ids[i]} {variety:<11} {toks}”)
# Rebuild the block-diagonal causal masks as a boolean grid for show.
key_seq, key_pos = [], []
for si, s in enumerate(seqs): # cached keys
key_seq += [si] * s.kv_len
key_pos += record(vary(s.kv_len))
q_seq, q_pos = [], []
for si, s in enumerate(seqs): # new keys / queries
for j in vary(q_lens[si]):
q_seq.append(si)
q_pos.append(s.kv_len + j)
key_seq += q_seq
key_pos += q_pos
q_seq_t, q_pos_t = torch.tensor(q_seq), torch.tensor(q_pos)
key_seq_t, key_pos_t = torch.tensor(key_seq), torch.tensor(key_pos)
allowed = (q_seq_t[:, None] == key_seq_t[None, :]) & (key_pos_t[None, :] <= q_pos_t[:, None])
Okay = len(key_seq)
def row_str(cells):
# House between sequence teams; ‘ | ‘ on the cached -> new-tokens cut up.
out = []
for ki in vary(Okay):
if total_kv > 0 and ki == total_kv:
out.append(” | “)
elif ki > 0 and key_seq[ki] != key_seq[ki – 1]:
out.append(” “)
out.append(cells[ki])
return “”.be part of(out)
def line(left, cells):
return f“{left:>7} “ + row_str(cells)
print(f“n block-diagonal causal masks (row = question, col = key) # attend . blocked”)
if total_kv > 0:
print(f” key format: [ cached KV | this step’s new tokens ]”)
print(line(“keys:”, [labels[key_seq[ki]] for ki in vary(Okay)]))
for qi in vary(total_q):
cells = [“#” if allowed[qi, ki] else “.” for ki in vary(Okay)]
print(line(f“{labels[q_seq[qi]]} p{q_pos[qi]}”, cells))
def continuous_batching(requests: record[tuple[str, int]], tokenizer, mannequin) -> record[str]:
“”“Ragged steady batching: dynamic scheduling + packed prefill/decode.
Scheduling coverage:
– As much as BATCH_SIZE sequences run concurrently.
– A newly admitted sequence is queued with its full immediate as the subsequent
tokens to feed; its prefill then occurs packed into the subsequent step
alongside everybody else’s decode.
– Each step runs ONE packed ahead cross throughout all energetic slots.
– When a sequence finishes it’s instantly changed by the subsequent immediate.
The admission log reveals slots being reused (iteration-level scheduling).
Two consultant steps are visualized: step one (all prompts being
prefilled without delay) and step one that fuses a brand new immediate’s prefill with
different sequences’ decode tokens.
““”
system = mannequin.system
dtype = subsequent(mannequin.parameters()).dtype
queue = record(enumerate(requests)) # (req_id, (immediate, max_new_tokens))
slots: record[Optional[Sequence]] = [None] * BATCH_SIZE
outcomes: dict[int, str] = {}
def _admit(slot_idx: int) -> None:
if not queue:
slots[slot_idx] = None
return
req_id, (immediate, max_new_tokens) = queue.pop(0)
prompt_ids = tokenizer(immediate)[“input_ids”]
slots[slot_idx] = Sequence(
req_id=req_id,
immediate=immediate,
max_new_tokens=max_new_tokens,
pending_ids=record(prompt_ids), # prefill rides the subsequent step
)
print(
f” ++ [step {step:3d}] slot {slot_idx} <- admit req {req_id} “
f“({max_new_tokens} tok cap): {immediate!r}”,
flush=True,
)
# Fill the pool with the primary batch of prompts (step 0 = earlier than any decode).
step = 0
for i in vary(BATCH_SIZE):
_admit(i)
printed_mixed = False
whereas any(s is not None for s in slots):
step += 1
energetic = [(i, s) for i, s in enumerate(slots) if s is not None]
seqs = [s for _, s in active]
slot_ids = [i for i, _ in active]
# Visualize a few consultant steps so the packing is seen
# (printing each step could be far an excessive amount of output).
combined = any(s.kv_len == 0 for s in seqs) and any(s.kv_len > 0 for s in seqs)
if step == 1:
visualize_ragged_step(
seqs, tokenizer, f“STEP {step} – prompts packed collectively (all PREFILL)”, slot_ids)
elif combined and not printed_mixed:
visualize_ragged_step(
seqs, tokenizer, f“STEP {step} – PREFILL + DECODE fused in a single cross”, slot_ids)
printed_mixed = True
# ONE packed ahead cross (prefill + decode fused, no padding).
next_tokens = _ragged_step(seqs, mannequin, system, dtype)
for (slot_idx, seq), tok in zip(energetic, next_tokens):
seq.output_ids.append(tok)
seq.tokens_generated += 1
seq.pending_ids = [tok] # subsequent step: a single decode token
if tok == tokenizer.eos_token_id or seq.tokens_generated >= seq.max_new_tokens:
result_text = seq.immediate +
tokenizer.decode(seq.output_ids, skip_special_tokens=True)
outcomes[seq.req_id] = result_text
print(
f” — step {step:3d}] slot {slot_idx} finished req {seq.req_id} “
f“({seq.tokens_generated}/{seq.max_new_tokens} tokens): {result_text[:90]}”,
flush=True,
)
_admit(slot_idx)
return [results[k] for ok in sorted(outcomes)]
def predominant():
print(f“Loading {MODEL_ID}”)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
# Choose the quickest accessible system. On Apple Silicon (M1/M2/…) that is
# the MPS GPU. We hold float32 on MPS on goal: float16 there flips just a few
# grasping ties, which might break the “static == steady, token-for-token”
# property this demo depends on.
if torch.cuda.is_available():
system, dtype = “cuda”, torch.float16
elif torch.backends.mps.is_available():
system, dtype = “mps”, torch.float32
else:
system, dtype = “cpu”, torch.float32
mannequin = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=dtype,
attn_implementation=“keen”, # use our customized 4D masks immediately
)
mannequin.eval()
mannequin.to(system)
print(f“Operating on {system} ({dtype})n”)
requests = [
(“The capital of France is”, 6),
(“Today’s weather is so”, 50),
(“In machine learning, a transformer is”, 300),
(“Once upon a time in a land far away,”, 30),
(“Quantum computing differs from classical computing because”, 180),
(“The history of the Roman Empire began”, 45),
]
print(“=== Static batching ===”)
_device_sync(mannequin)
begin = time.perf_counter()
static_batching(requests, tokenizer, mannequin)
_device_sync(mannequin)
static_elapsed = time.perf_counter() – begin
print(f“nStatic batching elapsed: {static_elapsed:.2f}sn”)
print(“=== Steady batching (ragged) ===”)
_device_sync(mannequin)
begin = time.perf_counter()
continuous_batching(requests, tokenizer, mannequin)
_device_sync(mannequin)
continuous_elapsed = time.perf_counter() – begin
print(f“nContinuous batching elapsed: {continuous_elapsed:.2f}s”)
if __name__ == “__main__”:
predominant()

