citeformer.backends.hf

HuggingFace transformers backend with grammar-level citation enforcement.

The flagship backend: loads a transformers causal LM, builds the §10.1 citation grammar for the in-scope sources, compiles it with XGrammar’s tokenizer-aware compiler, and masks logits at every decode step so the model cannot emit a citation marker that refers to a non-existent source.

Requires the hf extra: pip install citeformer[hf].

The integration surface is narrow on purpose — xgrammar does the masking, transformers does the sampling, we just wire them up via xgrammar.contrib.hf.LogitsProcessor. Note that the LogitsProcessor is stateful per-generation: we construct a new one for every generate() call (that’s an xgrammar constraint, not ours).

Module Contents

Classes

HFBackend

Transformers + XGrammar backend with logit-level citation enforcement.

API

class citeformer.backends.hf.HFBackend(model: str, *, device: str | None = None, dtype: str = 'auto')

Bases: citeformer.backends.base.Backend

Transformers + XGrammar backend with logit-level citation enforcement.

Citation markers are structurally unforgeable on this backend: at every decode step XGrammar masks tokens that would produce an [N] where N is not a valid source index, so sampling simply cannot traverse a path that emits a fabricated citation.

Example: >>> from citeformer import Citeformer, Source >>> from citeformer.backends.hf import HFBackend >>> sources = [Source(metadata={“id”: “a”, “type”: “book”}, content=”…”)] >>> backend = HFBackend(model=”gpt2”) >>> cf = Citeformer(backend=backend) >>> result = cf.generate(prompt=”Describe the book.”, sources=sources)

Attributes: model_name: HF model identifier passed to from_pretrained. device: Torch device the model lives on (cuda / mps / cpu). tokenizer: The loaded tokenizer. model: The loaded causal LM.

Initialization

Load the model + tokenizer and prepare the XGrammar compiler.

Args: model: HuggingFace model identifier (e.g. "gpt2", "microsoft/Phi-3.5-mini-instruct", or a local path). device: Torch device. If None, auto-detect CUDA > MPS > CPU. dtype: "auto" picks bfloat16 on CUDA, float16 on MPS, float32 on CPU. Explicit options: "bf16", "fp16", "fp32".

Raises: ImportError: If citeformer[hf] extras aren’t installed. ValueError: If dtype is unrecognized.

model_name: str

None

device: str

None

tokenizer: Any

None

model: Any

None

generate(prompt: str, sources: list[citeformer.core.Source], policy: citeformer.core.Policy, **options: Any) str

Generate text with grammar-masked decoding.

Args: prompt: User prompt. Caller is responsible for assembling any RAG context from sources into the prompt string (no implicit stitching — that’s a later-phase helper concern). sources: Sources in scope. Position (1-indexed) determines the citation id the model can emit. Must be non-empty. policy: Citation enforcement policy. **options: Sampling + grammar overrides — max_new_tokens (default 256), temperature (default 0.7), max_content_chars (default DEFAULT_MAX_CONTENT_CHARS; REQUIRED-policy soft progression bound — see ADR-009; pass None to disable). Unknown options are silently ignored.

Returns: The generated text as a string, containing only [N] markers where 1 <= N <= len(sources). Fabrication is structurally impossible.

Raises: ValueError: If sources is empty.

stream(prompt: str, sources: list[citeformer.core.Source], policy: citeformer.core.Policy, **options: Any) collections.abc.Iterator[str]

Stream text chunks as the model decodes them.

Runs model.generate on a background thread and yields decoded chunks via transformers’ TextIteratorStreamer. Grammar enforcement is identical to generate() — the XGrammar LogitsProcessor is wired in regardless of whether consumers are streaming or not.

Args: prompt: See generate(). sources: See generate(). policy: See generate(). **options: Same options as generate() (max_new_tokens, temperature, max_content_chars). A timeout option (seconds, default 60) caps how long a single chunk read blocks — useful to fail loudly if the background thread hangs.

Yields: Decoded text chunks in order. Joining them reproduces what generate() would have returned.

Raises: ValueError: If sources is empty.