cazzz307 commited on
Commit
e926d80
·
verified ·
1 Parent(s): aa5e049

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from transformers import BertTokenizerFast, BertForMaskedLM
5
+
6
+ MODEL_NAME = "bert-base-uncased"
7
+
8
+ # Load model & tokenizer once
9
+ tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
10
+ model = BertForMaskedLM.from_pretrained(MODEL_NAME)
11
+ model.eval()
12
+
13
+ NUM_LAYERS = model.config.num_hidden_layers # 12 for bert-base-uncased
14
+
15
+ @torch.inference_mode()
16
+ def analyze(text: str, layer_idx: int):
17
+ """
18
+ text: user input (ideally contains [MASK])
19
+ layer_idx: 1..NUM_LAYERS (which transformer block to visualise)
20
+ """
21
+ if not text.strip():
22
+ return (
23
+ "<span style='color:#888'>Type some text above…</span>",
24
+ "No [MASK] token, so I can’t show predictions.",
25
+ None,
26
+ None,
27
+ "Please type some text containing the [MASK] token."
28
+ )
29
+
30
+ # Tokenize
31
+ inputs = tokenizer(
32
+ text,
33
+ return_tensors="pt",
34
+ add_special_tokens=True
35
+ )
36
+
37
+ input_ids = inputs["input_ids"]
38
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
39
+
40
+ # Find [MASK] position (if any)
41
+ mask_token_id = tokenizer.mask_token_id
42
+ mask_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0]
43
+ mask_idx = int(mask_positions[0].item()) if len(mask_positions) > 0 else None
44
+
45
+ # Run BERT encoder to get hidden states and attention
46
+ outputs = model.bert(
47
+ **inputs,
48
+ output_hidden_states=True,
49
+ output_attentions=True,
50
+ return_dict=True,
51
+ )
52
+
53
+ hidden_states = outputs.hidden_states # tuple: (emb, layer1, ..., layer12)
54
+ attentions = outputs.attentions # tuple: (layer1..layer12), each [1, heads, seq, seq]
55
+
56
+ # We'll compute predictions for ALL layers for the [MASK], then slice for plots
57
+ layer_probs = [] # probability of best token per layer (or mask prob mass)
58
+ layer_best_tokens = [] # best token name per layer
59
+
60
+ if mask_idx is not None:
61
+ for L in range(1, NUM_LAYERS + 1):
62
+ hs = hidden_states[L] # [1, seq, hidden]
63
+ logits = model.cls(hs) # [1, seq, vocab]
64
+ mask_logits = logits[0, mask_idx, :]
65
+ probs = torch.softmax(mask_logits, dim=-1)
66
+
67
+ topk = torch.topk(probs, k=5)
68
+ top_tokens = tokenizer.convert_ids_to_tokens(topk.indices.tolist())
69
+ top_probs = topk.values.tolist()
70
+
71
+ # store best token per layer
72
+ layer_probs.append(float(top_probs[0]))
73
+ layer_best_tokens.append(top_tokens[0])
74
+ else:
75
+ # no [MASK]: we won't run MLM head for curve, but everything else still works
76
+ layer_probs = [0.0] * NUM_LAYERS
77
+ layer_best_tokens = ["(no [MASK])"] * NUM_LAYERS
78
+
79
+ # ---- Data for the selected layer ----
80
+ L = int(layer_idx)
81
+ L_hidden = hidden_states[L][0] # [seq, hidden]
82
+ # token "confidence" = norm of hidden vector, normalised for visualisation
83
+ norms = torch.norm(L_hidden, dim=-1)
84
+ norms_np = norms.cpu().numpy()
85
+ if norms_np.max() > 0:
86
+ conf = norms_np / norms_np.max()
87
+ else:
88
+ conf = norms_np
89
+
90
+ # attention for this layer, head 0
91
+ L_att = attentions[L - 1][0, 0].cpu().numpy() # [seq, seq]
92
+ # ensure it's [0,1]
93
+ L_att = (L_att - L_att.min()) / (L_att.max() - L_att.min() + 1e-9)
94
+
95
+ # ---- 1) Token visualisation (HTML with confidence-based background) ----
96
+ token_spans = []
97
+ for i, tok in enumerate(tokens):
98
+ c = conf[i] if i < len(conf) else 0.0
99
+ bg = f"rgba(34,197,94,{0.15 + 0.7*c})" # green-ish
100
+ border = "#22c55e" if i == mask_idx else "rgba(148,163,184,0.4)"
101
+ token_spans.append(
102
+ f"<span style='padding:2px 4px; margin:1px; border-radius:4px; "
103
+ f"border:1px solid {border}; background:{bg}; font-size:12px; "
104
+ f"display:inline-block;'>{tok}</span>"
105
+ )
106
+ tokens_html = "<div style='line-height:1.8;'>" + " ".join(token_spans) + "</div>"
107
+
108
+ # ---- 2) Top-k predictions for [MASK] at this layer ----
109
+ if mask_idx is not None:
110
+ hs_L = hidden_states[L] # [1, seq, hidden]
111
+ logits_L = model.cls(hs_L)
112
+ mask_logits_L = logits_L[0, mask_idx, :]
113
+ probs_L = torch.softmax(mask_logits_L, dim=-1)
114
+ topk_L = torch.topk(probs_L, k=10)
115
+ top_tokens_L = tokenizer.convert_ids_to_tokens(topk_L.indices.tolist())
116
+ top_probs_L = topk_L.values.tolist()
117
+
118
+ # Build a markdown table
119
+ lines = ["| Rank | Token | Prob |", "|------|-------|------|"]
120
+ for rank, (tok, p) in enumerate(zip(top_tokens_L, top_probs_L), start=1):
121
+ lines.append(f"| {rank} | `{tok}` | {p:.3f} |")
122
+ pred_md = "\n".join(lines)
123
+ else:
124
+ pred_md = (
125
+ "There is **no `[MASK]` token** in your input.\n\n"
126
+ "To see layer-wise predictions, include `[MASK]` somewhere in the text.\n"
127
+ "Example: `The capital of France is [MASK].`"
128
+ )
129
+
130
+ # ---- 3) Probability curve across layers ----
131
+ if mask_idx is not None:
132
+ import plotly.graph_objs as go
133
+
134
+ x = list(range(1, NUM_LAYERS + 1))
135
+ y = layer_probs
136
+
137
+ fig_prob = go.Figure()
138
+ fig_prob.add_trace(go.Scatter(
139
+ x=x,
140
+ y=y,
141
+ mode="lines+markers",
142
+ name="P(top token at [MASK])"
143
+ ))
144
+ fig_prob.update_layout(
145
+ xaxis_title="Layer",
146
+ yaxis_title="Probability of best prediction",
147
+ template="plotly_dark",
148
+ height=320,
149
+ margin=dict(l=40, r=20, t=40, b=40),
150
+ )
151
+ else:
152
+ fig_prob = None
153
+
154
+ # ---- 4) Attention heatmap for selected layer ----
155
+ import plotly.graph_objs as go
156
+ att_fig = go.Figure(
157
+ data=go.Heatmap(
158
+ z=L_att,
159
+ x=tokens,
160
+ y=tokens,
161
+ colorbar=dict(title="Attention"),
162
+ )
163
+ )
164
+ att_fig.update_layout(
165
+ xaxis_title="Key tokens",
166
+ yaxis_title="Query tokens",
167
+ template="plotly_dark",
168
+ height=420,
169
+ margin=dict(l=80, r=60, t=40, b=120),
170
+ )
171
+
172
+ # ---- 5) Info text ----
173
+ info = (
174
+ f"### Layer {L} summary\n"
175
+ f"- Hidden-state norms are used as a proxy for **token confidence** (bright = higher norm).\n"
176
+ f"- The heatmap shows **self-attention weights** for layer {L}, head 1.\n"
177
+ )
178
+ if mask_idx is not None:
179
+ best_current = layer_best_tokens[L - 1]
180
+ info += (
181
+ f"- At this layer, the top prediction for `[MASK]` is `{best_current}`.\n"
182
+ f"- The line chart shows how the model’s confidence in its *current* best prediction "
183
+ f"evolves across layers.\n"
184
+ )
185
+ else:
186
+ info += (
187
+ "- No `[MASK]` token detected, so layer-wise predictions are disabled. "
188
+ "Add `[MASK]` to explore how different layers refine the guess.\n"
189
+ )
190
+
191
+ return tokens_html, pred_md, fig_prob, att_fig, info
192
+
193
+
194
+ # ------------- Gradio UI ------------- #
195
+
196
+ DESCRIPTION = """
197
+ # 🔍 Transformer Layer Playground (BERT)
198
+
199
+ Explore how a real transformer (**bert-base-uncased**) processes text *layer by layer*.
200
+
201
+ - Type some text and choose a **layer** (1–12).
202
+ - If you include `[MASK]`, you’ll see **layer-wise predictions** at that position.
203
+ - Visualisations:
204
+ - Token chips, where brightness ≈ **hidden state norm** (a rough proxy for confidence/activation).
205
+ - A **line chart** of how the probability of the top prediction at `[MASK]` changes across layers.
206
+ - A full **attention heatmap** for the selected layer and head 1.
207
+ """
208
+
209
+ EXAMPLE_TEXTS = [
210
+ "The capital of France is [MASK].",
211
+ "Transformers are very [MASK] models.",
212
+ "I love eating [MASK] with tomato sauce.",
213
+ "The [MASK] barked loudly at the stranger."
214
+ ]
215
+
216
+ with gr.Blocks(theme=gr.themes.Monochrome(), css="""
217
+ #tokens-html { font-family: 'JetBrains Mono', monospace; }
218
+ """) as demo:
219
+ gr.Markdown(DESCRIPTION)
220
+
221
+ with gr.Row():
222
+ with gr.Column(scale=3):
223
+ text_in = gr.Textbox(
224
+ label="Input text (use [MASK] to see predictions)",
225
+ value="The capital of France is [MASK].",
226
+ lines=3,
227
+ placeholder="Type a sentence; include [MASK] somewhere."
228
+ )
229
+ layer_slider = gr.Slider(
230
+ minimum=1,
231
+ maximum=NUM_LAYERS,
232
+ value=4,
233
+ step=1,
234
+ label=f"Layer to visualise (1–{NUM_LAYERS})"
235
+ )
236
+
237
+ gr.Examples(
238
+ examples=EXAMPLE_TEXTS,
239
+ inputs=text_in,
240
+ label="Example prompts"
241
+ )
242
+
243
+ run_btn = gr.Button("Run", variant="primary")
244
+
245
+ with gr.Column(scale=5):
246
+ tokens_html = gr.HTML(label="Token representations", elem_id="tokens-html")
247
+
248
+ with gr.Row():
249
+ pred_out = gr.Markdown(label="Layer-wise predictions at [MASK]")
250
+ prob_plot = gr.Plot(label="Probability across layers")
251
+
252
+ att_plot = gr.Plot(label="Self-attention heatmap (selected layer, head 1)")
253
+ info_box = gr.Markdown(label="Explanation")
254
+
255
+ run_btn.click(
256
+ analyze,
257
+ inputs=[text_in, layer_slider],
258
+ outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box],
259
+ )
260
+
261
+ # Also run on change for a smoother experience
262
+ text_in.change(
263
+ analyze,
264
+ inputs=[text_in, layer_slider],
265
+ outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box],
266
+ )
267
+ layer_slider.change(
268
+ analyze,
269
+ inputs=[text_in, layer_slider],
270
+ outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box],
271
+ )
272
+
273
+ if __name__ == "__main__":
274
+ demo.launch()