import os import time from typing import List, Dict import gradio as gr from transformers import pipeline import spaces # === Config (override via Space secrets/env vars) === MODEL_ID = os.environ.get("MODEL_ID", "tlhv/osb-minier") DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 512)) DEFAULT_TEMPERATURE = float(os.environ.get("TEMPERATURE", 0.7)) DEFAULT_TOP_P = float(os.environ.get("TOP_P", 0.95)) DEFAULT_REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.0)) ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds # Cached pipeline (created after GPU is granted) _pipe = None def _to_messages(user_prompt: str) -> List[Dict[str, str]]: # The provided model expects chat-style messages return [{"role": "user", "content": user_prompt}] @spaces.GPU(duration=ZGPU_DURATION) def generate_long_prompt( prompt: str, max_new_tokens: int, temperature: float, top_p: float, repetition_penalty: float, ): """Runs on a ZeroGPU-allocated GPU thanks to the decorator above.""" global _pipe start = time.time() # Create the pipeline lazily once the GPU is available if _pipe is None: _pipe = pipeline( "text-generation", model=MODEL_ID, torch_dtype="auto", device_map="auto", # let HF accelerate map to the GPU we just got ) messages = _to_messages(prompt) outputs = _pipe( messages, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, ) # Robust extraction for different pipeline return shapes text = None if isinstance(outputs, list) and outputs: res = outputs[0] if isinstance(res, dict): gt = res.get("generated_text") if isinstance(gt, list) and gt and isinstance(gt[-1], dict): text = gt[-1].get("content") or gt[-1].get("text") elif isinstance(gt, str): text = gt if text is None: text = str(res) else: text = str(outputs) elapsed = time.time() - start meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}" return text, meta with gr.Blocks(css=".wrap textarea {font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace;}") as demo: gr.Markdown("# ZeroGPU: Long-Prompt Text Generation\nPaste a long prompt and generate text with a Transformers model. Set `MODEL_ID` in Space secrets to switch models.") with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", lines=20, placeholder="Paste a long prompt here…", elem_id="wrap", ) with gr.Accordion("Advanced settings", open=False): max_new_tokens = gr.Slider(16, 4096, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="max_new_tokens") temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="temperature") top_p = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P, step=0.01, label="top_p") repetition_penalty = gr.Slider(0.8, 2.0, value=DEFAULT_REPETITION_PENALTY, step=0.05, label="repetition_penalty") generate = gr.Button("Generate", variant="primary") with gr.Column(): output = gr.Textbox(label="Output", lines=20) meta = gr.Markdown() generate.click( fn=generate_long_prompt, inputs=[prompt, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[output, meta], concurrency_limit=1, api_name="generate", ) gr.Examples( examples=[ ["Summarize the following 3 pages of notes into a crisp plan of action…"], ["Write a 1200-word blog post about the history of transformers and attention…"], ], inputs=[prompt], ) # Important for ZeroGPU: use a queue so calls are serialized & resumable if __name__ == "__main__": demo.queue(concurrency_count=1, max_size=32).launch()