Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import gradio as gr | |
| def run_inference(review_text: str) -> str: | |
| """ | |
| Perform inference on the given wine review text and return the predicted wine variety. | |
| Args: | |
| review_text (str): Wine review text in the format "country [SEP] description". | |
| Returns: | |
| str: The predicted wine variety using the model's id2label mapping if available. | |
| """ | |
| # Define model and tokenizer identifiers | |
| model_id = "spawn99/modernbert-wine-classification" | |
| tokenizer_id = "answerdotai/ModernBERT-base" | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
| # Tokenize the input text | |
| inputs = tokenizer( | |
| review_text, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=256 | |
| ) | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Determine prediction and map to label if available | |
| pred = torch.argmax(logits, dim=-1).item() | |
| variety = ( | |
| model.config.id2label.get(pred, str(pred)) | |
| if hasattr(model.config, "id2label") and model.config.id2label | |
| else str(pred) | |
| ) | |
| return variety | |
| def predict_wine_variety(country: str, description: str) -> dict: | |
| """ | |
| Combine the provided country and description, then perform inference. | |
| Enforces a maximum character limit of 750 on the description. | |
| Args: | |
| country (str): The country of wine origin. | |
| description (str): The wine review description. | |
| Returns: | |
| dict: Dictionary containing the predicted wine variety or an error message if the limit is exceeded. | |
| """ | |
| # Validate description length | |
| if len(description) > 750: | |
| return {"error": "Description exceeds 750 character limit. Please shorten your input."} | |
| # Capitalize input values and format the review text accordingly. | |
| review_text = f"{country.capitalize()} [SEP] {description.capitalize()}" | |
| predicted_variety = run_inference(review_text) | |
| return {"Variety": predicted_variety} | |
| if __name__ == "__main__": | |
| iface = gr.Interface( | |
| fn=predict_wine_variety, | |
| inputs=[ | |
| gr.Textbox(label="Country", placeholder="Enter country of origin..."), | |
| gr.Textbox(label="Description", placeholder="Enter wine review description...") | |
| ], | |
| outputs=gr.JSON(label="Prediction"), | |
| title="Wine Variety Predictor", | |
| description="Predict the wine variety based on country and description." | |
| ) | |
| iface.launch() |