Skip to content

Instantly share code, notes, and snippets.

@dcolley
Created May 20, 2026 23:20
Show Gist options
  • Select an option

  • Save dcolley/9383eccefbf6a661df0823ddbb2ac4d1 to your computer and use it in GitHub Desktop.

Select an option

Save dcolley/9383eccefbf6a661df0823ddbb2ac4d1 to your computer and use it in GitHub Desktop.
nvidia/Nemotron-Labs-Diffusion-14B running on spark DGX

venv

uv venv my_venv
source my_venv/bin/activate
uv pip install -r requirements.txt

requirements.txt

# Nemotron-Labs-Diffusion-14B API Server
#
# Core dependencies
torch==2.12.0
transformers==5.9.0
accelerate==1.13.0
safetensors==0.7.0
tokenizers==0.22.2
huggingface-hub==1.15.0
hf-xet==1.5.0

# API server
fastapi==0.136.1
uvicorn==0.47.0
pydantic==2.13.4
starlette==1.0.0

# NVIDIA CUDA (cu13 for spark2 GB10/ARM64)
nvidia-cublas==13.1.1.3
nvidia-cuda-cupti==13.0.85
nvidia-cuda-nvrtc==13.0.88
nvidia-cuda-runtime==13.0.96
nvidia-cudnn-cu13==9.20.0.48
nvidia-cufft==12.0.0.61
nvidia-cufile==1.15.1.6
nvidia-curand==10.4.0.35
nvidia-cusolver==12.0.4.66
nvidia-cusparse==12.6.3.3
nvidia-cusparselt-cu13==0.8.1
nvidia-nccl-cu13==2.29.7
nvidia-nvjitlink==13.0.88
nvidia-nvshmem-cu13==3.4.5
nvidia-nvtx==13.0.85
cuda-bindings==13.2.0
cuda-pathfinder==1.5.4
cuda-toolkit==13.0.2
triton==3.7.0

# Supporting libraries
numpy==2.4.6
filelock==3.29.0
fsspec==2026.4.0
packaging==26.2
psutil==7.2.2
pyyaml==6.0.3
regex==2026.5.9
tqdm==4.67.3
typing-extensions==4.15.0
setuptools==81.0.0

# HTTP/CLI deps (pulled in by huggingface-hub, fastapi)
httpx==0.28.1
httpcore==1.0.9
h11==0.16.0
certifi==2026.5.20
idna==3.15
anyio==4.13.0
click==8.4.0
jinja2==3.1.6
markupsafe==3.0.3
networkx==3.6.1
pydantic-core==2.46.4
pygments==2.20.0
rich==15.0.0
shellingham==1.5.4
typer==0.25.1
typing-inspection==0.4.2
annotated-types==0.7.0
annotated-doc==0.0.4
markdown-it-py==4.2.0
mdurl==0.1.2
mpmath==1.3.0
sympy==1.14.0

Python server

"""
Nemotron-Labs-Diffusion-14B OpenAI-compatible API server.

Supports both:
  - ar_generate: autoregressive mode (fast, token-by-token)
  - generate: block-diffusion mode (bidirectional denoising)

Endpoints:
  GET  /health
  GET  /v1/models
  GET  /v1/models/{model_id}
  POST /v1/chat/completions  (stream + non-stream)
  POST /v1/completions       (non-stream)
"""
import os
import sys
import time
import uuid
import json
import traceback
from typing import Any, List, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field

import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from transformers import AutoTokenizer, AutoModel

# ── Config ──────────────────────────────────────────────────────────
MODEL_NAME = os.getenv("MODEL_NAME", "nvidia/Nemotron-Labs-Diffusion-14B")
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("PORT", "8000"))
# "ar" = autoregressive (fast), "diffusion" = block-diffusion (quality)
DEFAULT_MODE = os.getenv("NEMOTRON_MODE", "ar")
# block_length for diffusion mode
BLOCK_LENGTH = int(os.getenv("NEMOTRON_BLOCK_LENGTH", "32"))

app = FastAPI(title="Nemotron OpenAI-Compatible API")

# ── Load model & tokenizer ──────────────────────────────────────────
print(f"Loading tokenizer for {MODEL_NAME} ...")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
)

