import gradio as gr import torch import torch.nn as nn from safetensors.torch import load_file from transformers import AutoTokenizer, AutoModel import gc # Release memory gc.collect() torch.cuda.empty_cache() class MultiTaskRoberta(nn.Module): def __init__(self, base_model): super().__init__() self.roberta = base_model self.classifier = nn.Linear(768, 3) self.regressor = nn.Linear(768, 5) def forward(self, input_ids, attention_mask=None, **kwargs): outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state[:, 0] logits = self.classifier(pooled) regs = self.regressor(pooled) return {"logits": logits, "regression_outputs": regs} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") # Load base model base_model = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext") model = MultiTaskRoberta(base_model) # Load safetensors model_path = "model1.safetensors" state_dict = load_file(model_path, device="cpu") model.load_state_dict(state_dict) model.to(device) model.eval() # Use half precision to reduce memory usage # if device.type == 'cuda': # model.half() def predict(text: str): try: inputs = tokenizer( text, return_tensors="pt", truncation=True, padding="max_length", max_length=128 ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): if device.type == 'cuda': with torch.cuda.amp.autocast(): out = model(**inputs) else: out = model(**inputs) pred_class = torch.argmax(out["logits"], dim=-1).item() sentiment_map = {0: "正面", 1: "負面", 2: "中立"} reg_results = out["regression_outputs"][0].cpu().numpy() rating, delight, anger, sorrow, happiness = reg_results return { "情感": sentiment_map[pred_class], "評分": round(rating, 2), "喜悅": round(delight, 2), "憤怒": round(anger, 2), "悲傷": round(sorrow, 2), "快樂": round(happiness, 2), } except Exception as e: return {"错误": f"处理失败: {str(e)}"} # Create Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=3, placeholder="請輸入粵語文本...", label="粵語文本"), outputs=gr.JSON(label="分析結果"), title="粵語情感與情緒分析", description="輸入粵語文本,分析情感(正面/負面/中立)和五種情緒評分", examples=[ ["呢個plan聽落唔錯,我哋試下先啦。"], ["份proposal 你send 咗俾client未?Deadline 係EOD呀。"], ["返工返到好攰,但係見到同事就feel better啲。"], ["你今次嘅presentation做得唔錯,我好 impressed!"], ["夜晚聽到嗰啲聲,我唔敢出房門。"], ["個client 真係好 difficult 囉,改咗n 次 requirements,仲要urgent,chur 到痴線!"], ["我尋日冇乜特別事做,就係喺屋企睇電視。"], ["Weekend 去staycation,間酒店個view 正到爆!"], ["做乜嘢都冇意義。"], ["今朝遲到咗,差啲miss咗個重要meeting"], ] ) if __name__ == "__main__": iface.launch(share=True, show_error=True)