Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from dual_regression_model import DualRegressionModel | |
| import transformers | |
| from transformers import pipeline | |
| from functools import partial | |
| # load the models | |
| # CLF: A-pt-bs16-dbmdz-bert-base-italian-cased | |
| clf_model_tag = "clf_model/" | |
| clf_tokenizer = transformers.AutoTokenizer.from_pretrained(clf_model_tag) | |
| clf_model = transformers.AutoModelForSequenceClassification.from_pretrained(clf_model_tag) | |
| clf_pipeline = pipeline("text-classification", model=clf_model, tokenizer=clf_tokenizer) | |
| # REG | |
| reg_model_tag = "distilbert-base-multilingual-cased" | |
| reg_model_folder = "reg_model/regression_model.pt" | |
| reg_model = DualRegressionModel(model_name_or_path=reg_model_tag) | |
| reg_model.load_model(reg_model_folder) | |
| # define the function to be used for prediction | |
| def predict(text): | |
| # predict the class | |
| clf_prediction = clf_pipeline(text)[0] | |
| # predict the coordinates | |
| reg_input = reg_model.tokenizer(text, return_tensors="pt") | |
| reg_prediction = reg_model(reg_input) | |
| latitude, longitude = reg_prediction["latitude"].item(), reg_prediction["longitude"].item() | |
| lat_min = 38 | |
| lat_max = 46 | |
| long_min = 8 | |
| long_max = 18 | |
| # return the results | |
| html_output = f"<h3>The identified region is: {clf_prediction['label']}</h3>" | |
| # plot points on the map of Italy | |
| html_output += f'<h3>Predicted point on map:</h3><p>Latitude: {latitude}</p><p>Longitude: {longitude}</p>' | |
| html_output += f'<iframe width="425" height="350" frameborder="0" scrolling="no" marginheight="0" marginwidth="0" src="https://www.openstreetmap.org/export/embed.html?bbox={long_min}%2C{lat_min}%2C{long_max}%2C{lat_max}&layer=mapnik&marker={latitude}%2C{longitude}" style="border: 1px solid black"></iframe><br/><small><a href="https://www.openstreetmap.org/#map=13/{latitude}/{longitude}">Visualizza mappa ingrandita</a></small>' | |
| return html_output | |
| # -------------------------------------------------------------------------------------------- | |
| # Gradio interface | |
| # -------------------------------------------------------------------------------------------- | |
| # define the interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=2, placeholder="Insert the text here..."), | |
| outputs=gr.HTML(), | |
| title="DANTE: Dialect ANalysis TEam", | |
| description="This is a demo of a classification and regression model for locating the italian dialect of a given text.", | |
| examples=[ | |
| ["Bisognerebbe saperli materializzare .... !! Ma ovviamente .. belin .... NO SE PEU SCIUSCIA' E SCIORBI'"], | |
| ["Guaglio' Buongiorno! Azz! Vir te si scurdat puparuol e mulignane pero '!! E che se fa😑"], | |
| ["Il massimo...ghe ne minga par nisun"], | |
| ["Che poi a me la tuta piace na cifra da vede. Subisco un po' lo stigma sociale che noi con la fregna dovemo stà sempre apposto.",] | |
| ] | |
| ) | |
| # launch the interface | |
| iface.launch() | |