print(f"Loading model {MODEL_NAME} ...")
model = AutoModel.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    dtype=torch.bfloat16,
    device_map="auto",
)
model.eval()
print("Model loaded successfully.")

# ── Pydantic models ─────────────────────────────────────────────────
class Message(BaseModel):
    model_config = ConfigDict(extra="allow")
    role: Literal["system", "user", "assistant", "tool", "developer"]
    content: Any


class ChatCompletionRequest(BaseModel):
    model_config = ConfigDict(extra="allow")

    model: Optional[str] = None
    messages: List[Message]

    temperature: Optional[float] = 0.2
    top_p: Optional[float] = None
    n: Optional[int] = 1
    stream: Optional[bool] = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: Optional[float] = None
    frequency_penalty: Optional[float] = None
    user: Optional[str] = None
    response_format: Optional[dict] = None

    max_tokens: Optional[int] = None
    max_completion_tokens: Optional[int] = None


class CompletionRequest(BaseModel):
    model_config = ConfigDict(extra="allow")

    model: Optional[str] = None
    prompt: Union[str, List[str]]

    temperature: Optional[float] = 0.2
    top_p: Optional[float] = None
    n: Optional[int] = 1
    stream: Optional[bool] = False
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: Optional[float] = None
    frequency_penalty: Optional[float] = None
    user: Optional[str] = None

    max_tokens: Optional[int] = None
    max_completion_tokens: Optional[int] = None


# ── Helpers ──────────────────────────────────────────────────────────

def build_prompt(messages: List[Message]) -> str:
    """Use the tokenizer's chat template for proper formatting."""
    msg_dicts = [{"role": m.role, "content": str(m.content)} for m in messages]
    # apply_chat_template with add_generation_prompt=True appends the
    # assistant prefix so the model continues from there.
    return tokenizer.apply_chat_template(
        msg_dicts,
        tokenize=False,
        add_generation_prompt=True,
    )


