import torch import numpy as np import gradio as gr from transformers import BertTokenizerFast, BertForMaskedLM MODEL_NAME = "bert-base-uncased" # Load model & tokenizer once tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME) model = BertForMaskedLM.from_pretrained(MODEL_NAME) model.eval() NUM_LAYERS = model.config.num_hidden_layers # 12 for bert-base-uncased @torch.inference_mode() def analyze(text: str, layer_idx: int): """ text: user input (ideally contains [MASK]) layer_idx: 1..NUM_LAYERS (which transformer block to visualise) """ if not text.strip(): return ( "Type some text above…", "No [MASK] token, so I can’t show predictions.", None, None, "Please type some text containing the [MASK] token." ) # Tokenize inputs = tokenizer( text, return_tensors="pt", add_special_tokens=True ) input_ids = inputs["input_ids"] tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) # Find [MASK] position (if any) mask_token_id = tokenizer.mask_token_id mask_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0] mask_idx = int(mask_positions[0].item()) if len(mask_positions) > 0 else None # Run BERT encoder to get hidden states and attention outputs = model.bert( **inputs, output_hidden_states=True, output_attentions=True, return_dict=True, ) hidden_states = outputs.hidden_states # tuple: (emb, layer1, ..., layer12) attentions = outputs.attentions # tuple: (layer1..layer12), each [1, heads, seq, seq] # We'll compute predictions for ALL layers for the [MASK], then slice for plots layer_probs = [] # probability of best token per layer (or mask prob mass) layer_best_tokens = [] # best token name per layer if mask_idx is not None: for L in range(1, NUM_LAYERS + 1): hs = hidden_states[L] # [1, seq, hidden] logits = model.cls(hs) # [1, seq, vocab] mask_logits = logits[0, mask_idx, :] probs = torch.softmax(mask_logits, dim=-1) topk = torch.topk(probs, k=5) top_tokens = tokenizer.convert_ids_to_tokens(topk.indices.tolist()) top_probs = topk.values.tolist() # store best token per layer layer_probs.append(float(top_probs[0])) layer_best_tokens.append(top_tokens[0]) else: # no [MASK]: we won't run MLM head for curve, but everything else still works layer_probs = [0.0] * NUM_LAYERS layer_best_tokens = ["(no [MASK])"] * NUM_LAYERS # ---- Data for the selected layer ---- L = int(layer_idx) L_hidden = hidden_states[L][0] # [seq, hidden] # token "confidence" = norm of hidden vector, normalised for visualisation norms = torch.norm(L_hidden, dim=-1) norms_np = norms.cpu().numpy() if norms_np.max() > 0: conf = norms_np / norms_np.max() else: conf = norms_np # attention for this layer, head 0 L_att = attentions[L - 1][0, 0].cpu().numpy() # [seq, seq] # ensure it's [0,1] L_att = (L_att - L_att.min()) / (L_att.max() - L_att.min() + 1e-9) # ---- 1) Token visualisation (HTML with confidence-based background) ---- token_spans = [] for i, tok in enumerate(tokens): c = conf[i] if i < len(conf) else 0.0 bg = f"rgba(34,197,94,{0.15 + 0.7*c})" # green-ish border = "#22c55e" if i == mask_idx else "rgba(148,163,184,0.4)" token_spans.append( f"{tok}" ) tokens_html = "
" + " ".join(token_spans) + "
" # ---- 2) Top-k predictions for [MASK] at this layer ---- if mask_idx is not None: hs_L = hidden_states[L] # [1, seq, hidden] logits_L = model.cls(hs_L) mask_logits_L = logits_L[0, mask_idx, :] probs_L = torch.softmax(mask_logits_L, dim=-1) topk_L = torch.topk(probs_L, k=10) top_tokens_L = tokenizer.convert_ids_to_tokens(topk_L.indices.tolist()) top_probs_L = topk_L.values.tolist() # Build a markdown table lines = ["| Rank | Token | Prob |", "|------|-------|------|"] for rank, (tok, p) in enumerate(zip(top_tokens_L, top_probs_L), start=1): lines.append(f"| {rank} | `{tok}` | {p:.3f} |") pred_md = "\n".join(lines) else: pred_md = ( "There is **no `[MASK]` token** in your input.\n\n" "To see layer-wise predictions, include `[MASK]` somewhere in the text.\n" "Example: `The capital of France is [MASK].`" ) # ---- 3) Probability curve across layers ---- if mask_idx is not None: import plotly.graph_objs as go x = list(range(1, NUM_LAYERS + 1)) y = layer_probs fig_prob = go.Figure() fig_prob.add_trace(go.Scatter( x=x, y=y, mode="lines+markers", name="P(top token at [MASK])" )) fig_prob.update_layout( xaxis_title="Layer", yaxis_title="Probability of best prediction", template="plotly_dark", height=320, margin=dict(l=40, r=20, t=40, b=40), ) else: fig_prob = None # ---- 4) Attention heatmap for selected layer ---- import plotly.graph_objs as go att_fig = go.Figure( data=go.Heatmap( z=L_att, x=tokens, y=tokens, colorbar=dict(title="Attention"), ) ) att_fig.update_layout( xaxis_title="Key tokens", yaxis_title="Query tokens", template="plotly_dark", height=420, margin=dict(l=80, r=60, t=40, b=120), ) # ---- 5) Info text ---- info = ( f"### Layer {L} summary\n" f"- Hidden-state norms are used as a proxy for **token confidence** (bright = higher norm).\n" f"- The heatmap shows **self-attention weights** for layer {L}, head 1.\n" ) if mask_idx is not None: best_current = layer_best_tokens[L - 1] info += ( f"- At this layer, the top prediction for `[MASK]` is `{best_current}`.\n" f"- The line chart shows how the model’s confidence in its *current* best prediction " f"evolves across layers.\n" ) else: info += ( "- No `[MASK]` token detected, so layer-wise predictions are disabled. " "Add `[MASK]` to explore how different layers refine the guess.\n" ) return tokens_html, pred_md, fig_prob, att_fig, info # ------------- Gradio UI ------------- # DESCRIPTION = """ # 🔍 Transformer Layer Playground (BERT) Explore how a real transformer (**bert-base-uncased**) processes text *layer by layer*. - Type some text and choose a **layer** (1–12). - If you include `[MASK]`, you’ll see **layer-wise predictions** at that position. - Visualisations: - Token chips, where brightness ≈ **hidden state norm** (a rough proxy for confidence/activation). - A **line chart** of how the probability of the top prediction at `[MASK]` changes across layers. - A full **attention heatmap** for the selected layer and head 1. """ EXAMPLE_TEXTS = [ "The capital of France is [MASK].", "Transformers are very [MASK] models.", "I love eating [MASK] with tomato sauce.", "The [MASK] barked loudly at the stranger." ] with gr.Blocks() as demo: # Optional styling (safe even on older Gradio versions) gr.HTML(""" """) gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=3): text_in = gr.Textbox( label="Input text (use [MASK] to see predictions)", value="The capital of France is [MASK].", lines=3, placeholder="Type a sentence; include [MASK] somewhere." ) layer_slider = gr.Slider( minimum=1, maximum=NUM_LAYERS, value=4, step=1, label=f"Layer to visualise (1–{NUM_LAYERS})" ) gr.Examples( examples=EXAMPLE_TEXTS, inputs=text_in, label="Example prompts" ) run_btn = gr.Button("Run", variant="primary") with gr.Column(scale=5): tokens_html = gr.HTML(label="Token representations", elem_id="tokens-html") with gr.Row(): pred_out = gr.Markdown(label="Layer-wise predictions at [MASK]") prob_plot = gr.Plot(label="Probability across layers") att_plot = gr.Plot(label="Self-attention heatmap (selected layer, head 1)") info_box = gr.Markdown(label="Explanation") run_btn.click( analyze, inputs=[text_in, layer_slider], outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box], ) # Allows instant update without clicking Run text_in.change( analyze, inputs=[text_in, layer_slider], outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box], ) layer_slider.change( analyze, inputs=[text_in, layer_slider], outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box], ) if __name__ == "__main__": demo.launch()