import gradio as gr from baseline.baseline_convnext import predict_convnext from baseline.baseline_infer import predict_baseline # --- Placeholder models (for future extensions) --- def predict_placeholder_1(image): if image is None: return "Please upload an image." return "Model 2 is not available yet. Please check back later." def predict_placeholder_2(image): if image is None: return "Please upload an image." return "Model 3 is not available yet. Please check back later." # --- Main Prediction Logic --- def predict(model_choice, image): if model_choice == "Herbarium Species Classifier": # Friend's ConvNeXt mix-stream CNN baseline return predict_convnext(image) elif model_choice == "Baseline (DINOv2 + LogReg)": # Your plant-pretrained DINOv2 + Logistic Regression baseline return predict_baseline(image) elif model_choice == "Future Model 1 (Placeholder)": return predict_placeholder_1(image) elif model_choice == "Future Model 2 (Placeholder)": return predict_placeholder_2(image) else: return "Invalid model selected." # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo: with gr.Column(elem_id="app-wrapper"): # Header gr.Markdown( """

🌿 Plant Species Classification

AML Group Project – PsychicFireSong

""", elem_id="app-header", ) # Badges row gr.Markdown( """
Herbarium + Field images ConvNeXtV2 mix-stream CNN DINOv2 + Logistic Regression
""", elem_id="badge-row", ) # Main card with gr.Row(elem_id="main-card"): # Left side: model + image with gr.Column(scale=1, elem_id="left-panel"): model_selector = gr.Dropdown( label="Select model", choices=[ "Herbarium Species Classifier", "Baseline (DINOv2 + LogReg)", "Future Model 1 (Placeholder)", "Future Model 2 (Placeholder)", ], value="Herbarium Species Classifier", ) gr.Markdown( """
Herbarium Species Classifier – end-to-end ConvNeXtV2 CNN.
Baseline – plant-pretrained DINOv2 features + logistic regression head.
""", elem_id="model-help", ) image_input = gr.Image( type="pil", label="Upload plant image", ) submit_button = gr.Button("Classify 🌱", variant="primary") # Right side: predictions with gr.Column(scale=1, elem_id="right-panel"): output_label = gr.Label( label="Top 5 predictions", num_top_classes=5, ) submit_button.click( fn=predict, inputs=[model_selector, image_input], outputs=output_label, ) # Optional examples (keep empty if you don't have images) gr.Examples( examples=[], inputs=image_input, outputs=output_label, fn=lambda img: predict("Herbarium Species Classifier", img), cache_examples=False, ) gr.Markdown( "Built for the AML course – compare CNN vs. DINOv2 feature-extractor baselines.", elem_id="footer", ) if __name__ == "__main__": demo.launch()