Skip to content

vllm.v1.attention.backends.flex_attention

Attention layer with FlexAttention.

BlockSparsityHint

Bases: NamedTuple

This prunes KV blocks from the BlockMask before the flex_attention kernel is invoked, so that blocks that are fully masked never get loaded. Use this with custom mask_mods that are sparse to avoid the kernel iterating over all KV blocks unnecessarily.

Attributes:

Name Type Description
hint_fn _block_sparsity_hint_signature

(q_block_idx [num_tokens, 1], kv_block_idx [1, num_kv_blocks], block_size int) -> bool Tensor [num_tokens, num_kv_blocks]. Returns True for block pairs that may contain non-masked elements.

Source code in vllm/v1/attention/backends/flex_attention.py
class BlockSparsityHint(NamedTuple):
    """This prunes KV blocks from the BlockMask before the flex_attention kernel
    is invoked, so that blocks that are fully masked never get loaded.
    Use this with custom mask_mods that are sparse to avoid
    the kernel iterating over all KV blocks unnecessarily.

    Attributes:
        hint_fn: (q_block_idx [num_tokens, 1], kv_block_idx [1, num_kv_blocks],
            block_size int) -> bool Tensor [num_tokens, num_kv_blocks].
            Returns True for block pairs that may contain non-masked elements.
    """

    hint_fn: _block_sparsity_hint_signature

FlexAttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/flex_attention.py
class FlexAttentionBackend(AttentionBackend):
    accept_output_buffer: bool = True
    supported_dtypes: ClassVar[list[torch.dtype]] = [
        torch.float16,
        torch.bfloat16,
        torch.float32,
    ]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
        "float16",
        "bfloat16",
    ]

    forward_includes_kv_cache_update: bool = False

    @staticmethod
    def get_name() -> str:
        return "FLEX_ATTENTION"

    @classmethod
    def supports_attn_type(cls, attn_type: str) -> bool:
        """FlexAttention supports both decoder and encoder-only attention."""
        return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)

    @classmethod
    def supports_mm_prefix(cls) -> bool:
        """FlexAttention supports full attention for image tokens."""
        return True

    @staticmethod
    def get_impl_cls() -> type["FlexAttentionImpl"]:
        return FlexAttentionImpl

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        return (2, num_blocks, block_size, num_kv_heads, head_size)

    @staticmethod
    def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]:
        return FlexAttentionMetadataBuilder

    @staticmethod
    def use_cascade_attention(*args, **kwargs) -> bool:
        return False

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return []

supports_attn_type classmethod

supports_attn_type(attn_type: str) -> bool

FlexAttention supports both decoder and encoder-only attention.

Source code in vllm/v1/attention/backends/flex_attention.py
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
    """FlexAttention supports both decoder and encoder-only attention."""
    return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)

supports_mm_prefix classmethod

supports_mm_prefix() -> bool

FlexAttention supports full attention for image tokens.

Source code in vllm/v1/attention/backends/flex_attention.py
@classmethod
def supports_mm_prefix(cls) -> bool:
    """FlexAttention supports full attention for image tokens."""
    return True

FlexAttentionImpl

Bases: AttentionImpl

