Skip to content

API Reference

flashspec

FlashSpec — Adaptive speculative-decoding inference engine.

Adaptive speculative-decoding inference engine with Triton-optimised verification and online bandit draft selection.

Public API surface (AGENTS.md §13.2 — do not modify without explicit approval):

flashspec.SpeculativeEngine
flashspec.GenerationResult
flashspec.FlashSpecConfig
flashspec.BanditConfig
flashspec.SamplingConfig
flashspec.MetricsConfig
flashspec.register          (draft model decorator)
flashspec.get_drafter
flashspec.list_drafters
References

.. [1] Leviathan et al. (2023), "Fast Inference from Transformers via Speculative Decoding", arXiv:2211.17192. .. [2] Myet (2025), "FlashSpec: Adaptive Speculative Decoding with Online Bandit Draft Selection and Triton-Optimised Verification".

SpeculativeEngine

Adaptive speculative-decoding inference engine.

Orchestrates the full speculative decoding loop: 1. Bandit selects a draft model arm. 2. Draft model proposes gamma tokens autoregressively. 3. Target model scores all gamma positions in one forward pass. 4. Sampling algorithm accepts/rejects draft tokens and samples residual. 5. Bandit is updated with the acceptance outcome. 6. Metrics are updated.

Parameters:

Name Type Description Default
config FlashSpecConfig

Full engine configuration.

required
drafters list[DraftModel]

List of draft model instances, one per bandit arm.

required
target TargetModel

The target (verifier) model instance.

required
bandit DraftSelector

Online bandit selector for adaptive arm selection.

required

Raises:

Type Description
ValueError

If len(drafters) does not match config.bandit.n_arms.

Notes

Output distribution correctness (Algorithm 1, Leviathan et al. 2023) is verified by tests/integration/test_e2e_sampling.py (KS test, α=0.01, N=10,000). This is a CI hard gate.

Examples:

