import torch
import numpy as np
import gradio as gr
from transformers import BertTokenizerFast, BertForMaskedLM
MODEL_NAME = "bert-base-uncased"
# Load model & tokenizer once
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
model = BertForMaskedLM.from_pretrained(MODEL_NAME)
model.eval()
NUM_LAYERS = model.config.num_hidden_layers # 12 for bert-base-uncased
@torch.inference_mode()
def analyze(text: str, layer_idx: int):
"""
text: user input (ideally contains [MASK])
layer_idx: 1..NUM_LAYERS (which transformer block to visualise)
"""
if not text.strip():
return (
"Type some text above…",
"No [MASK] token, so I can’t show predictions.",
None,
None,
"Please type some text containing the [MASK] token."
)
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
add_special_tokens=True
)
input_ids = inputs["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# Find [MASK] position (if any)
mask_token_id = tokenizer.mask_token_id
mask_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0]
mask_idx = int(mask_positions[0].item()) if len(mask_positions) > 0 else None
# Run BERT encoder to get hidden states and attention
outputs = model.bert(
**inputs,
output_hidden_states=True,
output_attentions=True,
return_dict=True,
)
hidden_states = outputs.hidden_states # tuple: (emb, layer1, ..., layer12)
attentions = outputs.attentions # tuple: (layer1..layer12), each [1, heads, seq, seq]
# We'll compute predictions for ALL layers for the [MASK], then slice for plots
layer_probs = [] # probability of best token per layer (or mask prob mass)
layer_best_tokens = [] # best token name per layer
if mask_idx is not None:
for L in range(1, NUM_LAYERS + 1):
hs = hidden_states[L] # [1, seq, hidden]
logits = model.cls(hs) # [1, seq, vocab]
mask_logits = logits[0, mask_idx, :]
probs = torch.softmax(mask_logits, dim=-1)
topk = torch.topk(probs, k=5)
top_tokens = tokenizer.convert_ids_to_tokens(topk.indices.tolist())
top_probs = topk.values.tolist()
# store best token per layer
layer_probs.append(float(top_probs[0]))
layer_best_tokens.append(top_tokens[0])
else:
# no [MASK]: we won't run MLM head for curve, but everything else still works
layer_probs = [0.0] * NUM_LAYERS
layer_best_tokens = ["(no [MASK])"] * NUM_LAYERS
# ---- Data for the selected layer ----
L = int(layer_idx)
L_hidden = hidden_states[L][0] # [seq, hidden]
# token "confidence" = norm of hidden vector, normalised for visualisation
norms = torch.norm(L_hidden, dim=-1)
norms_np = norms.cpu().numpy()
if norms_np.max() > 0:
conf = norms_np / norms_np.max()
else:
conf = norms_np
# attention for this layer, head 0
L_att = attentions[L - 1][0, 0].cpu().numpy() # [seq, seq]
# ensure it's [0,1]
L_att = (L_att - L_att.min()) / (L_att.max() - L_att.min() + 1e-9)
# ---- 1) Token visualisation (HTML with confidence-based background) ----
token_spans = []
for i, tok in enumerate(tokens):
c = conf[i] if i < len(conf) else 0.0
bg = f"rgba(34,197,94,{0.15 + 0.7*c})" # green-ish
border = "#22c55e" if i == mask_idx else "rgba(148,163,184,0.4)"
token_spans.append(
f"{tok}"
)
tokens_html = "
" + " ".join(token_spans) + "
"
# ---- 2) Top-k predictions for [MASK] at this layer ----
if mask_idx is not None:
hs_L = hidden_states[L] # [1, seq, hidden]
logits_L = model.cls(hs_L)
mask_logits_L = logits_L[0, mask_idx, :]
probs_L = torch.softmax(mask_logits_L, dim=-1)
topk_L = torch.topk(probs_L, k=10)
top_tokens_L = tokenizer.convert_ids_to_tokens(topk_L.indices.tolist())
top_probs_L = topk_L.values.tolist()
# Build a markdown table
lines = ["| Rank | Token | Prob |", "|------|-------|------|"]
for rank, (tok, p) in enumerate(zip(top_tokens_L, top_probs_L), start=1):
lines.append(f"| {rank} | `{tok}` | {p:.3f} |")
pred_md = "\n".join(lines)
else:
pred_md = (
"There is **no `[MASK]` token** in your input.\n\n"
"To see layer-wise predictions, include `[MASK]` somewhere in the text.\n"
"Example: `The capital of France is [MASK].`"
)
# ---- 3) Probability curve across layers ----
if mask_idx is not None:
import plotly.graph_objs as go
x = list(range(1, NUM_LAYERS + 1))
y = layer_probs
fig_prob = go.Figure()
fig_prob.add_trace(go.Scatter(
x=x,
y=y,
mode="lines+markers",
name="P(top token at [MASK])"
))
fig_prob.update_layout(
xaxis_title="Layer",
yaxis_title="Probability of best prediction",
template="plotly_dark",
height=320,
margin=dict(l=40, r=20, t=40, b=40),
)
else:
fig_prob = None
# ---- 4) Attention heatmap for selected layer ----
import plotly.graph_objs as go
att_fig = go.Figure(
data=go.Heatmap(
z=L_att,
x=tokens,
y=tokens,
colorbar=dict(title="Attention"),
)
)
att_fig.update_layout(
xaxis_title="Key tokens",
yaxis_title="Query tokens",
template="plotly_dark",
height=420,
margin=dict(l=80, r=60, t=40, b=120),
)
# ---- 5) Info text ----
info = (
f"### Layer {L} summary\n"
f"- Hidden-state norms are used as a proxy for **token confidence** (bright = higher norm).\n"
f"- The heatmap shows **self-attention weights** for layer {L}, head 1.\n"
)
if mask_idx is not None:
best_current = layer_best_tokens[L - 1]
info += (
f"- At this layer, the top prediction for `[MASK]` is `{best_current}`.\n"
f"- The line chart shows how the model’s confidence in its *current* best prediction "
f"evolves across layers.\n"
)
else:
info += (
"- No `[MASK]` token detected, so layer-wise predictions are disabled. "
"Add `[MASK]` to explore how different layers refine the guess.\n"
)
return tokens_html, pred_md, fig_prob, att_fig, info
# ------------- Gradio UI ------------- #
DESCRIPTION = """
# 🔍 Transformer Layer Playground (BERT)
Explore how a real transformer (**bert-base-uncased**) processes text *layer by layer*.
- Type some text and choose a **layer** (1–12).
- If you include `[MASK]`, you’ll see **layer-wise predictions** at that position.
- Visualisations:
- Token chips, where brightness ≈ **hidden state norm** (a rough proxy for confidence/activation).
- A **line chart** of how the probability of the top prediction at `[MASK]` changes across layers.
- A full **attention heatmap** for the selected layer and head 1.
"""
EXAMPLE_TEXTS = [
"The capital of France is [MASK].",
"Transformers are very [MASK] models.",
"I love eating [MASK] with tomato sauce.",
"The [MASK] barked loudly at the stranger."
]
with gr.Blocks() as demo:
# Optional styling (safe even on older Gradio versions)
gr.HTML("""
""")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=3):
text_in = gr.Textbox(
label="Input text (use [MASK] to see predictions)",
value="The capital of France is [MASK].",
lines=3,
placeholder="Type a sentence; include [MASK] somewhere."
)
layer_slider = gr.Slider(
minimum=1,
maximum=NUM_LAYERS,
value=4,
step=1,
label=f"Layer to visualise (1–{NUM_LAYERS})"
)
gr.Examples(
examples=EXAMPLE_TEXTS,
inputs=text_in,
label="Example prompts"
)
run_btn = gr.Button("Run", variant="primary")
with gr.Column(scale=5):
tokens_html = gr.HTML(label="Token representations", elem_id="tokens-html")
with gr.Row():
pred_out = gr.Markdown(label="Layer-wise predictions at [MASK]")
prob_plot = gr.Plot(label="Probability across layers")
att_plot = gr.Plot(label="Self-attention heatmap (selected layer, head 1)")
info_box = gr.Markdown(label="Explanation")
run_btn.click(
analyze,
inputs=[text_in, layer_slider],
outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box],
)
# Allows instant update without clicking Run
text_in.change(
analyze,
inputs=[text_in, layer_slider],
outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box],
)
layer_slider.change(
analyze,
inputs=[text_in, layer_slider],
outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box],
)
if __name__ == "__main__":
demo.launch()