Source code in vllm/v1/attention/backends/flex_attention.py
class FlexAttentionImpl(AttentionImpl):
    sliding_window: int | None
    alibi_slopes: torch.Tensor | None
    logits_soft_cap: float | None
    mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
    logical_mask_mod: _mask_mod_signature | None = None
    block_sparsity_hint: BlockSparsityHint | None = None

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None = None,
        attn_type: AttentionType = AttentionType.DECODER,
        kv_sharing_target_layer_name: str | None = None,
        **kwargs,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.attn_type = attn_type

        if attn_type not in (AttentionType.ENCODER_ONLY, AttentionType.DECODER):
            raise NotImplementedError(
                f"FlexAttention does not support {attn_type} attention"
            )

        if alibi_slopes is not None:
            raise NotImplementedError(
                "FlexAttention does not support alibi slopes yet."
            )
        else:
            self.alibi_slopes = None

        self.sliding_window = sliding_window

        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
        if self.logits_soft_cap is not None:
            raise NotImplementedError(
                "FlexAttention does not support logits soft cap yet."
            )

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("FlexAttention does not support kv sharing yet.")

        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "FlexAttention does not support quantized kv-cache. Yet"
            )

    @staticmethod
    def view_as_4d(tensor: torch.Tensor) -> torch.Tensor:
        """View a 3d tensor as 4D."""
        if tensor.ndim == 4:
            return tensor
        assert tensor.ndim == 3
        return tensor[None, :, :, :]

    def do_kv_cache_update(
        self,
        layer: torch.nn.Module,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> None:
        if self.attn_type == AttentionType.ENCODER_ONLY:
            return

        key_cache, value_cache = kv_cache.unbind(0)
        torch.ops._C_cache_ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlexAttentionMetadata,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass with FLexAttention.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache: shape =
                [2, num_blocks, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."
        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported for FlexAttentionImpl"
            )

        enable_gqa = self.num_kv_heads != self.num_heads

        if attn_metadata is None:
            # Profiling run.
            return output.fill_(0)
            # query = self.view_as_4d(query).permute(0, 2, 1, 3)
            # return torch.empty_like(query)

        num_actual_tokens = attn_metadata.num_actual_tokens

        needs_rebuild_block_mask = False
        if attn_metadata.sliding_window != self.sliding_window:
            attn_metadata.sliding_window = self.sliding_window
            if attn_metadata.direct_build:
                # update mask mod in attention metadata
                attn_metadata.mask_mod = attn_metadata.get_mask_mod()
            needs_rebuild_block_mask = True

        if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None):
            self.mm_prefix_range = attn_metadata.mm_prefix_range
            attn_metadata.mask_mod = attn_metadata.get_mask_mod()
            needs_rebuild_block_mask = True

        layer_mask_mod = getattr(layer, "logical_mask_mod", None)
        if (
            layer_mask_mod is not None
            and attn_metadata.logical_mask_mod is not layer_mask_mod
        ):
            attn_metadata.logical_mask_mod = layer_mask_mod
            attn_metadata.mask_mod = attn_metadata.get_mask_mod()
            needs_rebuild_block_mask = True

        layer_hint = getattr(layer, "block_sparsity_hint", None)
        if (
            layer_hint is not None
            and attn_metadata.block_sparsity_hint is not layer_hint
        ):
            attn_metadata.block_sparsity_hint = layer_hint
            needs_rebuild_block_mask = True

        if needs_rebuild_block_mask or attn_metadata.block_mask is None:
            if attn_metadata.direct_build:
                attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
            else:
                attn_metadata.block_mask = attn_metadata.build_block_mask()

        if not attn_metadata.causal:
            assert self.attn_type == AttentionType.ENCODER_ONLY

            query, key_tensor, value_tensor = map(
                lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
                (query, key, value),
            )

            query = query[:, :, :num_actual_tokens, :]
            if (key_tensor.size(-2) > num_actual_tokens) or (
                value_tensor.size(-2) > num_actual_tokens
            ):
                # In the encoder-only model with torch.compile,
                # qkv might be padded, which might cause exception.
                # see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290
                key_tensor = key_tensor[:, :, :num_actual_tokens, :]
                value_tensor = value_tensor[:, :, :num_actual_tokens, :]

        else:
            assert self.attn_type == AttentionType.DECODER
            key_cache, value_cache = kv_cache.unbind(0)

            # View out the block_size dim
            key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
            value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
            query, key_tensor, value_tensor = map(
                lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
                (query, key_cache, value_cache),
            )

            query = query[:, :, :num_actual_tokens, :]

        # Doesn't work for now -> constraint violation
        # torch._dynamo.try_mark_dynamic(query, 2)

        assert attn_metadata.block_mask is not None
        block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE

        kernel_options = get_kernel_options(
            query, block_m, block_n, attn_metadata.direct_build
        )
        out = flex_attention_compiled(
            query,
            key_tensor,
            value_tensor,
            attn_metadata.transformed_score_mod,
            attn_metadata.block_mask,
            self.scale,
            enable_gqa=enable_gqa,
            kernel_options=kernel_options,
        )

        # Flex doesn't have an out variant today, rely on epilogue fusion
        out = out.permute(0, 2, 1, 3).squeeze(0)
        output[:num_actual_tokens, :, :].copy_(out)
        return output

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: FlexAttentionMetadata,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor

Forward pass with FLexAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
kv_cache Tensor

shape = [2, num_blocks, block_size, num_kv_heads, head_size]

