cadspc / app.py
kshdes37's picture
Update app.py
a3ac0b3 verified
import gradio as gr
import torch
import json
import os
import tempfile
import subprocess
import sys
from pathlib import Path
from huggingface_hub import snapshot_download
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CADFusionInference:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_loaded = False
def load_model(self, model_path="microsoft/CADFusion", revision="v1_1"):
"""Load the CADFusion model and tokenizer"""
try:
logger.info(f"Loading CADFusion model from {model_path} (revision: {revision})")
# Download model files
model_dir = snapshot_download(
repo_id=model_path,
revision=revision,
cache_dir="./model_cache"
)
# Try to load the model - this is a placeholder as we need to see the actual model structure
# The actual implementation would depend on the model architecture used
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
self.model_loaded = True
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise e
def generate_cad_sequence(self, text_prompt, max_length=512, temperature=0.8, top_p=0.9):
"""Generate CAD sequence from text prompt"""
if not self.model_loaded:
raise ValueError("Model not loaded. Please load the model first.")
try:
# Format the prompt for CAD generation
formatted_prompt = f"Generate CAD sequence for: {text_prompt}\nCAD:"
# Tokenize input
inputs = self.tokenizer.encode(formatted_prompt, return_tensors="pt")
inputs = inputs.to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode output
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract CAD sequence (remove the prompt part)
cad_sequence = generated_text[len(formatted_prompt):].strip()
return cad_sequence
except Exception as e:
logger.error(f"Error generating CAD sequence: {str(e)}")
raise e
def render_cad_visualization(self, cad_sequence):
"""Convert CAD sequence to visualization (placeholder - would need actual rendering code)"""
# This is a placeholder function. In the actual implementation, you would:
# 1. Parse the CAD sequence into geometric operations
# 2. Use the rendering utilities from the CADFusion repo
# 3. Generate 3D visualization or images
try:
# Create a simple text representation for now
visualization_info = {
"sequence": cad_sequence,
"operations": cad_sequence.count("extrude") + cad_sequence.count("revolve"),
"sketches": cad_sequence.count("sketch"),
"status": "Generated (visualization placeholder)"
}
return visualization_info
except Exception as e:
logger.error(f"Error rendering CAD: {str(e)}")
return {"error": str(e)}
# Initialize the inference class
cad_fusion = CADFusionInference()
def generate_cad_from_text(text_prompt, max_length=512, temperature=0.8, top_p=0.9):
"""Main function for Gradio interface"""
try:
# Load model if not already loaded
if not cad_fusion.model_loaded:
try:
cad_fusion.load_model()
except Exception as e:
error_msg = f"Failed to load CADFusion model: {str(e)}\n\nThis might be due to:\n- Model access restrictions\n- Insufficient resources\n- Network connectivity issues"
return error_msg, ""
# Validate input
if not text_prompt or text_prompt.strip() == "":
return "Please provide a description for the CAD model.", ""
# Generate CAD sequence
cad_sequence = cad_fusion.generate_cad_sequence(
text_prompt.strip(),
max_length=int(max_length),
temperature=temperature,
top_p=top_p
)
if not cad_sequence:
return "No CAD sequence was generated. Please try with a different prompt.", ""
# Create visualization info
viz_info = cad_fusion.render_cad_visualization(cad_sequence)
# Format detailed output
ops = viz_info.get('operations', {})
output_text = f"""
## 🎯 Generated CAD Model
**Input Description:** {text_prompt}
**Generated CAD Sequence:**
```
{cad_sequence[:500]}{'...' if len(cad_sequence) > 500 else ''}
```
## πŸ“Š Analysis:
- **Total Operations:** {viz_info.get('total_operations', 0)}
- **Complexity:** {viz_info.get('complexity', 'Unknown')}
- **Lines of Code:** {viz_info.get('line_count', 0)}
### Operation Breakdown:
- **Sketches:** {ops.get('sketch', 0)}
- **Extrusions:** {ops.get('extrude', 0)}
- **Revolutions:** {ops.get('revolve', 0)}
- **Circles:** {ops.get('circle', 0)}
- **Rectangles:** {ops.get('rectangle', 0)}
- **Lines:** {ops.get('line', 0)}
- **Fillets:** {ops.get('fillet', 0)}
- **Chamfers:** {ops.get('chamfer', 0)}
**Status:** {viz_info.get('status', 'Generated successfully')}
---
*Note: This is the parametric CAD sequence. For full 3D rendering, use CAD software that supports these operations.*
"""
return output_text, cad_sequence
except Exception as e:
error_msg = f"❌ Error generating CAD: {str(e)}"
logger.error(error_msg)
return error_msg, ""
def create_gradio_interface():
"""Create the Gradio interface"""
with gr.Blocks(
title="CADFusion - Text-to-CAD Generation",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px;
margin: auto;
}
.title {
text-align: center;
margin-bottom: 20px;
}
"""
) as demo:
gr.Markdown("""
# πŸ”§ CADFusion - Text-to-CAD Generation
Convert natural language descriptions into CAD model sequences using Microsoft's CADFusion framework.
**Features:**
- Generate parametric CAD sequences from text descriptions
- Built on fine-tuned LLMs with visual feedback learning
- Supports complex 3D modeling operations
**Example prompts:**
- "Create a cylindrical cup with a handle"
- "Design a rectangular bracket with mounting holes"
- "Generate a gear wheel with 12 teeth"
""", elem_classes="title")
with gr.Row():
with gr.Column(scale=2):
# Input section
gr.Markdown("## πŸ“ Input")
text_input = gr.Textbox(
label="CAD Description",
placeholder="Describe the CAD model you want to generate...",
lines=3,
value="Create a simple cylindrical cup with a handle on the side"
)
with gr.Accordion("Advanced Settings", open=False):
max_length = gr.Slider(
minimum=128,
maximum=1024,
value=512,
step=32,
label="Max Sequence Length"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p"
)
generate_btn = gr.Button(
"πŸš€ Generate CAD",
variant="primary",
size="lg"
)
with gr.Column(scale=3):
# Output section
gr.Markdown("## 🎯 Generated CAD")
output_display = gr.Markdown(label="Results")
with gr.Accordion("Raw CAD Sequence", open=False):
raw_sequence = gr.Textbox(
label="CAD Sequence",
lines=10,
max_lines=15,
show_copy_button=True
)
# Examples section
gr.Markdown("## πŸ“š Example Prompts")
examples = gr.Examples(
examples=[
["Create a simple cylindrical cup with a handle"],
["Design a rectangular bracket with four mounting holes"],
["Generate a gear wheel with 10 teeth and a central hole"],
["Make a L-shaped bracket for wall mounting"],
["Create a hexagonal nut with internal threading"],
["Design a simple phone stand with an angled surface"],
],
inputs=[text_input],
label="Click on any example to try it"
)
# Event handlers
generate_btn.click(
fn=generate_cad_from_text,
inputs=[text_input, max_length, temperature, top_p],
outputs=[output_display, raw_sequence],
show_progress=True
)
# Footer
gr.Markdown("""
---
**About CADFusion:**
This model is based on the paper ["Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models"](https://arxiv.org/abs/2501.19054) by Microsoft Research.
**Note:** This demo shows the text-to-sequence generation capability. Full 3D rendering would require additional computational resources and the complete CADFusion rendering pipeline.
""")
return demo
# Create and launch the interface
if __name__ == "__main__":
try:
# Pre-load the model for better performance
logger.info("Initializing CADFusion model...")
demo = create_gradio_interface()
# Launch the app
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Failed to launch application: {str(e)}")
sys.exit(1)