| import torch | |
| import gradio as gr | |
| from utils import create_vocab, setup_seed | |
| from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab | |
| setup_seed(4) | |
| def CTXGen(X1,X2,X3,model_name): | |
| device = torch.device("cpu") | |
| vocab_mlm = create_vocab() | |
| vocab_mlm = add_tokens_to_vocab(vocab_mlm) | |
| save_path = model_name | |
| model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu')) | |
| model = model.to(device) | |
| predicted_token_probability_all = [] | |
| model.eval() | |
| topk = [] | |
| with torch.no_grad(): | |
| new_seq = None | |
| seq = [f"{X1}|{X2}|{X3}|||"] | |
| vocab_mlm.token_to_idx["X"] = 4 | |
| padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq) | |
| idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device) | |
| mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"] | |
| if not mask_positions: | |
| raise ValueError("Nothing found in the sequence to predict.") | |
| for mask_position in mask_positions: | |
| padded_seq[mask_position] = "[MASK]" | |
| input_ids = vocab_mlm.__getitem__(padded_seq) | |
| input_ids = torch.tensor([input_ids]).to(device) | |
| logits = model(input_ids, idx_msa) | |
| mask_logits = logits[0, mask_position, :] | |
| predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5) | |
| topk.append(predicted_token_id) | |
| predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()] | |
| predicted_token_probability_all.append(predicted_token_probability[0].item()) | |
| padded_seq[mask_position] = predicted_token | |
| cls_pos = vocab_mlm.to_tokens(list(topk[0])) | |
| if X1 != "X": | |
| Topk = X1 | |
| Subtype = X1 | |
| Potency = padded_seq[2],predicted_token_probability_all[0] | |
| elif X2 != "X": | |
| Topk = cls_pos | |
| Subtype = padded_seq[1],predicted_token_probability_all[0] | |
| Potency = X2 | |
| else: | |
| Topk = cls_pos | |
| Subtype = padded_seq[1],predicted_token_probability_all[0] | |
| Potency = padded_seq[2],predicted_token_probability_all[1] | |
| return Subtype, Potency, Topk | |
| iface = gr.Interface( | |
| fn=CTXGen, | |
| inputs=[ | |
| gr.Dropdown(choices=['X', '<AChBP>', '<Ca12>', '<Ca13>', '<Ca22>', '<Ca23>', '<GABA>', '<GluN2A>', '<GluN2B>', '<GluN2C>', '<GluN2D>', '<GluN3A>', | |
| '<K11>', '<K12>', '<K13>', '<K16>', '<K17>', '<Kshaker>', | |
| '<Na11>', '<Na12>', '<Na13>', '<Na14>', '<Na15>', '<Na16>', '<Na17>', '<Na18>', '<NaTTXR>', '<NaTTXS>', '<NavBh>', '<NET>', | |
| '<α1AAR>', '<α1BAR>', '<α1β1γ>', '<α1β1γδ>', '<α1β1δ>', '<α1β1δε>', '<α1β1ε>', '<α2β2>', '<α2β4>', '<α3β2>', '<α3β4>', | |
| '<α4β2>', '<α4β4>', '<α6α3β2>', '<α6α3β2β3>', '<α6α3β4>', '<α6α3β4β3>', '<α6β3β4>', '<α6β4>', '<α7>', '<α7α6β2>', | |
| '<α75HT3>', '<α9>', '<α9α10>'], label="Subtype"), | |
| gr.Dropdown(choices=['X','<high>','<low>'], label="Potency"), | |
| gr.Textbox(label="Conotoxin"), | |
| gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Subtype"), | |
| gr.Textbox(label="Potency"), | |
| gr.Textbox(label="Top5") | |
| ], | |
| title="Conotoxin Label Prediction", | |
| description=""" | |
| 🔗 **[Label Prediction](https://huggingface.co/spaces/oucgc1996/CreoPep_Label_Prediction)** | |
| 🔗 **[Unconstrained Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_Unconstrained_generation)** | |
| 🔗 **[Conditional Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_conditional_generation)** | |
| 🔗 **[Optimization Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_optimization_generation)** | |
| ✅ **Subtype**: X if needs to be predicted. | |
| ✅ **Potency**: X if needs to be predicted. | |
| ✅ **Conotoxin**: conotoxin needs to be predicted. | |
| ✅ **Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details. | |
| """ | |
| ) | |
| iface.launch() |