Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from datasets import load_from_disk | |
| import numpy as np | |
| import re | |
| adjectives = ['ambitious', 'assertive', 'committed', 'compassionate', 'confident', 'considerate', 'decisive', 'determined', 'emotional', 'gentle', | |
| 'honest', 'intellectual', 'modest', 'no_adjective', 'outspoken', 'pleasant', 'self-confident', 'sensitive', 'stubborn', 'supportive', 'unreasonable'] | |
| professions = ['CEO', 'IT_specialist', 'accountant', 'aerospace_engineer', 'aide', 'air_conditioning_installer', 'architect', 'artist', 'author', | |
| 'baker', 'bartender', 'bus_driver', 'butcher', 'career_counselor', 'carpenter', 'carpet_installer', 'cashier', 'childcare_worker', | |
| 'civil_engineer', 'claims_appraiser', 'cleaner', 'clergy', 'clerk', 'coach', 'community_manager', 'compliance_officer', 'computer_programmer', | |
| 'computer_support_specialist', 'computer_systems_analyst', 'construction_worker', 'cook', 'correctional_officer', 'courier', 'credit_counselor', | |
| 'customer_service_representative', 'data_entry_keyer', 'dental_assistant', 'dental_hygienist', 'dentist', 'designer', 'detective', 'director', | |
| 'dishwasher', 'dispatcher', 'doctor', 'drywall_installer', 'electrical_engineer', 'electrician', 'engineer', 'event_planner', 'executive_assistant', | |
| 'facilities_manager', 'farmer', 'fast_food_worker', 'file_clerk', 'financial_advisor', 'financial_analyst', 'financial_manager', 'firefighter', | |
| 'fitness_instructor', 'graphic_designer', 'groundskeeper', 'hairdresser', 'head_cook', 'health_technician', 'host', 'hostess', 'industrial_engineer', | |
| 'insurance_agent', 'interior_designer', 'interviewer', 'inventory_clerk', 'jailer', 'janitor', 'laboratory_technician', 'language_pathologist', | |
| 'lawyer', 'librarian', 'logistician', 'machinery_mechanic', 'machinist', 'maid', 'manager', 'manicurist', 'market_research_analyst', | |
| 'marketing_manager', 'massage_therapist', 'mechanic', 'mechanical_engineer', 'medical_records_specialist', 'mental_health_counselor', | |
| 'metal_worker', 'mover', 'musician', 'network_administrator', 'nurse', 'nursing_assistant', 'nutritionist', 'occupational_therapist', | |
| 'office_clerk', 'office_worker', 'painter', 'paralegal', 'payroll_clerk', 'pharmacist', 'pharmacy_technician', 'photographer', | |
| 'physical_therapist', 'pilot', 'plane_mechanic', 'plumber', 'police_officer', 'postal_worker', 'printing_press_operator', 'producer', | |
| 'psychologist', 'public_relations_specialist', 'purchasing_agent', 'radiologic_technician', 'real_estate_broker', 'receptionist', | |
| 'repair_worker', 'roofer', 'sales_manager', 'salesperson', 'school_bus_driver', 'scientist', 'security_guard', 'sheet_metal_worker', 'singer', | |
| 'social_assistant', 'social_worker', 'software_developer', 'stocker', 'supervisor', 'taxi_driver', 'teacher', 'teaching_assistant', 'teller', | |
| 'therapist', 'tractor_operator', 'truck_driver', 'tutor', 'underwriter', 'veterinarian', 'waiter', 'waitress', 'welder', 'wholesale_buyer', 'writer'] | |
| models = ['DallE', 'SD_14', 'SD_2'] | |
| nos = [1,2,3,4,5,6,7,8,9,10] | |
| ds = load_from_disk("jobs") | |
| def get_nearest(adjective, profession, model, no): | |
| index=768 | |
| df = ds.remove_columns(["image","image_path"]).to_pandas() | |
| index = np.load(f"indexes/knn_{index}_65.npy") | |
| ix = df.loc[(df['adjective'] == adjective) & (df['profession'] == profession) & (df['no'] == no) & (df['model'] == model)].index[0] | |
| image = ds.select([index[ix][0]])["image"][0] | |
| neighbors = ds.select(index[ix][1:25]) | |
| neighbor_images = neighbors["image"] | |
| neighbor_captions = [caption.split("/")[-1] for caption in neighbors["image_path"]] | |
| # neighbor_captions = [' '.join(caption.split("_")[4:-3]) for caption in neighbor_captions] | |
| neighbor_models = neighbors["model"] | |
| neighbor_captions = [f"{re.split('Photo_portrait_of_an?_', a)[-1]} {b}" for a,b in zip(neighbor_captions,neighbor_models)] | |
| return image, list(zip(neighbor_images, neighbor_captions)) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# BoVW Nearest Neighbors Explorer") | |
| gr.Markdown("### TF-IDF index of the _identities_ dataset of images generated by 3 models using a visual vocabulary of 10,752 words.") | |
| gr.Markdown("#### Choose one of the generated identity images to see its nearest neighbors according to a bag-of-visual-words model.") | |
| gr.HTML("""<span style="color:red">⚠️ <b>DISCLAIMER: the images displayed by this tool were generated by text-to-image models and may depict offensive stereotypes or contain explicit content.</b></span>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model = gr.Radio(models, label="Model") | |
| adjective = gr.Radio(adjectives, label="Adjective") | |
| no = gr.Radio(nos, label="Image number") | |
| with gr.Column(): | |
| profession = gr.Dropdown(professions, label="Profession") | |
| button = gr.Button(value="Get nearest neighbors") | |
| with gr.Row(): | |
| image = gr.Image() | |
| gallery = gr.Gallery().style(grid=4) | |
| button.click(get_nearest, inputs=[adjective, profession, model, no], outputs=[image, gallery]) | |
| demo.launch() | |