>>> result = engine.generate(input_ids, max_new_tokens=200)
>>> result.acceptance_rate
0.73
Source code in flashspec/engine/speculative.py
class SpeculativeEngine:
    """Adaptive speculative-decoding inference engine.

    Orchestrates the full speculative decoding loop:
    1. Bandit selects a draft model arm.
    2. Draft model proposes ``gamma`` tokens autoregressively.
    3. Target model scores all ``gamma`` positions in one forward pass.
    4. Sampling algorithm accepts/rejects draft tokens and samples residual.
    5. Bandit is updated with the acceptance outcome.
    6. Metrics are updated.

    Parameters
    ----------
    config : FlashSpecConfig
        Full engine configuration.
    drafters : list[DraftModel]
        List of draft model instances, one per bandit arm.
    target : TargetModel
        The target (verifier) model instance.
    bandit : DraftSelector
        Online bandit selector for adaptive arm selection.

    Raises
    ------
    ValueError
        If ``len(drafters)`` does not match ``config.bandit.n_arms``.

    Notes
    -----
    Output distribution correctness (Algorithm 1, Leviathan et al. 2023) is
    verified by ``tests/integration/test_e2e_sampling.py`` (KS test, α=0.01,
    N=10,000).  This is a CI hard gate.

    Examples
    --------
    >>> result = engine.generate(input_ids, max_new_tokens=200)
    >>> result.acceptance_rate
    0.73
    """

    def __init__(
        self,
        config: FlashSpecConfig,
        drafters: list[DraftModel],
        target: TargetModel,
        bandit: DraftSelector,
    ) -> None:
        if len(drafters) != config.bandit.n_arms:
            raise ValueError(
                f"len(drafters)={len(drafters)} must match "
                f"config.bandit.n_arms={config.bandit.n_arms}."
            )
        self._config = config
        self._drafters = drafters
        self._target = target
        self._bandit = bandit
        self._device = get_device(config.device)

        self._acceptance_tracker = AcceptanceTracker(gamma=config.sampling.gamma)
        self._throughput_tracker = ThroughputTracker()
        self._latency_tracker = LatencyTracker(
            window=config.metrics.latency_window,
        )

    def generate(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int | None = None,
    ) -> GenerationResult:
        """Run the speculative decoding generation loop.

        Parameters
        ----------
        input_ids : torch.Tensor
            Prompt token IDs.  Shape: ``(batch_size, seq_len)``, dtype int64.
        max_new_tokens : int or None
            Maximum tokens to generate.  Defaults to ``config.max_new_tokens``.

        Returns
        -------
        GenerationResult
            Contains output token IDs and generation metrics (acceptance rate,
            tokens per second).

        Raises
        ------
        ValueError
            If ``input_ids`` is not 2-D or has batch size 0.

        Notes
        -----
        The output distribution of this method is provably identical to
        autoregressive sampling from the target model (Leviathan et al., 2023
        Theorem 1).  This invariant is enforced by the KS-test gate in
        ``tests/integration/test_e2e_sampling.py`` at α=0.01, N=10,000.

        The generation loop follows these steps on each iteration:

        1. Bandit selects a draft model arm.
        2. Draft model proposes ``gamma`` tokens autoregressively.
        3. Target model scores all ``gamma`` positions in one forward pass
           with temperature applied before log-softmax (§7).
        4. Rejection sampling accepts/rejects tokens (Algorithm 1).
        5. Bandit is updated with the acceptance outcome.
        6. Metrics are updated.
        7. Accepted tokens are appended to the context.

        Examples
        --------
        >>> result = engine.generate(input_ids, max_new_tokens=128)
        >>> result.n_tokens_generated
        128
        """
        if input_ids.dim() != 2:
            raise ValueError(
                f"input_ids must be 2-D (batch_size, seq_len); got shape {input_ids.shape}."
            )
        if input_ids.shape[0] == 0:
            raise ValueError("input_ids batch size must be >= 1.")

        max_tokens = max_new_tokens if max_new_tokens is not None else self._config.max_new_tokens
        gamma = self._config.sampling.gamma
        config_sampling = self._config.sampling

        context = input_ids.to(self._device)
        generated_ids: list[torch.Tensor] = []

        self._throughput_tracker.start()

        while sum(t.shape[1] for t in generated_ids) < max_tokens:
            self._latency_tracker.start()

            # 1. Bandit selects a draft arm.
            arm = self._bandit.select()
            drafter = self._drafters[arm]

            # 2. Draft model proposes gamma tokens.
            draft_token_ids, draft_logprobs = drafter.generate_draft(context, gamma)

            # 3. Target model scores draft positions (temperature applied inside).
            target_logprobs = self._target.score_draft(
                context, draft_token_ids, gamma,
                temperature=config_sampling.temperature,
            )

            # 4. Accept/reject sampling.
            if config_sampling.variant == "rejection":
                accepted_ids, first_rejection, alpha = rejection_sample(
                    input_ids=context,
                    draft_logprobs=draft_logprobs,
                    target_logprobs=target_logprobs,
                    draft_token_ids=draft_token_ids,
                    gamma=gamma,
                )
            else:
                accepted_ids, first_rejection, alpha = typical_sample(
                    draft_logprobs=draft_logprobs,
                    target_logprobs=target_logprobs,
                    draft_token_ids=draft_token_ids,
                    gamma=gamma,
                    typical_p=config_sampling.top_p,
                )

            # 5. Update bandit with acceptance count.
            n_accepted = int(first_rejection.float().mean().item())
            self._bandit.update(arm, accepted=n_accepted)

            # 6. Update metrics.
            self._acceptance_tracker.record(n_accepted)

            # Collect non-padding accepted tokens.
            new_tokens = accepted_ids[accepted_ids != -1].view(input_ids.shape[0], -1)
            if new_tokens.shape[1] > 0:
                generated_ids.append(new_tokens)
                context = torch.cat([context, new_tokens], dim=1)

            self._latency_tracker.stop()

        total_new = sum(t.shape[1] for t in generated_ids)
        output_ids = (
            torch.cat(generated_ids, dim=1)[:, :max_tokens]
            if generated_ids
            else input_ids[:, :0]
        )
        tokens_per_second = self._throughput_tracker.stop(n_tokens=total_new)

        logger.debug(
            "Generation complete",
            extra={
                "n_tokens": total_new,
                "tokens_per_second": tokens_per_second,
                "alpha": self._acceptance_tracker.mean_acceptance_rate,
            },
        )

        return GenerationResult(
            output_ids=output_ids,
            n_tokens_generated=total_new,
            acceptance_rate=self._acceptance_tracker.mean_acceptance_rate,
            tokens_per_second=tokens_per_second,
        )
