import gradio as gr import torch import json import os import tempfile import subprocess import sys from pathlib import Path from huggingface_hub import snapshot_download import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class CADFusionInference: def __init__(self): self.model = None self.tokenizer = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model_loaded = False def load_model(self, model_path="microsoft/CADFusion", revision="v1_1"): """Load the CADFusion model and tokenizer""" try: logger.info(f"Loading CADFusion model from {model_path} (revision: {revision})") # Download model files model_dir = snapshot_download( repo_id=model_path, revision=revision, cache_dir="./model_cache" ) # Try to load the model - this is a placeholder as we need to see the actual model structure # The actual implementation would depend on the model architecture used from transformers import AutoTokenizer, AutoModelForCausalLM # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_dir) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model self.model = AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) self.model_loaded = True logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise e def generate_cad_sequence(self, text_prompt, max_length=512, temperature=0.8, top_p=0.9): """Generate CAD sequence from text prompt""" if not self.model_loaded: raise ValueError("Model not loaded. Please load the model first.") try: # Format the prompt for CAD generation formatted_prompt = f"Generate CAD sequence for: {text_prompt}\nCAD:" # Tokenize input inputs = self.tokenizer.encode(formatted_prompt, return_tensors="pt") inputs = inputs.to(self.device) # Generate with torch.no_grad(): outputs = self.model.generate( inputs, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decode output generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract CAD sequence (remove the prompt part) cad_sequence = generated_text[len(formatted_prompt):].strip() return cad_sequence except Exception as e: logger.error(f"Error generating CAD sequence: {str(e)}") raise e def render_cad_visualization(self, cad_sequence): """Convert CAD sequence to visualization (placeholder - would need actual rendering code)""" # This is a placeholder function. In the actual implementation, you would: # 1. Parse the CAD sequence into geometric operations # 2. Use the rendering utilities from the CADFusion repo # 3. Generate 3D visualization or images try: # Create a simple text representation for now visualization_info = { "sequence": cad_sequence, "operations": cad_sequence.count("extrude") + cad_sequence.count("revolve"), "sketches": cad_sequence.count("sketch"), "status": "Generated (visualization placeholder)" } return visualization_info except Exception as e: logger.error(f"Error rendering CAD: {str(e)}") return {"error": str(e)} # Initialize the inference class cad_fusion = CADFusionInference() def generate_cad_from_text(text_prompt, max_length=512, temperature=0.8, top_p=0.9): """Main function for Gradio interface""" try: # Load model if not already loaded if not cad_fusion.model_loaded: try: cad_fusion.load_model() except Exception as e: error_msg = f"Failed to load CADFusion model: {str(e)}\n\nThis might be due to:\n- Model access restrictions\n- Insufficient resources\n- Network connectivity issues" return error_msg, "" # Validate input if not text_prompt or text_prompt.strip() == "": return "Please provide a description for the CAD model.", "" # Generate CAD sequence cad_sequence = cad_fusion.generate_cad_sequence( text_prompt.strip(), max_length=int(max_length), temperature=temperature, top_p=top_p ) if not cad_sequence: return "No CAD sequence was generated. Please try with a different prompt.", "" # Create visualization info viz_info = cad_fusion.render_cad_visualization(cad_sequence) # Format detailed output ops = viz_info.get('operations', {}) output_text = f""" ## 🎯 Generated CAD Model **Input Description:** {text_prompt} **Generated CAD Sequence:** ``` {cad_sequence[:500]}{'...' if len(cad_sequence) > 500 else ''} ``` ## 📊 Analysis: - **Total Operations:** {viz_info.get('total_operations', 0)} - **Complexity:** {viz_info.get('complexity', 'Unknown')} - **Lines of Code:** {viz_info.get('line_count', 0)} ### Operation Breakdown: - **Sketches:** {ops.get('sketch', 0)} - **Extrusions:** {ops.get('extrude', 0)} - **Revolutions:** {ops.get('revolve', 0)} - **Circles:** {ops.get('circle', 0)} - **Rectangles:** {ops.get('rectangle', 0)} - **Lines:** {ops.get('line', 0)} - **Fillets:** {ops.get('fillet', 0)} - **Chamfers:** {ops.get('chamfer', 0)} **Status:** {viz_info.get('status', 'Generated successfully')} --- *Note: This is the parametric CAD sequence. For full 3D rendering, use CAD software that supports these operations.* """ return output_text, cad_sequence except Exception as e: error_msg = f"❌ Error generating CAD: {str(e)}" logger.error(error_msg) return error_msg, "" def create_gradio_interface(): """Create the Gradio interface""" with gr.Blocks( title="CADFusion - Text-to-CAD Generation", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px; margin: auto; } .title { text-align: center; margin-bottom: 20px; } """ ) as demo: gr.Markdown(""" # 🔧 CADFusion - Text-to-CAD Generation Convert natural language descriptions into CAD model sequences using Microsoft's CADFusion framework. **Features:** - Generate parametric CAD sequences from text descriptions - Built on fine-tuned LLMs with visual feedback learning - Supports complex 3D modeling operations **Example prompts:** - "Create a cylindrical cup with a handle" - "Design a rectangular bracket with mounting holes" - "Generate a gear wheel with 12 teeth" """, elem_classes="title") with gr.Row(): with gr.Column(scale=2): # Input section gr.Markdown("## 📝 Input") text_input = gr.Textbox( label="CAD Description", placeholder="Describe the CAD model you want to generate...", lines=3, value="Create a simple cylindrical cup with a handle on the side" ) with gr.Accordion("Advanced Settings", open=False): max_length = gr.Slider( minimum=128, maximum=1024, value=512, step=32, label="Max Sequence Length" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p" ) generate_btn = gr.Button( "🚀 Generate CAD", variant="primary", size="lg" ) with gr.Column(scale=3): # Output section gr.Markdown("## 🎯 Generated CAD") output_display = gr.Markdown(label="Results") with gr.Accordion("Raw CAD Sequence", open=False): raw_sequence = gr.Textbox( label="CAD Sequence", lines=10, max_lines=15, show_copy_button=True ) # Examples section gr.Markdown("## 📚 Example Prompts") examples = gr.Examples( examples=[ ["Create a simple cylindrical cup with a handle"], ["Design a rectangular bracket with four mounting holes"], ["Generate a gear wheel with 10 teeth and a central hole"], ["Make a L-shaped bracket for wall mounting"], ["Create a hexagonal nut with internal threading"], ["Design a simple phone stand with an angled surface"], ], inputs=[text_input], label="Click on any example to try it" ) # Event handlers generate_btn.click( fn=generate_cad_from_text, inputs=[text_input, max_length, temperature, top_p], outputs=[output_display, raw_sequence], show_progress=True ) # Footer gr.Markdown(""" --- **About CADFusion:** This model is based on the paper ["Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models"](https://arxiv.org/abs/2501.19054) by Microsoft Research. **Note:** This demo shows the text-to-sequence generation capability. Full 3D rendering would require additional computational resources and the complete CADFusion rendering pipeline. """) return demo # Create and launch the interface if __name__ == "__main__": try: # Pre-load the model for better performance logger.info("Initializing CADFusion model...") demo = create_gradio_interface() # Launch the app demo.launch( server_name="0.0.0.0", server_port=7860, share=False ) except Exception as e: logger.error(f"Failed to launch application: {str(e)}") sys.exit(1)