required
attn_metadata FlexAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm/v1/attention/backends/flex_attention.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: FlexAttentionMetadata,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    """Forward pass with FLexAttention.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size]
        kv_cache: shape =
            [2, num_blocks, block_size, num_kv_heads, head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    assert output is not None, "Output tensor must be provided."
    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported for FlexAttentionImpl"
        )

    enable_gqa = self.num_kv_heads != self.num_heads

    if attn_metadata is None:
        # Profiling run.
        return output.fill_(0)
        # query = self.view_as_4d(query).permute(0, 2, 1, 3)
        # return torch.empty_like(query)

    num_actual_tokens = attn_metadata.num_actual_tokens

    needs_rebuild_block_mask = False
    if attn_metadata.sliding_window != self.sliding_window:
        attn_metadata.sliding_window = self.sliding_window
        if attn_metadata.direct_build:
            # update mask mod in attention metadata
            attn_metadata.mask_mod = attn_metadata.get_mask_mod()
        needs_rebuild_block_mask = True

    if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None):
        self.mm_prefix_range = attn_metadata.mm_prefix_range
        attn_metadata.mask_mod = attn_metadata.get_mask_mod()
        needs_rebuild_block_mask = True

    layer_mask_mod = getattr(layer, "logical_mask_mod", None)
    if (
        layer_mask_mod is not None
        and attn_metadata.logical_mask_mod is not layer_mask_mod
    ):
        attn_metadata.logical_mask_mod = layer_mask_mod
        attn_metadata.mask_mod = attn_metadata.get_mask_mod()
        needs_rebuild_block_mask = True

    layer_hint = getattr(layer, "block_sparsity_hint", None)
    if (
        layer_hint is not None
        and attn_metadata.block_sparsity_hint is not layer_hint
    ):
        attn_metadata.block_sparsity_hint = layer_hint
        needs_rebuild_block_mask = True

    if needs_rebuild_block_mask or attn_metadata.block_mask is None:
        if attn_metadata.direct_build:
            attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
        else:
            attn_metadata.block_mask = attn_metadata.build_block_mask()

    if not attn_metadata.causal:
        assert self.attn_type == AttentionType.ENCODER_ONLY

        query, key_tensor, value_tensor = map(
            lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
            (query, key, value),
        )

        query = query[:, :, :num_actual_tokens, :]
        if (key_tensor.size(-2) > num_actual_tokens) or (
            value_tensor.size(-2) > num_actual_tokens
        ):
            # In the encoder-only model with torch.compile,
            # qkv might be padded, which might cause exception.
            # see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290
            key_tensor = key_tensor[:, :, :num_actual_tokens, :]
            value_tensor = value_tensor[:, :, :num_actual_tokens, :]

    else:
        assert self.attn_type == AttentionType.DECODER
        key_cache, value_cache = kv_cache.unbind(0)

        # View out the block_size dim
        key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
        value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
        query, key_tensor, value_tensor = map(
            lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
            (query, key_cache, value_cache),
        )

        query = query[:, :, :num_actual_tokens, :]

    # Doesn't work for now -> constraint violation
    # torch._dynamo.try_mark_dynamic(query, 2)

    assert attn_metadata.block_mask is not None
    block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE

    kernel_options = get_kernel_options(
        query, block_m, block_n, attn_metadata.direct_build
    )
    out = flex_attention_compiled(
        query,
        key_tensor,
        value_tensor,
        attn_metadata.transformed_score_mod,
        attn_metadata.block_mask,
        self.scale,
        enable_gqa=enable_gqa,
        kernel_options=kernel_options,
    )

    # Flex doesn't have an out variant today, rely on epilogue fusion
    out = out.permute(0, 2, 1, 3).squeeze(0)
    output[:num_actual_tokens, :, :].copy_(out)
    return output

view_as_4d staticmethod

view_as_4d(tensor: Tensor) -> Tensor

View a 3d tensor as 4D.

Source code in vllm/v1/attention/backends/flex_attention.py
@staticmethod
def view_as_4d(tensor: torch.Tensor) -> torch.Tensor:
    """View a 3d tensor as 4D."""
    if tensor.ndim == 4:
        return tensor
    assert tensor.ndim == 3
    return tensor[None, :, :, :]

FlexAttentionMetadata dataclass

Source code in vllm/v1/attention/backends/flex_attention.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
@dataclass
class FlexAttentionMetadata:
    causal: bool
    num_actual_tokens: int  # Number of tokens excluding padding.
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table: torch.Tensor
    slot_mapping: torch.Tensor

    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: torch.Tensor | None
    prefix_kv_lens: torch.Tensor | None
    suffix_kv_lens: torch.Tensor | None

    # Block info
    total_cache_tokens: int
    block_size: int
    max_possible_sequence_length: int
    num_reqs: int
    physical_to_logical: torch.Tensor
    decode_offset: torch.Tensor
    num_blocks_per_seq: torch.Tensor

    # For logging.
    num_input_tokens: int = 0  # Number of tokens including padding.

    # Flex Metadata
    num_blocks = 0
    block_mask: BlockMask | None = None
    score_mod: _score_mod_signature | None = None
    logical_mask_mod: _mask_mod_signature = causal_mask_mod
    doc_ids: torch.Tensor | None = None
    direct_build: bool = True
    q_block_size: int = 16
    kv_block_size: int = 16
    transformed_score_mod: _score_mod_signature | None = None
    sliding_window: int | None = None
    mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
    block_sparsity_hint: BlockSparsityHint | None = None

    @cached_property
    def logical_block_ids(self):
        return torch.arange(
            cdiv(self.max_seq_len, self.block_size),
            device=self.block_table.device,
            dtype=torch.long,
        )

    def _convert_physical_to_logical(
        self,
        request_lookup: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Convert physical indices to logical indices for both query and kv.

        NB is_within_lower_bound: do sequences start on block_boundaries?

        Returns:
            tuple of (is_valid, logical_q_idx, logical_kv_idx)
        """
        # Map query indices to corresponding request indices
        q_req = request_lookup[q_idx]

        # Convert physical KV indices to logical indices
        physical_kv_block = physical_kv_idx // self.block_size
        physical_kv_offset = physical_kv_idx % self.block_size
        logical_block_idx = self.physical_to_logical[q_req, physical_kv_block]
        logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset

        # Determine valid kv indices
        live_block = logical_block_idx >= 0
        within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
        within_lower_bound = logical_kv_idx >= 0
        is_valid = live_block & within_upper_bound & within_lower_bound

        # Convert physical query indices to logical indices
        local_q_idx = q_idx - self.query_start_loc[q_req]
        logical_q_idx = local_q_idx + self.decode_offset[q_req]

        return is_valid, logical_q_idx, logical_kv_idx

    def get_paged_mask_mod(self) -> _mask_mod_signature:
        """Creates the mask_mod function for FlexAttention.

        This function creates the combined mask mod function that handles:
            1. The paged attention block mapping
            2. The mapping from packed query sequences to logical query entries

        It also by defaults adds the decoding offset to the query indices.
        With this info we create the "logical" indices that are passed to
        mask_mod functions. This allows mask mod functions to be agnostic to
        layout of the query and key/value tensors.
        """
        assert self.doc_ids is not None

        def final_mask_mod(
            b: torch.Tensor,
            h: torch.Tensor,
            q_idx: torch.Tensor,
            physical_kv_idx: torch.Tensor,
        ) -> torch.Tensor:
            (is_valid, logical_q_idx, logical_kv_idx) = (
                self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
            )
            # Apply mask modification only for valid indices
            return torch.where(
                is_valid,
                self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
                False,
            )

        return final_mask_mod

    def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
        """Creates the encoder mask_mod function for FlexAttention.

        Since the encoder bidirectional attention doesn't run with
        KV cache, this function creates a mask based on the
        packed query sequences.
        """
        # Create a lookup mapping from query indices -> request number
        request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)

        def final_mask_mod(
            b: torch.Tensor,
            h: torch.Tensor,
            q_idx: torch.Tensor,
            kv_idx: torch.Tensor,
        ) -> torch.Tensor:
            return request_lookup[q_idx] == request_lookup[kv_idx]

        return final_mask_mod

    def get_sliding_window_mask_mod(self) -> _mask_mod_signature:
        """Creates the sliding window mask_mod function for FlexAttention.

        Note that the sliding window mask here is bidirectional, we need
        to mask it with the bidirectional/causal mask for encoder/decoder.
        """

        if self.sliding_window is None:
            raise ValueError("sliding_window must be set for sliding window attention")

        def sliding_window_mask_mod(
            b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
        ):
            return torch.abs(q_idx - kv_idx) < self.sliding_window

        def final_mask_mod(
            b: torch.Tensor,
            h: torch.Tensor,
            q_idx: torch.Tensor,
            physical_kv_idx: torch.Tensor,
        ) -> torch.Tensor:
            (is_valid, logical_q_idx, logical_kv_idx) = (
                self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
            )
            return torch.where(
                is_valid,
                sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx),
                False,
            )

        return final_mask_mod if self.causal else sliding_window_mask_mod

    def get_prefix_lm_mask_mod(self) -> _mask_mod_signature:
        """Creates the prefix LM mask_mod function for FlexAttention."""

        assert self.doc_ids is not None
        request_lookup = self.doc_ids

        def prefix_lm_mask_mod(
            b: torch.Tensor,
            h: torch.Tensor,
            cu_q_idx: torch.Tensor,
            q_idx: torch.Tensor,
            kv_idx: torch.Tensor,
        ):
            mask = torch.zeros_like(q_idx, dtype=torch.bool)
            for req, doc_range_lst in (self.mm_prefix_range or {}).items():
                req_mask = request_lookup[cu_q_idx] == req
                for start, end in doc_range_lst:
                    doc_mask_q = (q_idx >= start) & (q_idx <= end)
                    doc_mask_kv = (kv_idx >= start) & (kv_idx <= end)
                    mask = mask | (req_mask & doc_mask_q & doc_mask_kv)
            return mask

        def final_mask_mod(
            b: torch.Tensor,
            h: torch.Tensor,
            q_idx: torch.Tensor,
            physical_kv_idx: torch.Tensor,
        ) -> torch.Tensor:
            (is_valid, logical_q_idx, logical_kv_idx) = (
                self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
            )
            return torch.where(
                is_valid,
                prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx),
                False,
            )

        return final_mask_mod

    def get_mask_mod(self):
        # Stage-1: initialize the base mask_mod
        # (causal mask for decoder or bidirectional mask for encoder)
        has_custom_mask = self.logical_mask_mod is not causal_mask_mod
        if self.causal or has_custom_mask:
            mask_mod = self.get_paged_mask_mod()
        else:
            mask_mod = self.get_bidirectional_mask_mod()
        # stage-2: add external mask_mod for special attention during
        # forwarding runtime to create the combined mask_mod.
        if self.sliding_window is not None:
            # Add sliding window mask for sliding window attention
            sliding_window_mask_mod = self.get_sliding_window_mask_mod()
            mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
        if self.mm_prefix_range:
            # Add prefix LM mask for vision-language prefix LM attention
            prefix_lm_mask_mod = self.get_prefix_lm_mask_mod()
            mask_mod = or_masks(mask_mod, prefix_lm_mask_mod)
        return mask_mod

    def get_transformed_score_mod(self) -> _score_mod_signature | None:
        """Creates the transformed score_mod function for FlexAttention.

        This function wraps the user's score_mod to handle physical-to-logical
        index conversion, similar to how get_mask_mod works for mask functions.
        """
        if self.score_mod is None:
            return None

        # Create a lookup mapping from query indices -> request number
        request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
        user_score_mod = self.score_mod

        def transformed_score_mod(
            score: torch.Tensor,
            b: torch.Tensor,
            h: torch.Tensor,
            q_idx: torch.Tensor,
            physical_kv_idx: torch.Tensor,
        ) -> torch.Tensor:
            (is_valid, logical_q_idx, logical_kv_idx) = (
                self._convert_physical_to_logical(
                    request_lookup, q_idx, physical_kv_idx
                )
            )

            return torch.where(
                is_valid,
                user_score_mod(
                    score, b, h, logical_q_idx, logical_kv_idx, physical_q=q_idx
                ),
                -float("inf"),
            )

        return transformed_score_mod

    def _build_block_mask_direct(self) -> BlockMask:
        """Direct block mask construction for standard causal attention.

        This method constructs the block mask directly using
        BlockMask.from_kv_blocks which is much more efficient than the
        generic create_block_mask approach.

        The direct path works as follows:
        1. For each query token, fetch blocks from block_table using max_seq_len
           and exclude out of sliding window blocks if needed.
           (this fetches more blocks than needed for shorter sequences)
        2. Group query tokens into chunks of q_block_size
        3. For each group, deduplicate the blocks using unique_static_unsorted
        4. Create BlockMask using the deduplicated block indices

        Over-estimation occurs when a group of q_block_size tokens contains
        multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for
        each sequence represented in the group, even though individual query
        tokens may only need a subset of those blocks based on causal masking
        and their position.

        """
        page_to_block_ratio = self.kv_block_size // self.block_size
        if page_to_block_ratio != 1:
            raise ValueError(
                f"FlexAttention currently requires the cache block size "
                f"({self.block_size}) to be equal to the kv_block_size "
                f"({self.kv_block_size}). Please check your model's "
                f"configuration."
            )

        used_pages = self.block_table[
            self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
        ]

        custom_hint = self.block_sparsity_hint is not None

        if self.sliding_window or custom_hint:
            device = used_pages.device
            assert self.doc_ids is not None
            token_indices = torch.arange(
                self.doc_ids.shape[0], device=device, dtype=torch.long
            )
            logical_q_idx = (
                token_indices
                - self.query_start_loc[self.doc_ids]
                + self.decode_offset[self.doc_ids]
            )

            if self.sliding_window:
                assert self.sliding_window is not None
                min_kv_idx = torch.clamp(
                    logical_q_idx - (self.sliding_window - 1), min=0
                )
                min_block_idx = min_kv_idx // self.block_size
                sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
                used_pages.masked_fill_(~sliding_mask, 0)
            if custom_hint:
                assert self.block_sparsity_hint is not None
                q_block_idx = logical_q_idx // self.block_size
                hint_mask = self.block_sparsity_hint.hint_fn(
                    q_block_idx[:, None],
                    self.logical_block_ids[None, :],
                    self.block_size,
                )
                used_pages.masked_fill_(~hint_mask, 0)

        used_pages_padded = pad_to_multiple(
            used_pages, multiple=self.q_block_size, dim=0
        )
        used_pages_padded = used_pages_padded.reshape(
            used_pages_padded.shape[0] // self.q_block_size, -1
        )
        used_pages_padded = used_pages_padded // page_to_block_ratio
        kv_indices = unique_static_unsorted(
            (used_pages_padded.long()), M=self.num_blocks
        ).to(torch.int32)

        kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
        block_mask_kwargs = {
            "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
            "kv_num_blocks": kv_num_blocks[None, None],
            "kv_indices": kv_indices[None, None],
            "full_kv_num_blocks": None,
            "full_kv_indices": None,
            "BLOCK_SIZE": (self.q_block_size, self.kv_block_size),
            "mask_mod": self.mask_mod,
        }

        # compute_q_blocks parameter is available in PyTorch 2.9+
        if is_torch_equal_or_newer("2.9.0.dev0"):
            block_mask_kwargs["compute_q_blocks"] = False
        return BlockMask.from_kv_blocks(**block_mask_kwargs)

    def build_block_mask(self) -> BlockMask:
        mask_mod = self.get_mask_mod()
        kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens
        return create_block_mask_compiled(
            mask_mod,
            None,
            None,
            self.num_actual_tokens,
            kv_len,
            device=self.block_table.device,
            BLOCK_SIZE=(self.q_block_size, self.kv_block_size),
        )

    def __post_init__(self):
        assert self.use_cascade is False, "Not implemented yet."
        assert self.common_prefix_len == 0, "Not implemented yet."
        assert self.cu_prefix_query_lens is None, "Not implemented yet."
        assert self.prefix_kv_lens is None, "Not implemented yet."
        assert self.suffix_kv_lens is None, "Not implemented yet."
        # Create a lookup mapping from query indices -> request number
        self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
        self.num_blocks = self.total_cache_tokens // self.block_size

        self.mask_mod = self.get_mask_mod()
        self.transformed_score_mod = self.get_transformed_score_mod()