generate(input_ids, max_new_tokens=None)

Run the speculative decoding generation loop.

Parameters:

Name Type Description Default
input_ids Tensor

Prompt token IDs. Shape: (batch_size, seq_len), dtype int64.

required
max_new_tokens int or None

Maximum tokens to generate. Defaults to config.max_new_tokens.

None

Returns:

Type Description
GenerationResult

Contains output token IDs and generation metrics (acceptance rate, tokens per second).

Raises:

Type Description
ValueError

If input_ids is not 2-D or has batch size 0.

Notes

The output distribution of this method is provably identical to autoregressive sampling from the target model (Leviathan et al., 2023 Theorem 1). This invariant is enforced by the KS-test gate in tests/integration/test_e2e_sampling.py at α=0.01, N=10,000.

The generation loop follows these steps on each iteration:

  1. Bandit selects a draft model arm.
  2. Draft model proposes gamma tokens autoregressively.
  3. Target model scores all gamma positions in one forward pass with temperature applied before log-softmax (§7).
  4. Rejection sampling accepts/rejects tokens (Algorithm 1).
  5. Bandit is updated with the acceptance outcome.
  6. Metrics are updated.
  7. Accepted tokens are appended to the context.

Examples:

>>> result = engine.generate(input_ids, max_new_tokens=128)
>>> result.n_tokens_generated
128
Source code in flashspec/engine/speculative.py
def generate(
    self,
    input_ids: torch.Tensor,
    max_new_tokens: int | None = None,
) -> GenerationResult:
    """Run the speculative decoding generation loop.

    Parameters
    ----------
    input_ids : torch.Tensor
        Prompt token IDs.  Shape: ``(batch_size, seq_len)``, dtype int64.
    max_new_tokens : int or None
        Maximum tokens to generate.  Defaults to ``config.max_new_tokens``.

    Returns
    -------
    GenerationResult
        Contains output token IDs and generation metrics (acceptance rate,
        tokens per second).

    Raises
    ------
    ValueError
        If ``input_ids`` is not 2-D or has batch size 0.

    Notes
    -----
    The output distribution of this method is provably identical to
    autoregressive sampling from the target model (Leviathan et al., 2023
    Theorem 1).  This invariant is enforced by the KS-test gate in
    ``tests/integration/test_e2e_sampling.py`` at α=0.01, N=10,000.

    The generation loop follows these steps on each iteration:

    1. Bandit selects a draft model arm.
    2. Draft model proposes ``gamma`` tokens autoregressively.
    3. Target model scores all ``gamma`` positions in one forward pass
       with temperature applied before log-softmax (§7).
    4. Rejection sampling accepts/rejects tokens (Algorithm 1).
    5. Bandit is updated with the acceptance outcome.
    6. Metrics are updated.
    7. Accepted tokens are appended to the context.

    Examples
    --------
    >>> result = engine.generate(input_ids, max_new_tokens=128)
    >>> result.n_tokens_generated
    128
    """
    if input_ids.dim() != 2:
        raise ValueError(
            f"input_ids must be 2-D (batch_size, seq_len); got shape {input_ids.shape}."
        )
    if input_ids.shape[0] == 0:
        raise ValueError("input_ids batch size must be >= 1.")

    max_tokens = max_new_tokens if max_new_tokens is not None else self._config.max_new_tokens
    gamma = self._config.sampling.gamma
    config_sampling = self._config.sampling

    context = input_ids.to(self._device)
    generated_ids: list[torch.Tensor] = []

    self._throughput_tracker.start()

    while sum(t.shape[1] for t in generated_ids) < max_tokens:
        self._latency_tracker.start()

        # 1. Bandit selects a draft arm.
        arm = self._bandit.select()
        drafter = self._drafters[arm]

        # 2. Draft model proposes gamma tokens.
        draft_token_ids, draft_logprobs = drafter.generate_draft(context, gamma)

        # 3. Target model scores draft positions (temperature applied inside).
        target_logprobs = self._target.score_draft(
            context, draft_token_ids, gamma,
            temperature=config_sampling.temperature,
        )

        # 4. Accept/reject sampling.
        if config_sampling.variant == "rejection":
            accepted_ids, first_rejection, alpha = rejection_sample(
                input_ids=context,
                draft_logprobs=draft_logprobs,
                target_logprobs=target_logprobs,
                draft_token_ids=draft_token_ids,
                gamma=gamma,
            )
        else:
            accepted_ids, first_rejection, alpha = typical_sample(
                draft_logprobs=draft_logprobs,
                target_logprobs=target_logprobs,
                draft_token_ids=draft_token_ids,
                gamma=gamma,
                typical_p=config_sampling.top_p,
            )

        # 5. Update bandit with acceptance count.
        n_accepted = int(first_rejection.float().mean().item())
        self._bandit.update(arm, accepted=n_accepted)

        # 6. Update metrics.
        self._acceptance_tracker.record(n_accepted)

        # Collect non-padding accepted tokens.
        new_tokens = accepted_ids[accepted_ids != -1].view(input_ids.shape[0], -1)
        if new_tokens.shape[1] > 0:
            generated_ids.append(new_tokens)
            context = torch.cat([context, new_tokens], dim=1)

        self._latency_tracker.stop()

    total_new = sum(t.shape[1] for t in generated_ids)
    output_ids = (
        torch.cat(generated_ids, dim=1)[:, :max_tokens]
        if generated_ids
        else input_ids[:, :0]
    )
    tokens_per_second = self._throughput_tracker.stop(n_tokens=total_new)

    logger.debug(
        "Generation complete",
        extra={
            "n_tokens": total_new,
            "tokens_per_second": tokens_per_second,
            "alpha": self._acceptance_tracker.mean_acceptance_rate,
        },
    )

    return GenerationResult(
        output_ids=output_ids,
        n_tokens_generated=total_new,
        acceptance_rate=self._acceptance_tracker.mean_acceptance_rate,
        tokens_per_second=tokens_per_second,
    )

GenerationResult dataclass

Result of a single :meth:SpeculativeEngine.generate call.

Parameters:

Name Type Description Default
output_ids Tensor

Generated token IDs. Shape: (batch_size, n_new_tokens).

required
n_tokens_generated int

Total number of new tokens in output_ids.

required
acceptance_rate float

Mean draft-token acceptance rate (alpha) for this call.

required
tokens_per_second float

Wall-clock tokens per second for this call.

required
Source code in flashspec/engine/speculative.py
@dataclass(slots=True, frozen=True)
class GenerationResult:
    """Result of a single :meth:`SpeculativeEngine.generate` call.

    Parameters
    ----------
    output_ids : torch.Tensor
        Generated token IDs.  Shape: ``(batch_size, n_new_tokens)``.
    n_tokens_generated : int
        Total number of new tokens in ``output_ids``.
    acceptance_rate : float
        Mean draft-token acceptance rate (alpha) for this call.
    tokens_per_second : float
        Wall-clock tokens per second for this call.
    """

    output_ids: torch.Tensor
    n_tokens_generated: int
    acceptance_rate: float
    tokens_per_second: float

FlashSpecConfig

Bases: BaseModel

Top-level FlashSpec runtime configuration.

Parameters:

Name Type Description Default
drafter_name str