def get_stop_ids(stop: Optional[Union[str, List[str]]] = None) -> List[int]:
    """Convert stop strings to token IDs for early stopping."""
    stop_ids = [tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else []
    if stop:
        if isinstance(stop, str):
            stop = [stop]
        for s in stop:
            ids = tokenizer.encode(s, add_special_tokens=False)
            if ids:
                stop_ids.extend(ids)
    return list(set(stop_ids))


def generate_ar(prompt_ids: torch.Tensor, max_new_tokens: int,
                temperature: float, eos_token_id: Optional[int]):
    """Autoregressive generation — token-by-token, fast."""
    return model.ar_generate(
        prompt_ids=prompt_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        eos_token_id=eos_token_id,
    )


def generate_diffusion(prompt_ids: torch.Tensor, max_new_tokens: int,
                       temperature: float, eos_token_id: Optional[int],
                       block_length: int = 32):
    """Block-diffusion generation — bidirectional denoising."""
    # Round max_new_tokens up to a multiple of block_length
    if max_new_tokens % block_length != 0:
        max_new_tokens = ((max_new_tokens // block_length) + 1) * block_length
    return model.generate(
        prompt_ids=prompt_ids,
        max_new_tokens=max_new_tokens,
        block_length=block_length,
        temperature=temperature,
        eos_token_id=eos_token_id,
    )


# ── Endpoints ───────────────────────────────────────────────────────

@app.get("/health")
def health():
    return {"status": "ok", "model": MODEL_NAME, "mode": DEFAULT_MODE}


@app.get("/v1/models")
def list_models():
    return {
        "object": "list",
        "data": [
            {
                "id": MODEL_NAME,
                "object": "model",
                "created": int(time.time()),
                "owned_by": "local",
            }
        ],
    }


@app.get("/v1/models/{model_id:path}")
def get_model(model_id: str):
    if model_id != MODEL_NAME:
        raise HTTPException(status_code=404, detail="Model not found")
    return {
        "id": MODEL_NAME,
        "object": "model",
        "created": int(time.time()),
        "owned_by": "local",
    }


@app.post("/v1/completions")
def completions(req: CompletionRequest):
    prompt_text = req.prompt if isinstance(req.prompt, str) else "\n".join(req.prompt)
    messages = [Message(role="user", content=prompt_text)]
    return _run_chat(messages, req.model, req.max_completion_tokens or req.max_tokens,
                     req.temperature, req.stream, req.stop)


@app.post("/v1/chat/completions")
def chat_completions(req: ChatCompletionRequest):
    return _run_chat(
        req.messages, req.model,
        req.max_completion_tokens or req.max_tokens,
        req.temperature, req.stream, req.stop,
    )


def _run_chat(messages, model_id, max_tokens, temperature, stream, stop):
    try:
        prompt_text = build_prompt(messages)
        enc = tokenizer(prompt_text, return_tensors="pt")
        prompt_ids = enc["input_ids"].to(model.device)
        prompt_len = prompt_ids.shape[1]

        max_new_tokens = max_tokens or 256
        temperature = 0.0 if temperature is None else temperature
        eos_token_id = tokenizer.eos_token_id

        mode = DEFAULT_MODE

        with torch.inference_mode():
            if mode == "ar":
                output_ids, nfe = generate_ar(
                    prompt_ids, max_new_tokens, temperature, eos_token_id,
                )
            else:
                output_ids, nfe = generate_diffusion(
                    prompt_ids, max_new_tokens, temperature, eos_token_id,
                    block_length=BLOCK_LENGTH,
                )

        # output_ids includes the prompt; extract only the generated part
        generated_ids = output_ids[0, prompt_len:]
        content = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        # Truncate at stop sequences if provided
        if stop:
            stop_strs = [stop] if isinstance(stop, str) else stop
            for s in stop_strs:
                idx = content.find(s)
                if idx >= 0:
                    content = content[:idx]

        completion_tokens = len(generated_ids)
        chat_id = f"chatcmpl-{uuid.uuid4().hex}"

        if stream:
            return _stream_response(chat_id, content, model_id or MODEL_NAME,
                                     prompt_len, completion_tokens)

        return {
            "id": chat_id,
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model_id or MODEL_NAME,
            "choices": [
                {
                    "index": 0,
                    "message": {"role": "assistant", "content": content},
                    "finish_reason": "stop",
                }
            ],
            "usage": {
                "prompt_tokens": prompt_len,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_len + completion_tokens,
            },
        }

    except Exception as e:
        traceback.print_exc()
        return JSONResponse(
            status_code=400,
            content={
                "error": {
                    "message": str(e),
                    "type": "bad_request_error",
                    "code": 400,
                }
            },
        )


def _stream_response(chat_id: str, content: str, model_name: str,
                      prompt_tokens: int, completion_tokens: int):
    """SSE streaming: word-by-word emission."""
    import re
    # Split on whitespace boundaries but keep the delimiter
    chunks = re.split(r'(\s+)', content)

    def event_stream():
        for chunk in chunks:
            if not chunk:
                continue
            data = {
                "id": chat_id,
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": model_name,
                "choices": [
                    {
                        "index": 0,
                        "delta": {"content": chunk},
                        "finish_reason": None,
                    }
                ],
            }
            yield f"data: {json.dumps(data)}\n\n"

        # Final chunk with finish_reason
        final = {
            "id": chat_id,
            "object": "chat.completion.chunk",
            "created": int(time.time()),
            "model": model_name,
            "choices": [
                {
                    "index": 0,
                    "delta": {},
                    "finish_reason": "stop",
                }
            ],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
            },
        }
        yield f"data: {json.dumps(final)}\n\n"
        yield "data: [DONE]\n\n"

    return StreamingResponse(
        event_stream(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host=HOST, port=PORT)

systemd service file

[Unit]
Description=Nemotron-Labs-Diffusion-14B OpenAI-Compatible API Server
After=network.target

[Service]
Type=simple
User=derek
WorkingDirectory=/home/derek/models
ExecStart=/home/derek/models/venv_nemotron_api/bin/python /home/derek/models/nemotron.py
Restart=on-failure
RestartSec=30
Environment=MODEL_NAME=nvidia/Nemotron-Labs-Diffusion-14B
Environment=HOST=0.0.0.0
Environment=PORT=8000
Environment=NEMOTRON_MODE=ar

[Install]
WantedBy=multi-user.target
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment