Ludvig commited on
Commit
c38363b
·
1 Parent(s): e29363a

Adds sub col option

Browse files
Files changed (6) hide show
  1. README.md +1 -3
  2. app.py +33 -3
  3. data.py +7 -1
  4. plot.R +14 -0
  5. text_sections.py +13 -3
  6. utils.py +29 -0
README.md CHANGED
@@ -20,6 +20,4 @@ Streamlit application for plotting a confusion matrix.
20
  - Add option to change zero-tile background (e.g. to black for black backgrounds)
21
  - Add option to format total-count tile in sum tiles
22
  - Selectable templates (for 2,3,4,5 classes - one selects num classes and pick a color scheme and other common defaults)
23
- - Add extra column in `Upload counts` that replaces whichever value is the bottom value (normally counts). Requires changes to cvms.
24
- - Allow handling tick text - e.g. for long class names or many classes.
25
- - Enable class order reversal after cvms arrow bug is fixed
 
20
  - Add option to change zero-tile background (e.g. to black for black backgrounds)
21
  - Add option to format total-count tile in sum tiles
22
  - Selectable templates (for 2,3,4,5 classes - one selects num classes and pick a color scheme and other common defaults)
23
+ - Allow handling tick text - e.g. for long class names or many classes.
 
 
app.py CHANGED
@@ -151,6 +151,11 @@ elif input_choice == "Upload counts":
151
  n_col = st.selectbox(
152
  "Counts column", options=list(st.session_state["count_data"].columns)
153
  )
 
 
 
 
 
154
 
155
  if st.form_submit_button(label="Set columns"):
156
  st.session_state["step"] = 2
@@ -240,6 +245,7 @@ elif input_choice == "Enter counts":
240
  st.session_state["count_data"] = pd.DataFrame(
241
  all_pairs, columns=["Target", "Prediction"]
242
  )
 
243
  st.session_state["count_data"]["N"] = 0
244
 
245
  st.session_state["step"] = 1
@@ -247,7 +253,17 @@ elif input_choice == "Enter counts":
247
  if st.session_state["step"] >= 1:
248
  with st.form(key="enter_counts_form"):
249
  st.write(
250
- "Fill in the counts by pressing each cell in the `N` column and inputting the counts."
 
 
 
 
 
 
 
 
 
 
251
  )
252
 
