Skip to content

vllm.distributed.device_communicators.flashinfer_all_reduce

FlashInferAllReduce

Source code in vllm/distributed/device_communicators/flashinfer_all_reduce.py
class FlashInferAllReduce:
    def __init__(
        self,
        group: ProcessGroup,
        device: int | str | torch.device,
    ):
        self.disabled = True

        if not fi_ar_available:
            logger.info(
                "FlashInfer All Reduce is disabled because flashinfer is not available"
            )
            return

        if not current_platform.is_cuda():
            logger.info(
                "FlashInfer All Reduce is disabled because it requires CUDA platform"
            )
            return

        self.group = group
        self.world_size = dist.get_world_size(self.group)
        self.rank = dist.get_rank(self.group)
        self.device = device
        if self.world_size == 1:
            return

        # Use the same threshold as the allreduce-rms fusion pass
        # TODO: tune the threshold
        MiB = 1024 * 1024
        max_workspace_size = PassConfig.default_fi_allreduce_fusion_max_size_mb().get(
            self.world_size, None
        )
        if not max_workspace_size:
            logger.warning(
                "FlashInfer All Reduce is disabled because it "
                "is not supported for world_size=%d.",
                self.world_size,
            )
            return
        self.max_workspace_size = max_workspace_size * MiB
        self.max_num_tokens = 0
        self.disabled = False

    def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
        """Ensure the all reduce workspace is initialized."""
        if self.max_num_tokens == 0:
            element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
            self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
        workspace = get_fi_ar_workspace(
            world_size=self.world_size,
            rank=self.rank,
            max_token_num=self.max_num_tokens,
            hidden_dim=hidden_dim,
            dtype=dtype,
            group=self.group,
        )
        if workspace is None:
            self.disabled = True
            return False
        return True

    def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
        if self.disabled:
            return False

        if not input_tensor.is_cuda:
            return False

        if not input_tensor.is_contiguous():
            return False

        if len(input_tensor.shape) != 2:
            return False

        num_tokens, hidden_dim = input_tensor.shape
        if not self.max_num_tokens:
            element_size = torch.tensor([], dtype=input_tensor.dtype).element_size()
            self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)

        if num_tokens > self.max_num_tokens:
            return False

        return self._ensure_workspace(hidden_dim, input_tensor.dtype)

    def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
        _, hidden_dim = input_tensor.shape
        workspace = get_fi_ar_workspace(
            world_size=self.world_size,
            rank=self.rank,
            max_token_num=self.max_num_tokens,
            hidden_dim=hidden_dim,
            dtype=input_tensor.dtype,
            group=self.group,
        )
        return flashinfer_comm.allreduce_fusion(
            input=input_tensor,
            workspace=workspace,
            pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce,
        )

    def destroy(self):
        if not self.disabled:
            destroy_fi_ar_workspace()

_ensure_workspace

_ensure_workspace(hidden_dim: int, dtype: dtype) -> bool

Ensure the all reduce workspace is initialized.

Source code in vllm/distributed/device_communicators/flashinfer_all_reduce.py
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
    """Ensure the all reduce workspace is initialized."""
    if self.max_num_tokens == 0:
        element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
        self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
    workspace = get_fi_ar_workspace(
        world_size=self.world_size,
        rank=self.rank,
        max_token_num=self.max_num_tokens,
        hidden_dim=hidden_dim,
        dtype=dtype,
        group=self.group,
    )
    if workspace is None:
        self.disabled = True
        return False
    return True

_create_workspace

_create_workspace(
    backend: str,
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: dtype,
    group: ProcessGroup,
)

Create a flashinfer allreduce workspace, returning None on failure.

Source code in vllm/distributed/device_communicators/flashinfer_all_reduce.py
def _create_workspace(
    backend: str,
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: torch.dtype,
    group: ProcessGroup,
):
    """Create a flashinfer allreduce workspace, returning None on failure."""
    comm_backend = TorchDistBackend(group=group)
    rng_state = random.getstate()
    try:
        random.seed(int.from_bytes(os.urandom(16), byteorder="big"))
        workspace = flashinfer_comm.create_allreduce_fusion_workspace(
            backend=backend,
            world_size=world_size,
            rank=rank,
            max_token_num=max_token_num,
            hidden_dim=hidden_dim,
            dtype=dtype,
            comm_backend=comm_backend,
        )
    except Exception as e:
        if "multicast" in str(e).lower():
            logger.warning_once(
                "Failed to initialize FlashInfer All Reduce workspace: %s. "
                "This is expected on GPUs without NVSwitch (e.g., NVLink "
                "bridge-only or PCIe topologies).",
                e,
            )
        else:
            logger.warning_once(
                "Failed to initialize FlashInfer All Reduce workspace: %s.",
                e,
            )
        return None
    finally:
        random.setstate(rng_state)
    logger.debug(
        "Initialized FlashInfer All Reduce workspace: backend=%s, "
        "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
        backend,
        world_size,
        rank,
        max_token_num,
        hidden_dim,
        dtype,
    )
    return workspace

get_fi_ar_quant_workspace

get_fi_ar_quant_workspace(
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: dtype,
    group: ProcessGroup,
)

Return the allreduce workspace for quant patterns, initializing if needed.

Always uses trtllm backend as it is the only one supporting quantization fusion (FP8/FP4).

Source code in vllm/distributed/device_communicators/flashinfer_all_reduce.py
def get_fi_ar_quant_workspace(
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: torch.dtype,
    group: ProcessGroup,
):
    """
    Return the allreduce workspace for quant patterns, initializing if needed.

    Always uses trtllm backend as it is the only one supporting quantization
    fusion (FP8/FP4).
    """
    global _fi_ar_quant_workspace
    if _fi_ar_quant_workspace is not None:
        return _fi_ar_quant_workspace

    # Reuse the non-quant workspace if it was already created with trtllm
    if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
        _fi_ar_quant_workspace = _fi_ar_workspace
        return _fi_ar_quant_workspace

    _fi_ar_quant_workspace = _create_workspace(
        "trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group
    )
    return _fi_ar_quant_workspace

get_fi_ar_workspace

get_fi_ar_workspace(
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: dtype,
    group: ProcessGroup,
)

Return the allreduce workspace for non-quant patterns, initializing if needed.

Used by AllReduceFusionPass (non-quant patterns) and FlashInferAllReduce for standalone allreduce. Backend is controlled by VLLM_FLASHINFER_ALLREDUCE_BACKEND env var.

Source code in vllm/distributed/device_communicators/flashinfer_all_reduce.py
def get_fi_ar_workspace(
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: torch.dtype,
    group: ProcessGroup,
):
    """
    Return the allreduce workspace for non-quant patterns, initializing if needed.

    Used by AllReduceFusionPass (non-quant patterns) and FlashInferAllReduce
    for standalone allreduce. Backend is controlled by
    VLLM_FLASHINFER_ALLREDUCE_BACKEND env var.
    """
    global _fi_ar_workspace
    if _fi_ar_workspace is not None:
        return _fi_ar_workspace

    backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND

    # Reuse the quant workspace if it was already created with the same backend
    if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend:
        _fi_ar_workspace = _fi_ar_quant_workspace
        return _fi_ar_workspace

    _fi_ar_workspace = _create_workspace(
        backend, world_size, rank, max_token_num, hidden_dim, dtype, group
    )
    return _fi_ar_workspace