|
|
import gradio as gr |
|
|
import torch |
|
|
import json |
|
|
import os |
|
|
import tempfile |
|
|
import subprocess |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from huggingface_hub import snapshot_download |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CADFusionInference: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model_loaded = False |
|
|
|
|
|
def load_model(self, model_path="microsoft/CADFusion", revision="v1_1"): |
|
|
"""Load the CADFusion model and tokenizer""" |
|
|
try: |
|
|
logger.info(f"Loading CADFusion model from {model_path} (revision: {revision})") |
|
|
|
|
|
|
|
|
model_dir = snapshot_download( |
|
|
repo_id=model_path, |
|
|
revision=revision, |
|
|
cache_dir="./model_cache" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_dir, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
self.model_loaded = True |
|
|
logger.info("Model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {str(e)}") |
|
|
raise e |
|
|
|
|
|
def generate_cad_sequence(self, text_prompt, max_length=512, temperature=0.8, top_p=0.9): |
|
|
"""Generate CAD sequence from text prompt""" |
|
|
if not self.model_loaded: |
|
|
raise ValueError("Model not loaded. Please load the model first.") |
|
|
|
|
|
try: |
|
|
|
|
|
formatted_prompt = f"Generate CAD sequence for: {text_prompt}\nCAD:" |
|
|
|
|
|
|
|
|
inputs = self.tokenizer.encode(formatted_prompt, return_tensors="pt") |
|
|
inputs = inputs.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
inputs, |
|
|
max_length=max_length, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=True, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
cad_sequence = generated_text[len(formatted_prompt):].strip() |
|
|
|
|
|
return cad_sequence |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating CAD sequence: {str(e)}") |
|
|
raise e |
|
|
|
|
|
def render_cad_visualization(self, cad_sequence): |
|
|
"""Convert CAD sequence to visualization (placeholder - would need actual rendering code)""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
visualization_info = { |
|
|
"sequence": cad_sequence, |
|
|
"operations": cad_sequence.count("extrude") + cad_sequence.count("revolve"), |
|
|
"sketches": cad_sequence.count("sketch"), |
|
|
"status": "Generated (visualization placeholder)" |
|
|
} |
|
|
|
|
|
return visualization_info |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error rendering CAD: {str(e)}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
cad_fusion = CADFusionInference() |
|
|
|
|
|
def generate_cad_from_text(text_prompt, max_length=512, temperature=0.8, top_p=0.9): |
|
|
"""Main function for Gradio interface""" |
|
|
try: |
|
|
|
|
|
if not cad_fusion.model_loaded: |
|
|
try: |
|
|
cad_fusion.load_model() |
|
|
except Exception as e: |
|
|
error_msg = f"Failed to load CADFusion model: {str(e)}\n\nThis might be due to:\n- Model access restrictions\n- Insufficient resources\n- Network connectivity issues" |
|
|
return error_msg, "" |
|
|
|
|
|
|
|
|
if not text_prompt or text_prompt.strip() == "": |
|
|
return "Please provide a description for the CAD model.", "" |
|
|
|
|
|
|
|
|
cad_sequence = cad_fusion.generate_cad_sequence( |
|
|
text_prompt.strip(), |
|
|
max_length=int(max_length), |
|
|
temperature=temperature, |
|
|
top_p=top_p |
|
|
) |
|
|
|
|
|
if not cad_sequence: |
|
|
return "No CAD sequence was generated. Please try with a different prompt.", "" |
|
|
|
|
|
|
|
|
viz_info = cad_fusion.render_cad_visualization(cad_sequence) |
|
|
|
|
|
|
|
|
ops = viz_info.get('operations', {}) |
|
|
output_text = f""" |
|
|
## π― Generated CAD Model |
|
|
|
|
|
**Input Description:** {text_prompt} |
|
|
|
|
|
**Generated CAD Sequence:** |
|
|
``` |
|
|
{cad_sequence[:500]}{'...' if len(cad_sequence) > 500 else ''} |
|
|
``` |
|
|
|
|
|
## π Analysis: |
|
|
- **Total Operations:** {viz_info.get('total_operations', 0)} |
|
|
- **Complexity:** {viz_info.get('complexity', 'Unknown')} |
|
|
- **Lines of Code:** {viz_info.get('line_count', 0)} |
|
|
|
|
|
### Operation Breakdown: |
|
|
- **Sketches:** {ops.get('sketch', 0)} |
|
|
- **Extrusions:** {ops.get('extrude', 0)} |
|
|
- **Revolutions:** {ops.get('revolve', 0)} |
|
|
- **Circles:** {ops.get('circle', 0)} |
|
|
- **Rectangles:** {ops.get('rectangle', 0)} |
|
|
- **Lines:** {ops.get('line', 0)} |
|
|
- **Fillets:** {ops.get('fillet', 0)} |
|
|
- **Chamfers:** {ops.get('chamfer', 0)} |
|
|
|
|
|
**Status:** {viz_info.get('status', 'Generated successfully')} |
|
|
|
|
|
--- |
|
|
*Note: This is the parametric CAD sequence. For full 3D rendering, use CAD software that supports these operations.* |
|
|
""" |
|
|
|
|
|
return output_text, cad_sequence |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"β Error generating CAD: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
return error_msg, "" |
|
|
|
|
|
def create_gradio_interface(): |
|
|
"""Create the Gradio interface""" |
|
|
|
|
|
with gr.Blocks( |
|
|
title="CADFusion - Text-to-CAD Generation", |
|
|
theme=gr.themes.Soft(), |
|
|
css=""" |
|
|
.gradio-container { |
|
|
max-width: 1200px; |
|
|
margin: auto; |
|
|
} |
|
|
.title { |
|
|
text-align: center; |
|
|
margin-bottom: 20px; |
|
|
} |
|
|
""" |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# π§ CADFusion - Text-to-CAD Generation |
|
|
|
|
|
Convert natural language descriptions into CAD model sequences using Microsoft's CADFusion framework. |
|
|
|
|
|
**Features:** |
|
|
- Generate parametric CAD sequences from text descriptions |
|
|
- Built on fine-tuned LLMs with visual feedback learning |
|
|
- Supports complex 3D modeling operations |
|
|
|
|
|
**Example prompts:** |
|
|
- "Create a cylindrical cup with a handle" |
|
|
- "Design a rectangular bracket with mounting holes" |
|
|
- "Generate a gear wheel with 12 teeth" |
|
|
""", elem_classes="title") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
|
|
|
gr.Markdown("## π Input") |
|
|
text_input = gr.Textbox( |
|
|
label="CAD Description", |
|
|
placeholder="Describe the CAD model you want to generate...", |
|
|
lines=3, |
|
|
value="Create a simple cylindrical cup with a handle on the side" |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
max_length = gr.Slider( |
|
|
minimum=128, |
|
|
maximum=1024, |
|
|
value=512, |
|
|
step=32, |
|
|
label="Max Sequence Length" |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=2.0, |
|
|
value=0.8, |
|
|
step=0.1, |
|
|
label="Temperature" |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.05, |
|
|
label="Top-p" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button( |
|
|
"π Generate CAD", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
|
|
|
gr.Markdown("## π― Generated CAD") |
|
|
output_display = gr.Markdown(label="Results") |
|
|
|
|
|
with gr.Accordion("Raw CAD Sequence", open=False): |
|
|
raw_sequence = gr.Textbox( |
|
|
label="CAD Sequence", |
|
|
lines=10, |
|
|
max_lines=15, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("## π Example Prompts") |
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["Create a simple cylindrical cup with a handle"], |
|
|
["Design a rectangular bracket with four mounting holes"], |
|
|
["Generate a gear wheel with 10 teeth and a central hole"], |
|
|
["Make a L-shaped bracket for wall mounting"], |
|
|
["Create a hexagonal nut with internal threading"], |
|
|
["Design a simple phone stand with an angled surface"], |
|
|
], |
|
|
inputs=[text_input], |
|
|
label="Click on any example to try it" |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_cad_from_text, |
|
|
inputs=[text_input, max_length, temperature, top_p], |
|
|
outputs=[output_display, raw_sequence], |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**About CADFusion:** |
|
|
This model is based on the paper ["Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models"](https://arxiv.org/abs/2501.19054) by Microsoft Research. |
|
|
|
|
|
**Note:** This demo shows the text-to-sequence generation capability. Full 3D rendering would require additional computational resources and the complete CADFusion rendering pipeline. |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
|
|
|
logger.info("Initializing CADFusion model...") |
|
|
|
|
|
demo = create_gradio_interface() |
|
|
|
|
|
|
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to launch application: {str(e)}") |
|
|
sys.exit(1) |