clipspace / app_fixed.py
borso271's picture
Add remove labels functionality with admin panel UI
2a97c1d
import gradio as gr
import base64
import json
import os
from PIL import Image
import io
from handler import EndpointHandler
# Initialize handler
print("Initializing MobileCLIP handler...")
try:
handler = EndpointHandler()
print(f"Handler initialized successfully! Device: {handler.device}")
except Exception as e:
print(f"Error initializing handler: {e}")
handler = None
def classify_image(image, top_k=10):
"""
Main classification function for public interface.
"""
if handler is None:
return "Error: Handler not initialized", None
if image is None:
return "Please upload an image", None
try:
# Convert PIL image to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue()).decode()
# Call handler
result = handler({
"inputs": {
"image": img_b64,
"top_k": int(top_k)
}
})
# Format results for display
if isinstance(result, list):
# Create formatted output
output_text = "**Top {} Classifications:**\n\n".format(len(result))
# Create data for bar chart (list of tuples)
chart_data = []
for i, item in enumerate(result, 1):
score_pct = item['score'] * 100
output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n"
chart_data.append((item['label'], item['score']))
return output_text, chart_data
else:
return f"Error: {result.get('error', 'Unknown error')}", None
except Exception as e:
return f"Error: {str(e)}", None
def upsert_labels_admin(admin_token, new_items_json):
"""
Admin function to add new labels.
"""
if handler is None:
return "Error: Handler not initialized"
if not admin_token:
return "Error: Admin token required"
try:
# Parse the JSON input
items = json.loads(new_items_json) if new_items_json else []
result = handler({
"inputs": {
"op": "upsert_labels",
"token": admin_token,
"items": items
}
})
if result.get("status") == "ok":
return f"βœ… Success! Added {result.get('added', 0)} new labels. Current version: {result.get('labels_version', 'unknown')}"
elif result.get("error") == "unauthorized":
return "❌ Error: Invalid admin token"
else:
return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}"
except json.JSONDecodeError:
return "❌ Error: Invalid JSON format"
except Exception as e:
return f"❌ Error: {str(e)}"
def reload_labels_admin(admin_token, version):
"""
Admin function to reload a specific label version.
"""
if handler is None:
return "Error: Handler not initialized"
if not admin_token:
return "Error: Admin token required"
try:
result = handler({
"inputs": {
"op": "reload_labels",
"token": admin_token,
"version": int(version) if version else 1
}
})
if result.get("status") == "ok":
return f"βœ… Labels reloaded successfully! Current version: {result.get('labels_version', 'unknown')}"
elif result.get("status") == "nochange":
return f"ℹ️ No change needed. Current version: {result.get('labels_version', 'unknown')}"
elif result.get("error") == "unauthorized":
return "❌ Error: Invalid admin token"
elif result.get("error") == "invalid_version":
return "❌ Error: Invalid version number"
else:
return f"❌ Error: {result.get('error', 'Unknown error')}"
except Exception as e:
return f"❌ Error: {str(e)}"
def get_current_stats():
"""
Get current label statistics.
"""
if handler is None:
return "Handler not initialized"
try:
num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0
version = getattr(handler, 'labels_version', 1)
device = handler.device if hasattr(handler, 'device') else "unknown"
stats = f"""
**Current Statistics:**
- Number of labels: {num_labels}
- Labels version: {version}
- Device: {device}
- Model: MobileCLIP-B
"""
if hasattr(handler, 'class_names') and len(handler.class_names) > 0:
stats += f"\n- Sample labels: {', '.join(handler.class_names[:5])}"
if len(handler.class_names) > 5:
stats += "..."
return stats
except Exception as e:
return f"Error getting stats: {str(e)}"
# Create Gradio interface
print("Creating Gradio interface...")
with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
gr.Markdown("""
# πŸ–ΌοΈ MobileCLIP-B Zero-Shot Image Classifier
Upload an image to classify it using MobileCLIP-B model with dynamic label management.
""")
with gr.Tab("πŸ” Image Classification"):
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Upload Image"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=50,
value=10,
step=1,
label="Number of top results to show"
)
classify_btn = gr.Button("πŸš€ Classify Image", variant="primary")
with gr.Column():
output_text = gr.Markdown(label="Classification Results")
# Simplified bar chart using Dataframe
output_chart = gr.Dataframe(
headers=["Label", "Confidence"],
label="Classification Scores",
interactive=False
)
# Event handler for classification
classify_btn.click(
fn=classify_image,
inputs=[input_image, top_k_slider],
outputs=[output_text, output_chart]
)
# Also trigger on image upload
input_image.change(
fn=classify_image,
inputs=[input_image, top_k_slider],
outputs=[output_text, output_chart]
)
with gr.Tab("πŸ”§ Admin Panel"):
gr.Markdown("""
### Admin Functions
**Note:** Requires admin token (set via environment variable `ADMIN_TOKEN`)
""")
with gr.Row():
admin_token_input = gr.Textbox(
label="Admin Token",
type="password",
placeholder="Enter admin token"
)
with gr.Accordion("πŸ“Š Current Statistics", open=True):
stats_display = gr.Markdown(value=get_current_stats())
refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats")
refresh_stats_btn.click(
fn=get_current_stats,
inputs=[],
outputs=stats_display
)
with gr.Accordion("βž• Add New Labels", open=False):
gr.Markdown("""
Add new labels by providing JSON array:
```json
[
{"id": 100, "name": "new_object", "prompt": "a photo of a new_object"},
{"id": 101, "name": "another_object", "prompt": "a photo of another_object"}
]
```
""")
new_items_input = gr.Code(
label="New Items JSON",
language="json",
lines=5,
value='[\n {"id": 100, "name": "example", "prompt": "a photo of example"}\n]'
)
upsert_btn = gr.Button("βž• Add Labels", variant="primary")
upsert_output = gr.Markdown()
upsert_btn.click(
fn=upsert_labels_admin,
inputs=[admin_token_input, new_items_input],
outputs=upsert_output
)
with gr.Accordion("πŸ”„ Reload Label Version", open=False):
gr.Markdown("Reload labels from a specific version stored in the Hub")
version_input = gr.Number(
label="Version Number",
value=1,
precision=0
)
reload_btn = gr.Button("πŸ”„ Reload Version", variant="primary")
reload_output = gr.Markdown()
reload_btn.click(
fn=reload_labels_admin,
inputs=[admin_token_input, version_input],
outputs=reload_output
)
with gr.Tab("ℹ️ About"):
gr.Markdown("""
## About MobileCLIP-B Classifier
This Space provides a web interface for Apple's MobileCLIP-B model, optimized for fast zero-shot image classification.
### Features:
- πŸš€ **Fast inference**: < 30ms on GPU
- 🏷️ **Dynamic labels**: Add/update labels without redeployment
- πŸ”„ **Version control**: Track and reload label versions
- πŸ“Š **Visual results**: Classification scores and confidence
### Environment Variables (set in Space Settings):
- `ADMIN_TOKEN`: Secret token for admin operations
- `HF_LABEL_REPO`: Hub repository for label storage
- `HF_WRITE_TOKEN`: Token with write permissions to label repo
- `HF_READ_TOKEN`: Token with read permissions (optional)
### Model Details:
- **Architecture**: MobileCLIP-B with MobileOne blocks
- **Text Encoder**: Transformer-based, 77 token context
- **Image Size**: 224x224
- **Embedding Dim**: 512
### License:
Model weights are licensed under Apple Sample Code License (ASCL).
""")
print("Gradio interface created successfully!")
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch()