class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic):
"""
BF16 unquantized TRTLLM-Gen MoE kernels. Supports monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_trtllm_fused_moe()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""BF16 kernels do not support non-gated MoE"""
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports only unquantized inputs."""
return weight_key is None and activation_key is None
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU]
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Default,
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
# NOTE: TRTLLM Kernel has issue with Qwen3.5 router.
# Re-enable once the issue is resolved.
# https://github.com/vllm-project/vllm/issues/37591
# RoutingMethodType.Renormalize,
# RoutingMethodType.RenormalizeNaive
]
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_ag_rs_all2all_kernels
) and not moe_parallel_config.enable_eplb
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
@property
def expects_unquantized_inputs(self) -> bool:
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
import flashinfer
return flashinfer.fused_moe.trtllm_bf16_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
gemm1_weights=w1,
gemm2_weights=w2,
num_experts=global_num_experts,
top_k=self.topk,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routing_method_type=self.routing_method_type,
)