253
  new_counts = st.data_editor(
@@ -256,6 +272,12 @@ elif input_choice == "Enter counts":
256
  column_config={
257
  "Target": st.column_config.TextColumn(disabled=True),
258
  "Prediction": st.column_config.TextColumn(disabled=True),
 
 
 
 
 
 
259
  "N": st.column_config.NumberColumn(
260
  disabled=False, min_value=0, step=1
261
  ),
@@ -280,6 +302,8 @@ elif input_choice == "Enter counts":
280
  target_col = "Target"
281
  prediction_col = "Prediction"
282
  n_col = "N"
 
 
283
 
284
  if st.session_state["step"] >= 2:
285
  data_is_ready = False
@@ -302,7 +326,7 @@ if st.session_state["step"] >= 2:
302
  df[prediction_col] = clean_str_column(df[prediction_col])
303
 
304
  # Save to tmp directory to allow reading in R script
305
- df.to_csv(data_store_path)
306
 
307
  # Extract unique classes
308
  st.session_state["classes"] = sorted(
@@ -316,7 +340,10 @@ if st.session_state["step"] >= 2:
316
  st.write(f"{df.shape} (Showing first 5 rows)")
317
 
318
  else:
319
- st.session_state["count_data"].to_csv(data_store_path)
 
 
 
320
  data_is_ready = True
321
 
322
  if data_is_ready:
@@ -365,6 +392,9 @@ if st.session_state["step"] >= 2:
365
  f"{selected_classes_string}",
366
  ]
367
 
 
 
 
368
  if st.session_state["input_type"] == "counts":
369
  # The input data are counts
370
  plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"]
 
151
  n_col = st.selectbox(
152
  "Counts column", options=list(st.session_state["count_data"].columns)
153
  )
154
+ sub_col = st.selectbox(
155
+ "Sub column",
156
+ options=["--"] + list(st.session_state["count_data"].columns),
157
+ help="Optional! This column will replace the bottom text in the middle of the tiles.",
158
+ )
159
 
160
  if st.form_submit_button(label="Set columns"):
161
  st.session_state["step"] = 2
 
245
  st.session_state["count_data"] = pd.DataFrame(
246
  all_pairs, columns=["Target", "Prediction"]
247
  )
248
+ st.session_state["count_data"]["Sub"] = ""
249
  st.session_state["count_data"]["N"] = 0
250
 
251
  st.session_state["step"] = 1
 
253
  if st.session_state["step"] >= 1:
254
  with st.form(key="enter_counts_form"):
255
  st.write(
256
+ "Fill in the counts by pressing each cell in the `N` column and inputting the counts. "
257
+ )
258
+ st.markdown(
259
+ "(**Optional**) If you wish to specify the bottom text in the middle of the tiles, "
260
+ "you can fill in the `Sub` column.",
261
+ help="The `sub` column text replaces the bottom text (counts by default). "
262
+ "The design settings for the replaced element (e.g. counts) are used for this text instead.",
263
+ )
264
+ st.info(
265
+ "Note: Please click outside the cell before "
266
+ "pressing `Generate data` to register your change."
267
  )
268
 
269
  new_counts = st.data_editor(
 
272
  column_config={
273
  "Target": st.column_config.TextColumn(disabled=True),
274
  "Prediction": st.column_config.TextColumn(disabled=True),
275
+ "Sub": st.column_config.TextColumn(
276
+ help="This text replaces the bottom text (in the middle of the tiles). "
277
+ "By default, the counts are replaced. "
278
+ "Note that the settings for this text are named "
279
+ "by the text element it replaces (e.g. **Fonts**>>*Counts*)."
280
+ ),
281
  "N": st.column_config.NumberColumn(
282
  disabled=False, min_value=0, step=1
283
  ),
 
302
  target_col = "Target"
303
  prediction_col = "Prediction"
304
  n_col = "N"
305
+ sub_col = "Sub" if any(st.session_state["count_data"]["Sub"]) else None
306
+
307
 
308
  if st.session_state["step"] >= 2:
309
  data_is_ready = False
 
326
  df[prediction_col] = clean_str_column(df[prediction_col])
327
 
328
  # Save to tmp directory to allow reading in R script
329
+ df.to_csv(data_store_path, index=False)
330
 
331
  # Extract unique classes
332
  st.session_state["classes"] = sorted(
 
340
  st.write(f"{df.shape} (Showing first 5 rows)")
341
 
342
  else:
343
+ count_data_clean = st.session_state["count_data"].copy()
344
+ if not any(count_data_clean["Sub"]):
345
+ del count_data_clean["Sub"]
346
+ count_data_clean.to_csv(data_store_path, index=False)
347
  data_is_ready = True
348
 
349
  if data_is_ready:
 
392
  f"{selected_classes_string}",
393
  ]
394
 
395
+ if "sub_col" in locals() and sub_col is not None and sub_col != "--":
396
+ plotting_args += ["--sub_col", f"{sub_col}"]
397
+
398
  if st.session_state["input_type"] == "counts":
399
  # The input data are counts
400
  plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"]
data.py CHANGED
@@ -57,7 +57,13 @@ class DownloadHeader:
57
 
58
  @staticmethod
59
  def header_and_data_download(
60
- header, data, file_name, col_sizes=[9, 2], key=None, label="Download", help="Download data"
 
 
 
 
 
 
61
  ):
62
  col1, col2 = st.columns(col_sizes)
63
  with col1:
 
57
 
58
  @staticmethod
59
  def header_and_data_download(
60
+ header,
61
+ data,
62
+ file_name,
63
+ col_sizes=[9, 2],
64
+ key=None,
65
+ label="Download",
66
+ help="Download data",
67
  ):
68
  col1, col2 = st.columns(col_sizes)
69
  with col1:
plot.R CHANGED
@@ -36,6 +36,10 @@ option_list <- list(
36
  type = "character",
37
  help = "Count column (when `--data_are_counts`)."
38
  ),
 
 
 
 
39
  make_option(c("--classes"),
40
  type = "character",
41
  help = paste0(
@@ -82,6 +86,15 @@ if (!is.null(opt$n_col)) {
82
  n_col <- stringr::str_replace_all(n_col, " ", ".")
83
  }
84
 
 
 
 
 
 
 
 
 
 
85
  # Read and prepare data frame
86
  df <- tryCatch(
87
  {
@@ -282,6 +295,7 @@ confusion_matrix_plot <- tryCatch(
282
  {
283
  cvms::plot_confusion_matrix(
284
  confusion_matrix,
 
285
  class_order = classes,
286
  add_sums = design_settings$show_sums,
287
  add_counts = design_settings$show_counts,
 
36
  type = "character",
37
  help = "Count column (when `--data_are_counts`)."
38
  ),
39
+ make_option(c("--sub_col"),
40
+ type = "character",
41
+ help = "Sub column (when `--data_are_counts`)."
42
+ ),
43
  make_option(c("--classes"),
44
  type = "character",
45
  help = paste0(
 
86
  n_col <- stringr::str_replace_all(n_col, " ", ".")
87
  }
88
 
89
+ sub_col <- NULL
90
+ if (!is.null(opt$sub_col)) {
91
+ if (!data_are_counts) {
92
+ stop("`sub_col` can only be specified when data are counts.")
93
+ }
94
+ sub_col <- stringr::str_squish(opt$sub_col)
95
+ sub_col <- stringr::str_replace_all(sub_col, " ", ".")
96
+ }
97
+
98
  # Read and prepare data frame
99
  df <- tryCatch(
100
  {
 
295
  {
296
  cvms::plot_confusion_matrix(
297
  confusion_matrix,
298
+ sub_col = sub_col,
299
  class_order = classes,
300
  add_sums = design_settings$show_sums,
301
  add_counts = design_settings$show_counts,
text_sections.py CHANGED
@@ -45,6 +45,12 @@ def get_example_counts():
45
  {
46
  "Target": ["cl1", "cl2", "cl1", "cl2"],
47
  "Prediction": ["cl1", "cl2", "cl2", "cl1"],
 
 
 
 
 
 
48
  "N": [12, 10, 3, 5],
49
  }
50
  )
@@ -149,12 +155,14 @@ def upload_counts_text():
149
  "2) A `predicted classes` column. \n\n"
150
  "3) A `combination count` column for the "
151
  "combination frequency of 1 and 2. \n\n"
 
 
152
  "Other columns are currently ignored. "
153
  "In the next step, you will be asked to select the names of these two columns. "
154
  )
155
  with col2:
156
  st.write("Example of such a file:")
157
- st.write(get_example_counts())
158
 
159
 
160
  def upload_predictions_text():
@@ -171,7 +179,7 @@ def upload_predictions_text():
171
  )
172
  with col2:
173
  st.write("Example of such a file:")
174
- st.write(get_example_data())
175
 
176
 
177
  def columns_text():
@@ -184,7 +192,9 @@ def columns_text():
184
  def design_text():
185
  st.subheader("Design your plot")
186
  st.write("This is where you customize the design of your confusion matrix plot.")
187
- st.markdown("We suggest you go directly to `Generate plot` to see the starting point. Then go back and tweak to your liking!")
 
 
188
  st.markdown(
189
  "The *width* and *height* settings are usually necessary to adjust as they "
190
  "change the relative size of the elements. Try adjusting 100px at a "
 
45
  {
46
  "Target": ["cl1", "cl2", "cl1", "cl2"],
47
  "Prediction": ["cl1", "cl2", "cl2", "cl1"],
48
+ "Sub*": [
49
+ "(57/60)",
50
+ "(46/50)",
51
+ "(12/15)",
52
+ "(23/25)",
53
+ ],
54
  "N": [12, 10, 3, 5],
55
  }
56
  )
 
155
  "2) A `predicted classes` column. \n\n"
156
  "3) A `combination count` column for the "
157
  "combination frequency of 1 and 2. \n\n"
158
+ "4) (\\***Optionally**) a `sub` column with text "
159
+ "that replaces the bottom text in the middle of tiles. \n\n"
160
  "Other columns are currently ignored. "
161
  "In the next step, you will be asked to select the names of these two columns. "
162
  )
163
  with col2:
164
  st.write("Example of such a file:")
165
+ st.dataframe(get_example_counts(), hide_index=True)
166
 
167
 
168
  def upload_predictions_text():
 
179
  )
180
  with col2:
181
  st.write("Example of such a file:")
182
+ st.dataframe(get_example_data(), hide_index=True)
183
 
184
 
185
  def columns_text():
 
192
  def design_text():
193
  st.subheader("Design your plot")
194
  st.write("This is where you customize the design of your confusion matrix plot.")
195
+ st.markdown(
196
+ "We suggest you go directly to `Generate plot` to see the starting point. Then go back and tweak to your liking!"
197
+ )
198
  st.markdown(
199
  "The *width* and *height* settings are usually necessary to adjust as they "
200
  "change the relative size of the elements. Try adjusting 100px at a "
utils.py CHANGED
@@ -1,5 +1,13 @@
1
  import subprocess
2
  import re
 
 
 
 
 
 
 
 
3
 
4
 
5
  def call_subprocess(call_, message, return_output=False, encoding="UTF-8"):
@@ -8,6 +16,27 @@ def call_subprocess(call_, message, return_output=False, encoding="UTF-8"):
8
  try:
9
  out = subprocess.check_output(call_, shell=True, encoding=encoding)
10
  except subprocess.CalledProcessError as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  print(f"{message}: {call_}")
12
  raise e
13
  return out
 
1
  import subprocess
2
  import re
3
+ import streamlit as st
4
+ import json
5
+
6
+
7
+ def show_error(msg, action):
8
+ st.error(
9
+ f"Failed to {action}:\n\n...{msg}\n\nPlease [report](https://github.com/LudvigOlsen/plot_confusion_matrix/issues) this issue."
10
+ )
11
 
12
 
13
  def call_subprocess(call_, message, return_output=False, encoding="UTF-8"):
 
16
  try:
17
  out = subprocess.check_output(call_, shell=True, encoding=encoding)
18
  except subprocess.CalledProcessError as e:
19
+ if "Failed to create plot from confusion matrix." in e.output:
20
+ msg = e.output.split("Failed to create plot from confusion matrix.")[-1]
21
+ show_error(msg=msg, action="plot confusion matrix")
22
+ elif "Failed to read design settings as a json file" in e.output:
23
+ msg = e.output.split("Failed to read design settings as a json file")[
24
+ -1
25
+ ]
26
+ show_error(msg=msg, action="read design settings")
27
+ elif "Failed to read data from" in e.output:
28
+ msg = e.output.split("Failed to read data from")[-1]
29
+ show_error(msg=msg, action="read data")
30
+ elif "Failed to ggsave plot to:" in e.output:
31
+ msg = e.output.split("Failed to ggsave plot to:")[-1]
32
+ show_error(msg=msg, action="save plot")
33
+ else:
34
+ msg = e.output.split("\n\n")[-1]
35
+ st.error(
36
+ f"Unknown type of error: {msg}.\n\n"
37
+ "Please [report](https://github.com/LudvigOlsen/plot_confusion_matrix/issues) this issue."
38
+ )
39
+ print(e.output)
40
  print(f"{message}: {call_}")
41
  raise e
42
  return out