_build_block_mask_direct

_build_block_mask_direct() -> BlockMask

Direct block mask construction for standard causal attention.

This method constructs the block mask directly using BlockMask.from_kv_blocks which is much more efficient than the generic create_block_mask approach.

The direct path works as follows: 1. For each query token, fetch blocks from block_table using max_seq_len and exclude out of sliding window blocks if needed. (this fetches more blocks than needed for shorter sequences) 2. Group query tokens into chunks of q_block_size 3. For each group, deduplicate the blocks using unique_static_unsorted 4. Create BlockMask using the deduplicated block indices

Over-estimation occurs when a group of q_block_size tokens contains multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for each sequence represented in the group, even though individual query tokens may only need a subset of those blocks based on causal masking and their position.

Source code in vllm/v1/attention/backends/flex_attention.py
def _build_block_mask_direct(self) -> BlockMask:
    """Direct block mask construction for standard causal attention.

    This method constructs the block mask directly using
    BlockMask.from_kv_blocks which is much more efficient than the
    generic create_block_mask approach.

    The direct path works as follows:
    1. For each query token, fetch blocks from block_table using max_seq_len
       and exclude out of sliding window blocks if needed.
       (this fetches more blocks than needed for shorter sequences)
    2. Group query tokens into chunks of q_block_size
    3. For each group, deduplicate the blocks using unique_static_unsorted
    4. Create BlockMask using the deduplicated block indices

    Over-estimation occurs when a group of q_block_size tokens contains
    multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for
    each sequence represented in the group, even though individual query
    tokens may only need a subset of those blocks based on causal masking
    and their position.

    """
    page_to_block_ratio = self.kv_block_size // self.block_size
    if page_to_block_ratio != 1:
        raise ValueError(
            f"FlexAttention currently requires the cache block size "
            f"({self.block_size}) to be equal to the kv_block_size "
            f"({self.kv_block_size}). Please check your model's "
            f"configuration."
        )

    used_pages = self.block_table[
        self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
    ]

    custom_hint = self.block_sparsity_hint is not None

    if self.sliding_window or custom_hint:
        device = used_pages.device
        assert self.doc_ids is not None
        token_indices = torch.arange(
            self.doc_ids.shape[0], device=device, dtype=torch.long
        )
        logical_q_idx = (
            token_indices
            - self.query_start_loc[self.doc_ids]
            + self.decode_offset[self.doc_ids]
        )

        if self.sliding_window:
            assert self.sliding_window is not None
            min_kv_idx = torch.clamp(
                logical_q_idx - (self.sliding_window - 1), min=0
            )
            min_block_idx = min_kv_idx // self.block_size
            sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
            used_pages.masked_fill_(~sliding_mask, 0)
        if custom_hint:
            assert self.block_sparsity_hint is not None
            q_block_idx = logical_q_idx // self.block_size
            hint_mask = self.block_sparsity_hint.hint_fn(
                q_block_idx[:, None],
                self.logical_block_ids[None, :],
                self.block_size,
            )
            used_pages.masked_fill_(~hint_mask, 0)

    used_pages_padded = pad_to_multiple(
        used_pages, multiple=self.q_block_size, dim=0
    )
    used_pages_padded = used_pages_padded.reshape(
        used_pages_padded.shape[0] // self.q_block_size, -1
    )
    used_pages_padded = used_pages_padded // page_to_block_ratio
    kv_indices = unique_static_unsorted(
        (used_pages_padded.long()), M=self.num_blocks
    ).to(torch.int32)

    kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
    block_mask_kwargs = {
        "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
        "kv_num_blocks": kv_num_blocks[None, None],
        "kv_indices": kv_indices[None, None],
        "full_kv_num_blocks": None,
        "full_kv_indices": None,
        "BLOCK_SIZE": (self.q_block_size, self.kv_block_size),
        "mask_mod": self.mask_mod,
    }

    # compute_q_blocks parameter is available in PyTorch 2.9+
    if is_torch_equal_or_newer("2.9.0.dev0"):
        block_mask_kwargs["compute_q_blocks"] = False
    return BlockMask.from_kv_blocks(**block_mask_kwargs)

_convert_physical_to_logical

_convert_physical_to_logical(
    request_lookup: Tensor,
    q_idx: Tensor,
    physical_kv_idx: Tensor,
) -> tuple[Tensor, Tensor, Tensor]

Convert physical indices to logical indices for both query and kv.

NB is_within_lower_bound: do sequences start on block_boundaries?

Returns:

Type Description
tuple[Tensor, Tensor, Tensor]

tuple of (is_valid, logical_q_idx, logical_kv_idx)

Source code in vllm/v1/attention/backends/flex_attention.py
def _convert_physical_to_logical(
    self,
    request_lookup: torch.Tensor,
    q_idx: torch.Tensor,
    physical_kv_idx: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Convert physical indices to logical indices for both query and kv.

    NB is_within_lower_bound: do sequences start on block_boundaries?

    Returns:
        tuple of (is_valid, logical_q_idx, logical_kv_idx)
    """
    # Map query indices to corresponding request indices
    q_req = request_lookup[q_idx]

    # Convert physical KV indices to logical indices
    physical_kv_block = physical_kv_idx // self.block_size
    physical_kv_offset = physical_kv_idx % self.block_size
    logical_block_idx = self.physical_to_logical[q_req, physical_kv_block]
    logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset

    # Determine valid kv indices
    live_block = logical_block_idx >= 0
    within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
    within_lower_bound = logical_kv_idx >= 0
    is_valid = live_block & within_upper_bound & within_lower_bound

    # Convert physical query indices to logical indices
    local_q_idx = q_idx - self.query_start_loc[q_req]
    logical_q_idx = local_q_idx + self.decode_offset[q_req]

    return is_valid, logical_q_idx, logical_kv_idx

get_bidirectional_mask_mod

get_bidirectional_mask_mod() -> _mask_mod_signature

Creates the encoder mask_mod function for FlexAttention.

Since the encoder bidirectional attention doesn't run with KV cache, this function creates a mask based on the packed query sequences.

Source code in vllm/v1/attention/backends/flex_attention.py
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
    """Creates the encoder mask_mod function for FlexAttention.

    Since the encoder bidirectional attention doesn't run with
    KV cache, this function creates a mask based on the
    packed query sequences.
    """
    # Create a lookup mapping from query indices -> request number
    request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)

    def final_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        kv_idx: torch.Tensor,
    ) -> torch.Tensor:
        return request_lookup[q_idx] == request_lookup[kv_idx]

    return final_mask_mod

get_paged_mask_mod

get_paged_mask_mod() -> _mask_mod_signature

Creates the mask_mod function for FlexAttention.

This function creates the combined mask mod function that handles
  1. The paged attention block mapping
  2. The mapping from packed query sequences to logical query entries

It also by defaults adds the decoding offset to the query indices. With this info we create the "logical" indices that are passed to mask_mod functions. This allows mask mod functions to be agnostic to layout of the query and key/value tensors.

Source code in vllm/v1/attention/backends/flex_attention.py
def get_paged_mask_mod(self) -> _mask_mod_signature:
    """Creates the mask_mod function for FlexAttention.

    This function creates the combined mask mod function that handles:
        1. The paged attention block mapping
        2. The mapping from packed query sequences to logical query entries

    It also by defaults adds the decoding offset to the query indices.
    With this info we create the "logical" indices that are passed to
    mask_mod functions. This allows mask mod functions to be agnostic to
    layout of the query and key/value tensors.
    """
    assert self.doc_ids is not None

    def final_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ) -> torch.Tensor:
        (is_valid, logical_q_idx, logical_kv_idx) = (
            self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
        )
        # Apply mask modification only for valid indices
        return torch.where(
            is_valid,
            self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
            False,
        )

    return final_mask_mod

get_prefix_lm_mask_mod

get_prefix_lm_mask_mod() -> _mask_mod_signature

Creates the prefix LM mask_mod function for FlexAttention.

Source code in vllm/v1/attention/backends/flex_attention.py
def get_prefix_lm_mask_mod(self) -> _mask_mod_signature:
    """Creates the prefix LM mask_mod function for FlexAttention."""

    assert self.doc_ids is not None
    request_lookup = self.doc_ids

    def prefix_lm_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        cu_q_idx: torch.Tensor,
        q_idx: torch.Tensor,
        kv_idx: torch.Tensor,
    ):
        mask = torch.zeros_like(q_idx, dtype=torch.bool)
        for req, doc_range_lst in (self.mm_prefix_range or {}).items():
            req_mask = request_lookup[cu_q_idx] == req
            for start, end in doc_range_lst:
                doc_mask_q = (q_idx >= start) & (q_idx <= end)
                doc_mask_kv = (kv_idx >= start) & (kv_idx <= end)
                mask = mask | (req_mask & doc_mask_q & doc_mask_kv)
        return mask

    def final_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ) -> torch.Tensor:
        (is_valid, logical_q_idx, logical_kv_idx) = (
            self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
        )
        return torch.where(
            is_valid,
            prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx),
            False,
        )

    return final_mask_mod

get_sliding_window_mask_mod

get_sliding_window_mask_mod() -> _mask_mod_signature

Creates the sliding window mask_mod function for FlexAttention.

Note that the sliding window mask here is bidirectional, we need to mask it with the bidirectional/causal mask for encoder/decoder.

Source code in vllm/v1/attention/backends/flex_attention.py
def get_sliding_window_mask_mod(self) -> _mask_mod_signature:
    """Creates the sliding window mask_mod function for FlexAttention.

    Note that the sliding window mask here is bidirectional, we need
    to mask it with the bidirectional/causal mask for encoder/decoder.
    """

    if self.sliding_window is None:
        raise ValueError("sliding_window must be set for sliding window attention")

    def sliding_window_mask_mod(
        b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
    ):
        return torch.abs(q_idx - kv_idx) < self.sliding_window

    def final_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ) -> torch.Tensor:
        (is_valid, logical_q_idx, logical_kv_idx) = (
            self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
        )
        return torch.where(
            is_valid,
            sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx),
            False,
        )

    return final_mask_mod if self.causal else sliding_window_mask_mod

get_transformed_score_mod

get_transformed_score_mod() -> _score_mod_signature | None

Creates the transformed score_mod function for FlexAttention.

This function wraps the user's score_mod to handle physical-to-logical index conversion, similar to how get_mask_mod works for mask functions.

Source code in vllm/v1/attention/backends/flex_attention.py
def get_transformed_score_mod(self) -> _score_mod_signature | None:
    """Creates the transformed score_mod function for FlexAttention.

    This function wraps the user's score_mod to handle physical-to-logical
    index conversion, similar to how get_mask_mod works for mask functions.
    """
    if self.score_mod is None:
        return None

    # Create a lookup mapping from query indices -> request number
    request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
    user_score_mod = self.score_mod

    def transformed_score_mod(
        score: torch.Tensor,
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ) -> torch.Tensor:
        (is_valid, logical_q_idx, logical_kv_idx) = (
            self._convert_physical_to_logical(
                request_lookup, q_idx, physical_kv_idx
            )
        )

        return torch.where(
            is_valid,
            user_score_mod(
                score, b, h, logical_q_idx, logical_kv_idx, physical_q=q_idx
            ),
            -float("inf"),
        )

    return transformed_score_mod

physical_to_logical_mapping

physical_to_logical_mapping(
    block_table: Tensor,
    seq_lens: Tensor,
    block_size: int,
    total_blocks: int,
) -> Tensor

Creates an inverse mapping from physical block locations to logical indices.

The original block_table maps from logical blocks to physical locations:

Logical to Physical (Original block_table): ┌───────────────────────────────────────────┐ │ Request 0: │ │ │ │ Logical Blocks: 0 1 2 3 4 5 6 7 │ │ │ │ │ │ │ │ │ │ │ │ v v v v v v v v │ │ Physical Blocks: 3 5 1 7 4 2 0 6 │ └───────────────────────────────────────────┘

This function creates the inverse mapping:

Physical to Logical (Inverse mapping): ┌───────────────────────────────────────────┐ │ Request 0: │ │ │ │ Physical Blocks: 0 1 2 3 4 5 6 7 │ │ │ │ │ │ │ │ │ │ │ │ v v v v v v v v │ │ Logical Blocks: 6 2 5 0 4 1 7 3 │ └───────────────────────────────────────────┘

If multiple logical blocks map to the same physical block, this function returns the latest (maximum) logical block index.

If a physical block is not mapped to by any logical block, its value in the result will be -1.

IMPORTANT: Garbage Value Protection ──────────────────────────────────── The block_table tensor may contain garbage values in unused positions (beyond the actual sequence length). For example, if a sequence only needs 3 blocks but the table has space for 8:

block_table[0] = [10, 25, 7, 999, 1234, 888, ...]
                            ^^^^^^^^^^^^^^^^^^^^
                            garbage values

These garbage values can cause issues because: 1. They may map to valid physical blocks by coincidence 2. The scatter_ operation will assign them logical indices 3. Later attention computations may incorrectly access these blocks

To prevent this, we use seq_lens and block_size to mask out unused entries, ensuring only valid block references are processed.

IMPORTANT: Reused physical blocks (sliding-window / hybrid attention) ──────────────────────────────────────────────────────────────────── For some attention types, physical cache blocks can be reused over time. This can cause the same physical block id to appear multiple times in a row of block_table at different logical block indices. In that case, only the latest logical block index corresponds to the current contents of that physical block. Therefore, the inverse mapping must pick the maximum logical block index for each physical block id.

Parameters:

Name Type Description Default
block_table Tensor

Tensor of shape [max_reqs, max_num_blocks] mapping logical blocks to physical locations. May contain garbage values in unused positions.

required
seq_lens Tensor

Tensor of sequence lengths for each request. Used to determine how many blocks are actually needed per sequence.

required
block_size int

Size of each block in tokens. Used with seq_lens to compute the number of valid blocks per sequence.

required
total_blocks int

Total number of physical blocks available

required

Returns:

Type Description
Tensor

A tensor of shape [max_reqs, total_blocks] where each entry

Tensor

physical_to_logical[req_id, physical_block] contains the logical

Tensor

block index for that physical block, or -1 if unused.

Source code in vllm/v1/attention/backends/flex_attention.py
def physical_to_logical_mapping(
    block_table: torch.Tensor,
    seq_lens: torch.Tensor,
    block_size: int,
    total_blocks: int,
) -> torch.Tensor:
    """
    Creates an inverse mapping from physical block locations to logical indices.

    The original block_table maps from logical blocks to physical locations:

    Logical to Physical (Original block_table):
    ┌───────────────────────────────────────────┐
    │ Request 0:                                │
    │                                           │
    │ Logical Blocks:  0  1  2  3  4  5  6  7   │
    │                  │  │  │  │  │  │  │  │   │
    │                  v  v  v  v  v  v  v  v   │
    │ Physical Blocks: 3  5  1  7  4  2  0  6   │
    └───────────────────────────────────────────┘

    This function creates the inverse mapping:

    Physical to Logical (Inverse mapping):
    ┌───────────────────────────────────────────┐
    │ Request 0:                                │
    │                                           │
    │ Physical Blocks: 0  1  2  3  4  5  6  7   │
    │                  │  │  │  │  │  │  │  │   │
    │                  v  v  v  v  v  v  v  v   │
    │ Logical Blocks:  6  2  5  0  4  1  7  3   │
    └───────────────────────────────────────────┘

    If multiple logical blocks map to the same physical block,
    this function returns the latest (maximum) logical block index.

    If a physical block is not mapped to by any logical block,
    its value in the result will be -1.

    IMPORTANT: Garbage Value Protection
    ────────────────────────────────────
    The block_table tensor may contain garbage values in unused positions
    (beyond the actual sequence length). For example, if a sequence only
    needs 3 blocks but the table has space for 8:

        block_table[0] = [10, 25, 7, 999, 1234, 888, ...]
                                    ^^^^^^^^^^^^^^^^^^^^
                                    garbage values

    These garbage values can cause issues because:
    1. They may map to valid physical blocks by coincidence
    2. The scatter_ operation will assign them logical indices
    3. Later attention computations may incorrectly access these blocks

    To prevent this, we use seq_lens and block_size to mask out unused
    entries, ensuring only valid block references are processed.

    IMPORTANT: Reused physical blocks (sliding-window / hybrid attention)
    ────────────────────────────────────────────────────────────────────
    For some attention types, physical cache blocks can be reused over time.
    This can cause the same physical block id to appear multiple times in a row
    of `block_table` at different logical block indices. In that case, only the
    latest logical block index corresponds to the current contents of that
    physical block. Therefore, the inverse mapping must pick the maximum logical
    block index for each physical block id.

    Args:
        block_table: Tensor of shape [max_reqs, max_num_blocks]
            mapping logical blocks to physical locations. May contain
            garbage values in unused positions.
        seq_lens: Tensor of sequence lengths for each request. Used to
            determine how many blocks are actually needed per sequence.
        block_size: Size of each block in tokens. Used with seq_lens to
            compute the number of valid blocks per sequence.
        total_blocks: Total number of physical blocks available

    Returns:
        A tensor of shape [max_reqs, total_blocks] where each entry
        physical_to_logical[req_id, physical_block] contains the logical
        block index for that physical block, or -1 if unused.
    """
    max_reqs, max_num_blocks = block_table.shape
    device = block_table.device

    physical_to_logical = torch.full(
        (max_reqs, total_blocks), -1, dtype=torch.long, device=device
    )

    # Only process valid blocks to avoid garbage values
    num_blocks_per_seq: torch.Tensor = cdiv(seq_lens, block_size)
    mask = (
        torch.arange(max_num_blocks, device=device)[None, :]
        < num_blocks_per_seq[:, None]
    )

    valid_block_table = torch.where(mask, block_table, 0)
    valid_logical_indices = torch.where(
        mask, torch.arange(max_num_blocks, device=device)[None, :], 0
    )

    physical_to_logical.scatter_reduce_(
        -1, valid_block_table.to(torch.int64), valid_logical_indices, reduce="amax"
    )
    # NB - Seems like block 0 is always empty so we reset it manually
    physical_to_logical[:, 0] = -1
    return physical_to_logical

unique_static_unsorted

unique_static_unsorted(
    x: Tensor,
    *,
    M: int,
    dim: int = -1,
    ignored_val: int = 0,
    pad_val: int = -1,
) -> Tensor
  • Keeps the first occurrence of each non-zero value while preserving order, then left-packs those uniques and fills the rest with pad_val.
  • Returns (packed, keep_mask) with the same shape as x.
  • Requires that all values be in the range [0, M]
  • Skips ignored_val

Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory.

Example: x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1]

Source code in vllm/v1/attention/backends/flex_attention.py
def unique_static_unsorted(
    x: torch.Tensor,
    *,
    M: int,  # maximum positive value (0 is “skip me”)
    dim: int = -1,  # axis along which to deduplicate
    ignored_val: int = 0,  # value to ignore
    pad_val: int = -1,  # sentinel for unused slots
) -> torch.Tensor:
    """
    - Keeps the first occurrence of each non-zero value while preserving order,
      then left-packs those uniques and fills the rest with `pad_val`.
    - Returns (packed, keep_mask) with the *same shape* as `x`.
    - Requires that all values be in the range [0, M]
    - Skips ignored_val

    Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory.

    Example:
    x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1]
    """
    if not (-1 <= pad_val <= M):
        raise ValueError("`pad_val` must lie in [-1, M]")

    # ── move `dim` to the end so we can treat tensor as [B, N] ──────────
    dim = dim % x.ndim
    x_perm = x.movedim(dim, -1)  # shape [..., N]
    B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1]
    x_flat = x_perm.reshape(B, N)  # [B, N]

    device = x.device
    idx = torch.arange(N, device=device).expand(B, N)  # per-row indices

    # ── build first-occurrence table for every v ∈ [0, M] ───────────────
    first_idx = torch.full((B, M + 1), N, device=device)  # “∞”
    # scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i) for each i
    first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin")

    # ── keep mask: first occurrence *and* value ≠ 0 ─────────────────────
    keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat))  # [B, N]

    # ── left-pack uniques into a fresh tensor ───────────────────────────
    dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1  # where to go
    packed_flat = torch.full_like(x_flat, pad_val)

    rows, src_cols = torch.nonzero(keep, as_tuple=True)
    packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols]

    # ── restore original layout ─────────────────────────────────────────
    packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim)
    return packed