Registry key of the draft model to use (e.g. "llama3-1b").

required
target_name str

HuggingFace model identifier (or local path) for the target model.

required
device str

PyTorch device string, e.g. "cuda:0" or "cpu".

required
dtype ('float32', 'bfloat16', 'float16')

Compute dtype for model forward passes.

"float32"
max_new_tokens int

Maximum tokens to generate per call.

required
bandit BanditConfig

Bandit algorithm settings.

required
sampling SamplingConfig

Speculative sampling settings.

required
metrics MetricsConfig

Metrics collection settings.

required

Raises:

Type Description
ValueError

If max_new_tokens < 1 or device is an invalid string.

Source code in flashspec/utils/config.py
class FlashSpecConfig(BaseModel):
    """Top-level FlashSpec runtime configuration.

    Parameters
    ----------
    drafter_name : str
        Registry key of the draft model to use (e.g. ``"llama3-1b"``).
    target_name : str
        HuggingFace model identifier (or local path) for the target model.
    device : str
        PyTorch device string, e.g. ``"cuda:0"`` or ``"cpu"``.
    dtype : {"float32", "bfloat16", "float16"}
        Compute dtype for model forward passes.
    max_new_tokens : int
        Maximum tokens to generate per call.
    bandit : BanditConfig
        Bandit algorithm settings.
    sampling : SamplingConfig
        Speculative sampling settings.
    metrics : MetricsConfig
        Metrics collection settings.

    Raises
    ------
    ValueError
        If ``max_new_tokens`` < 1 or ``device`` is an invalid string.
    """

    drafter_name: str = "llama3-1b"
    target_name: str = "meta-llama/Llama-3-8B-Instruct"
    device: str = "cuda:0"
    dtype: Literal["float32", "bfloat16", "float16"] = "bfloat16"
    max_new_tokens: int = Field(default=512, ge=1)
    bandit: BanditConfig = Field(default_factory=BanditConfig)
    sampling: SamplingConfig = Field(default_factory=SamplingConfig)
    metrics: MetricsConfig = Field(default_factory=MetricsConfig)

    model_config = {"frozen": True}

    @model_validator(mode="after")
    def _validate_device_format(self) -> "FlashSpecConfig":
        """Validate that device string is non-empty and plausibly formatted."""
        if not self.device.strip():
            raise ValueError(
                f"'device' must be a non-empty string, got: {self.device!r}"
            )
        return self

BanditConfig

Bases: BaseModel

Configuration for the online bandit draft selector.

Parameters:

Name Type Description Default
strategy ('ucb1', 'thompson', 'oracle')

Which bandit algorithm to use.

"ucb1"
window_size int

Number of recent rounds to include in windowed statistics. Set to 0 to disable windowing (use all history).

required
n_arms int

Number of draft-model arms available to the bandit.

required
ucb_exploration_constant float

Exploration constant for UCB1 (the factor multiplying √(2 log t / n_k)). Default 1.0 matches the theoretical UCB1 derivation.

required
thompson_prior_alpha float

Initial α of the Beta prior for Thompson sampling.

required
thompson_prior_beta float

Initial β of the Beta prior for Thompson sampling.

required

Raises:

Type Description
ValueError

If n_arms < 1 or window_size < 0.

