encryptd
Optimize speculative decoding performance by increasing max_num_batched_tokens to 4096
a0613f9
import os
import subprocess
import sys
import httpx
import json
import base64
from io import BytesIO
# Fix for Python 3.13 audioop removal
try:
import audioop
except ImportError:
import audioop_lts as audioop
sys.modules["audioop"] = audioop
from fastapi import Request,FastAPI
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
import gradio as gr
from openai import OpenAI
# --- CONFIGURATION ---
MODEL_ID = "numind/NuExtract3"
GPU_UTILIZATION = 0.95
MAX_MODEL_LEN = 16384
VLLM_PORT = 8000
HF_PORT = 7860
# --- STEP 1: START vLLM ---
def start_vllm():
if "VLLM_PID" in os.environ:
return
print("🚀 Starting vLLM engine...")
command = [
"python3", "-m", "vllm.entrypoints.openai.api_server",
"--model", MODEL_ID,
"--host", "127.0.0.1",
"--port", str(VLLM_PORT),
"--trust-remote-code",
"--gpu-memory-utilization", str(GPU_UTILIZATION),
"--max-model-len", "131072",
"--max-num-batched-tokens", "4096",
"--dtype", "bfloat16",
"--limit-mm-per-prompt", '{"image": 99, "video": 0}',
"--chat-template-content-format", "openai",
"--generation-config", "vllm",
"--speculative-config", '{"method": "qwen3_next_mtp", "num_speculative_tokens": 2}',
# Optional but helpful
"--enforce-eager"
]
# Connect vLLM logs to the HF console logs
subprocess.Popen(command, stdout=sys.stdout, stderr=sys.stderr)
os.environ["VLLM_PID"] = "running"
start_vllm()
# --- STEP 2: FASTAPI PROXY (API) ---
app = FastAPI()
# We add the external API proxy directly to this app
@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def gatekeeper_proxy(path: str, request: Request):
target_url = f"http://127.0.0.1:{VLLM_PORT}/v1/{path}"
# Strip Host and Content-Length to prevent routing loops on HF
headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length"]}
async with httpx.AsyncClient(timeout=300.0) as client:
try:
if path == "chat/completions" and request.method == "POST":
body = await request.json()
if not body.get("stream", False):
resp = await client.post(target_url, headers=headers, json=body)
if resp.status_code == 200:
data = resp.json()
content = data["choices"][0]["message"].get("content", "")
# STRIP THINKING FROM EXTERNAL DOCLING API
if "</think>" in content:
data["choices"][0]["message"]["content"] = content.split("</think>")[-1].strip()
return JSONResponse(content=data)
return JSONResponse(status_code=resp.status_code, content=resp.json())
# Fallback for models list, etc.
proxy_req = client.build_request(request.method, target_url, headers=headers, content=await request.body())
r = await client.send(proxy_req, stream=True)
return StreamingResponse(r.aiter_raw(), status_code=r.status_code, headers=dict(r.headers))
except Exception as e:
return JSONResponse(status_code=503, content={"error": f"API Proxy Error: {str(e)}"})
# --- STEP 2: UI LOGIC ---
def run_ui_test(image, prompt):
if image is None: return "⚠️ Please upload an image."
# Internal check for vLLM
try:
with httpx.Client() as check:
check.get(f"http://127.0.0.1:{VLLM_PORT}/v1/models", timeout=2.0)
except:
return "⏳ Model is still loading... please wait 3-5 minutes."
client = OpenAI(base_url=f"http://127.0.0.1:{VLLM_PORT}/v1", api_key="EMPTY")
try:
image = image.convert("RGB")
buffered = BytesIO()
image.save(buffered, format="JPEG")
b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
completion = client.chat.completions.create(
model=MODEL_ID,
messages=[{"role": "user", "content": [
{"type": "text", "text": prompt or "Convert to markdown."},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}}
]}],
timeout=300.0
)
content = completion.choices[0].message.content
# Suppress reasoning for UI
return content.split("</think>")[-1].strip() if "</think>" in content else content
except Exception as e:
return f"❌ Error: {str(e)}"
with gr.Blocks(title="NuExtract3 API") as demo:
gr.Markdown("# NuExtract3 A100 API Server")
gr.Markdown("The API is live at `/v1/chat/completions` (Reasoning stripped automatically).")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Input Document")
txt_input = gr.Textbox(value="Convert to markdown.", label="Prompt")
btn = gr.Button("Extract Markdown", variant="primary")
with gr.Column():
out = gr.Textbox(label="Output", lines=20)
btn.click(run_ui_test, inputs=[img_input, txt_input], outputs=[out])
# --- STEP 3: ATTACH PROXY TO GRADIO'S APP ---
# We enable the queue for long tasks
# 1. FIX ATTRIBUTE ERROR: Patch missing attributes onto the demo object
demo.max_file_size = 100 * 1024 * 1024 # 100MB
demo.proxy_url = None
demo.root_path = ""
demo.queue()
# We get the FastAPI instance from Gradio
# app = demo.app
# 3. Mount Gradio to FastAPI
# Using path="" and assigning to the app ensures assets are at the root
app = gr.mount_gradio_app(app, demo, path="/")
# --- STEP 4: RUN ---
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=HF_PORT, workers=1)