Source code in flashspec/utils/config.py
class BanditConfig(BaseModel):
    """Configuration for the online bandit draft selector.

    Parameters
    ----------
    strategy : {"ucb1", "thompson", "oracle"}
        Which bandit algorithm to use.
    window_size : int
        Number of recent rounds to include in windowed statistics.
        Set to 0 to disable windowing (use all history).
    n_arms : int
        Number of draft-model arms available to the bandit.
    ucb_exploration_constant : float
        Exploration constant for UCB1 (the factor multiplying √(2 log t / n_k)).
        Default 1.0 matches the theoretical UCB1 derivation.
    thompson_prior_alpha : float
        Initial α of the Beta prior for Thompson sampling.
    thompson_prior_beta : float
        Initial β of the Beta prior for Thompson sampling.

    Raises
    ------
    ValueError
        If ``n_arms`` < 1 or ``window_size`` < 0.
    """

    strategy: Literal["ucb1", "thompson", "oracle"] = "ucb1"
    window_size: int = Field(default=500, ge=0)
    n_arms: int = Field(default=2, ge=1)
    ucb_exploration_constant: float = Field(default=1.0, gt=0.0)
    thompson_prior_alpha: float = Field(default=1.0, gt=0.0)
    thompson_prior_beta: float = Field(default=1.0, gt=0.0)

    model_config = {"frozen": True}

SamplingConfig

Bases: BaseModel

Configuration for speculative sampling.

Parameters:

Name Type Description Default
gamma int

Speculation length — number of draft tokens proposed per step.

required
temperature float

Sampling temperature applied to target logits before log-softmax. Must be positive; use 1.0 for unscaled probabilities.

required
top_p float

Nucleus sampling threshold. Set to 1.0 to disable.

required
seed int or None

Random seed for reproducibility. None means non-deterministic.

required
variant ('rejection', 'typical')

Sampling algorithm variant.

"rejection"

Raises:

Type Description
ValueError

If gamma < 1, temperature ≤ 0, or top_p not in (0, 1].

Source code in flashspec/utils/config.py
class SamplingConfig(BaseModel):
    """Configuration for speculative sampling.

    Parameters
    ----------
    gamma : int
        Speculation length — number of draft tokens proposed per step.
    temperature : float
        Sampling temperature applied to target logits before log-softmax.
        Must be positive; use 1.0 for unscaled probabilities.
    top_p : float
        Nucleus sampling threshold.  Set to 1.0 to disable.
    seed : int or None
        Random seed for reproducibility.  ``None`` means non-deterministic.
    variant : {"rejection", "typical"}
        Sampling algorithm variant.

    Raises
    ------
    ValueError
        If ``gamma`` < 1, ``temperature`` ≤ 0, or ``top_p`` not in (0, 1].
    """

    gamma: int = Field(default=4, ge=1)
    temperature: float = Field(default=1.0, gt=0.0)
    top_p: float = Field(default=1.0, gt=0.0, le=1.0)
    seed: int | None = None
    variant: Literal["rejection", "typical"] = "rejection"

    model_config = {"frozen": True}

MetricsConfig

Bases: BaseModel

Configuration for metrics collection.

Parameters:

Name Type Description Default
track_acceptance bool

Whether to track per-step token acceptance rates.

required
track_throughput bool

Whether to track tokens/s and MFU.

required
track_latency bool

Whether to track p50/p95/p99 step latency.

required
latency_window int

Rolling window size for latency percentile computation.

required
Source code in flashspec/utils/config.py
class MetricsConfig(BaseModel):
    """Configuration for metrics collection.

    Parameters
    ----------
    track_acceptance : bool
        Whether to track per-step token acceptance rates.
    track_throughput : bool
        Whether to track tokens/s and MFU.
    track_latency : bool
        Whether to track p50/p95/p99 step latency.
    latency_window : int
        Rolling window size for latency percentile computation.
    """

    track_acceptance: bool = True
    track_throughput: bool = True
    track_latency: bool = True
    latency_window: int = Field(default=1000, ge=10)

    model_config = {"frozen": True}

register(name)

Class decorator that registers a draft model under name.

Parameters:

Name Type Description Default
name str

Registry key (e.g. "llama3-1b"). Must be unique across all registered drafters.

required

Returns:

Type Description
Callable

The decorator function; returns the decorated class unchanged so the class can still be used normally after decoration.

Raises:

Type Description
ValueError

If name is already registered.

Notes

The registry is a module-level dict _REGISTRY. Names are case-sensitive. External packages can also register drafters via Python entry points under the flashspec.drafters group without modifying this module (see §4.4 of AGENTS.md).

Examples:

>>> @register("llama3-1b")
... class Llama3_1B_Drafter:
...     ...
Source code in flashspec/engine/drafter.py
def register(name: str) -> Any:
    """Class decorator that registers a draft model under ``name``.

    Parameters
    ----------
    name : str
        Registry key (e.g. ``"llama3-1b"``).  Must be unique across all
        registered drafters.

    Returns
    -------
    Callable
        The decorator function; returns the decorated class unchanged so
        the class can still be used normally after decoration.

    Raises
    ------
    ValueError
        If ``name`` is already registered.

    Notes
    -----
    The registry is a module-level dict ``_REGISTRY``.  Names are
    case-sensitive.  External packages can also register drafters via Python
    entry points under the ``flashspec.drafters`` group without modifying
    this module (see §4.4 of AGENTS.md).

    Examples
    --------
    >>> @register("llama3-1b")
    ... class Llama3_1B_Drafter:
    ...     ...
    """
    def _decorator(cls: type) -> type:
        if name in _REGISTRY:
            raise ValueError(
                f"Drafter '{name}' is already registered. "
                "Use a unique name or unregister the existing entry first."
            )
        _REGISTRY[name] = cls
        logger.debug("Drafter registered", extra={"name": name, "class": cls.__name__})
        return cls
    return _decorator

get_drafter(name)

Look up a registered draft model class by name.

Parameters:

Name Type Description Default
name str

Registry key, as used in :func:register.

required

Returns:

Type Description
type[DraftModel]

The registered class (not an instance; caller must instantiate it).

Raises:

Type Description
KeyError

If name is not found in the registry after loading all entry points.

Notes

Triggers :func:_load_entry_points on every call so that external packages installed after the interpreter started are discovered automatically. The lookup is O(1) dict access after entry-point loading.

Examples:

>>> DrClass = get_drafter("llama3-1b")
>>> drafter = DrClass(device="cuda:0")
Source code in flashspec/engine/drafter.py
def get_drafter(name: str) -> type[DraftModel]:
    """Look up a registered draft model class by name.

    Parameters
    ----------
    name : str
        Registry key, as used in :func:`register`.

    Returns
    -------
    type[DraftModel]
        The registered class (not an instance; caller must instantiate it).

    Raises
    ------
    KeyError
        If ``name`` is not found in the registry after loading all entry points.

    Notes
    -----
    Triggers :func:`_load_entry_points` on every call so that external packages
    installed after the interpreter started are discovered automatically.
    The lookup is O(1) dict access after entry-point loading.

    Examples
    --------
    >>> DrClass = get_drafter("llama3-1b")
    >>> drafter = DrClass(device="cuda:0")
    """
    _load_entry_points()
    if name not in _REGISTRY:
        available = list(_REGISTRY.keys())
        raise KeyError(
            f"No drafter registered under '{name}'. Available: {available}"
        )
    return _REGISTRY[name]

list_drafters()

Return all registered drafter names in alphabetical order.

Returns:

Type Description
list[str]

Sorted list of registry keys.

Notes

Triggers :func:_load_entry_points to discover any externally registered drafters before returning the list. The list reflects the state of the registry at call time; it is not a live view.

Examples:

>>> list_drafters()
['llama3-1b', 'llama3-68m']
Source code in flashspec/engine/drafter.py
def list_drafters() -> list[str]:
    """Return all registered drafter names in alphabetical order.

    Returns
    -------
    list[str]
        Sorted list of registry keys.

    Notes
    -----
    Triggers :func:`_load_entry_points` to discover any externally registered
    drafters before returning the list.  The list reflects the state of the
    registry at call time; it is not a live view.

    Examples
    --------
    >>> list_drafters()
    ['llama3-1b', 'llama3-68m']
    """
    _load_entry_points()
    return sorted(_REGISTRY.keys())