Ludvig commited on
Commit
f8cfd7b
·
1 Parent(s): 747851f

Many improvements and fixes

Browse files
Files changed (3) hide show
  1. app.py +21 -27
  2. design.py +49 -10
  3. plot.R +44 -1
app.py CHANGED
@@ -10,7 +10,6 @@ import streamlit as st # Import last
10
  import pandas as pd
11
  from pandas.api.types import is_float_dtype
12
  from itertools import combinations
13
- from collections import OrderedDict
14
 
15
  from utils import call_subprocess, clean_string_for_non_alphanumerics, clean_str_column
16
  from data import read_data, read_data_cached, DownloadHeader, generate_data
@@ -157,7 +156,6 @@ elif input_choice == "Upload counts":
157
  st.session_state["step"] = 2
158
 
159
  if st.session_state["step"] >= 2:
160
- print(st.session_state["count_data"])
161
  # Ensure targets and predictions are clean strings
162
  st.session_state["count_data"][target_col] = clean_str_column(
163
  st.session_state["count_data"][target_col]
@@ -242,47 +240,43 @@ elif input_choice == "Enter counts":
242
  st.session_state["count_data"] = pd.DataFrame(
243
  all_pairs, columns=["Target", "Prediction"]
244
  )
 
245
 
246
  st.session_state["step"] = 1
247
 
248
  if st.session_state["step"] >= 1:
249
  with st.form(key="enter_counts_form"):
250
- st.write("Fill in the counts for `N(Target, Prediction)` pairs.")
251
- count_input_fields = OrderedDict()
252
-
253
- num_cols = 3
254
- cols = st.columns(num_cols)
255
- for i, (targ, pred) in enumerate(
256
- zip(
257
- st.session_state["count_data"]["Target"],
258
- st.session_state["count_data"]["Prediction"],
259
- )
260
- ):
261
- count_input_fields[f"{targ}____{pred}"] = cols[
262
- i % num_cols
263
- ].number_input(f"N({targ}, {pred})", step=1)
 
264
 
265
  if st.form_submit_button(
266
  label="Generate data",
267
  ):
268
- st.session_state["count_data"]["N"] = [
269
- int(val) for val in count_input_fields.values()
270
- ]
271
  st.session_state["step"] = 2
272
 
273
  if st.session_state["step"] >= 2:
274
  DownloadHeader.header_and_data_download(
275
- "Entered counts",
276
  data=st.session_state["count_data"],
277
  file_name="confusion_matrix_counts.csv",
 
278
  help="Download counts",
279
- col_sizes=[10, 2],
280
  )
281
- col1, col2, col3 = st.columns([4, 5, 4])
282
- with col2:
283
- st.write(st.session_state["count_data"])
284
- st.write(f"{st.session_state['count_data'].shape}")
285
-
286
  target_col = "Target"
287
  prediction_col = "Prediction"
288
  n_col = "N"
@@ -318,7 +312,7 @@ if st.session_state["step"] >= 2:
318
  st.subheader("The Data")
319
  col1, col2, col3 = st.columns([2, 2, 2])
320
  with col2:
321
- st.write(df.head(5))
322
  st.write(f"{df.shape} (Showing first 5 rows)")
323
 
324
  else:
 
10
  import pandas as pd
11
  from pandas.api.types import is_float_dtype
12
  from itertools import combinations
 
13
 
14
  from utils import call_subprocess, clean_string_for_non_alphanumerics, clean_str_column
15
  from data import read_data, read_data_cached, DownloadHeader, generate_data
 
156
  st.session_state["step"] = 2
157
 
158
  if st.session_state["step"] >= 2:
 
159
  # Ensure targets and predictions are clean strings
160
  st.session_state["count_data"][target_col] = clean_str_column(
161
  st.session_state["count_data"][target_col]
 
240
  st.session_state["count_data"] = pd.DataFrame(
241
  all_pairs, columns=["Target", "Prediction"]
242
  )
243
+ st.session_state["count_data"]["N"] = 0
244
 
245
  st.session_state["step"] = 1
246
 
247
  if st.session_state["step"] >= 1:
248
  with st.form(key="enter_counts_form"):
249
+ st.write(
250
+ "Fill in the counts by pressing each cell in the `N` column and inputting the counts."
251
+ )
252
+
253
+ new_counts = st.data_editor(
254
+ st.session_state["count_data"],
255
+ hide_index=True,
256
+ column_config={
257
+ "Target": st.column_config.TextColumn(disabled=True),
258
+ "Prediction": st.column_config.TextColumn(disabled=True),
259
+ "N": st.column_config.NumberColumn(
260
+ disabled=False, min_value=0, step=1
261
+ ),
262
+ },
263
+ )
264
 
265
  if st.form_submit_button(
266
  label="Generate data",
267
  ):
268
+ st.session_state["count_data"] = new_counts
 
 
269
  st.session_state["step"] = 2
270
 
271
  if st.session_state["step"] >= 2:
272
  DownloadHeader.header_and_data_download(
273
+ "",
274
  data=st.session_state["count_data"],
275
  file_name="confusion_matrix_counts.csv",
276
+ label="Download counts",
277
  help="Download counts",
278
+ col_sizes=[10, 3],
279
  )
 
 
 
 
 
280
  target_col = "Target"
281
  prediction_col = "Prediction"
282
  n_col = "N"
 
312
  st.subheader("The Data")
313
  col1, col2, col3 = st.columns([2, 2, 2])
314
  with col2:
315
+ st.dataframe(df.head(5), hide_index=True)
316
  st.write(f"{df.shape} (Showing first 5 rows)")
317
 
318
  else:
design.py CHANGED
@@ -78,16 +78,14 @@ def design_section(
78
  "Any observation with either a target or prediction "
79
  "of another class is excluded.",
80
  )
81
- # TODO: Once the arrow bug in cvms is fixed, enable reversing!
82
  # Reverse by default
83
- # selected_classes.reverse()
84
  with col2:
85
- # st.write(" ")
86
- # st.write(" ")
87
- # reverse_class_order = st.checkbox(
88
- # "Reverse order", value=False, help="Reverse the order of the classes."
89
- # )
90
- pass
91
 
92
  # Color palette
93
  output["palette"] = _add_select_box(
@@ -151,6 +149,45 @@ def design_section(
151
  st.markdown("""---""")
152
  st.markdown("**Advanced**:")
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  with st.expander("Elements"):
155
  col1, col2 = st.columns(2)
156
  with col1:
@@ -454,8 +491,10 @@ def design_section(
454
  # Save settings as json
455
  with open(design_settings_store_path, "w") as f:
456
  json.dump(output, f)
457
- # if reverse_class_order:
458
- # selected_classes.reverse()
 
 
459
 
460
  design_ready = False
461
  if st.session_state["step"] >= 3:
 
78
  "Any observation with either a target or prediction "
79
  "of another class is excluded.",
80
  )
 
81
  # Reverse by default
82
+ selected_classes.reverse()
83
  with col2:
84
+ st.write(" ")
85
+ st.write(" ")
86
+ reverse_class_order = st.checkbox(
87
+ "Reverse order", value=False, help="Reverse the order of the classes."
88
+ )
 
89
 
90
  # Color palette
91
  output["palette"] = _add_select_box(
 
149
  st.markdown("""---""")
150
  st.markdown("**Advanced**:")
151
 
152
+ with st.expander("Labels"):
153
+ col1, col2 = st.columns(2)
154
+ with col1:
155
+ output["x_label"] = st.text_input(
156
+ "x-axis",
157
+ value=get_uploaded_setting(
158
+ key="x_label", default="True Class", type_=str
159
+ ),
160
+ )
161
+ with col2:
162
+ output["y_label"] = st.text_input(
163
+ "y-axis",
164
+ value=get_uploaded_setting(
165
+ key="y_label", default="Predicted Class", type_=str
166
+ ),
167
+ )
168
+
169
+ st.markdown("---")
170
+ col1, col2 = st.columns(2)
171
+ with col1:
172
+ output["title_label"] = st.text_input(
173
+ "Title",
174
+ value=get_uploaded_setting(
175
+ key="title_label", default="", type_=str
176
+ ),
177
+ )
178
+ with col2:
179
+ output["caption_label"] = st.text_input(
180
+ "Caption",
181
+ value=get_uploaded_setting(
182
+ key="caption_label", default="", type_=str
183
+ ),
184
+ )
185
+ st.info(
186
+ "Note: When adding a title or caption, "
187
+ "you may need to adjust the height and "
188
+ "width of the plot as well."
189
+ )
190
+
191
  with st.expander("Elements"):
192
  col1, col2 = st.columns(2)
193
  with col1:
 
491
  # Save settings as json
492
  with open(design_settings_store_path, "w") as f:
493
  json.dump(output, f)
494
+ if not output["place_x_axis_above"]:
495
+ selected_classes.reverse()
496
+ if reverse_class_order:
497
+ selected_classes.reverse()
498
 
499
  design_ready = False
500
  if st.session_state["step"] >= 3:
plot.R CHANGED
@@ -267,6 +267,16 @@ if (isTRUE(design_settings$counts_on_top) ||
267
  )
268
  }
269
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  confusion_matrix_plot <- tryCatch(
272
  {
@@ -282,7 +292,13 @@ confusion_matrix_plot <- tryCatch(
282
  rm_zero_text = !design_settings$show_zero_text,
283
  add_zero_shading = design_settings$show_zero_shading,
284
  add_arrows = design_settings$show_arrows,
 
 
 
 
285
  counts_on_top = design_settings$counts_on_top,
 
 
286
  diag_percentages_only = design_settings$diag_percentages_only,
287
  digits = as.integer(design_settings$num_digits),
288
  palette = design_settings$palette,
@@ -290,7 +306,10 @@ confusion_matrix_plot <- tryCatch(
290
  font_counts = do.call("font", counts_font_args),
291
  font_normalized = do.call("font", normalized_font_args),
292
  font_row_percentages = do.call("font", percentages_font_args),
293
- font_col_percentages = do.call("font", percentages_font_args)
 
 
 
294
  )
295
  },
296
  error = function(e) {
@@ -301,6 +320,30 @@ confusion_matrix_plot <- tryCatch(
301
  }
302
  )
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  tryCatch(
305
  {
306
  ggplot2::ggsave(
 
267
  )
268
  }
269
 
270
+ tile_border_color <- NA
271
+ if (isTRUE(design_settings$show_tile_border)) {
272
+ tile_border_color <- design_settings$tile_border_color
273
+ }
274
+
275
+ intensity_by <- ifelse(
276
+ tolower(design_settings$intensity_by) == "counts",
277
+ "counts",
278
+ "normalized"
279
+ )
280
 
281
  confusion_matrix_plot <- tryCatch(
282
  {
 
292
  rm_zero_text = !design_settings$show_zero_text,
293
  add_zero_shading = design_settings$show_zero_shading,
294
  add_arrows = design_settings$show_arrows,
295
+ arrow_size = design_settings$arrow_size,
296
+ arrow_nudge_from_text = design_settings$arrow_nudge_from_text,
297
+ intensity_by = intensity_by,
298
+ darkness = design_settings$darkness,
299
  counts_on_top = design_settings$counts_on_top,
300
+ place_x_axis_above = design_settings$place_x_axis_above,
301
+ rotate_y_text = design_settings$rotate_y_text,
302
  diag_percentages_only = design_settings$diag_percentages_only,
303
  digits = as.integer(design_settings$num_digits),
304
  palette = design_settings$palette,
 
306
  font_counts = do.call("font", counts_font_args),
307
  font_normalized = do.call("font", normalized_font_args),
308
  font_row_percentages = do.call("font", percentages_font_args),
309
+ font_col_percentages = do.call("font", percentages_font_args),
310
+ tile_border_color = tile_border_color,
311
+ tile_border_size = design_settings$tile_border_size,
312
+ tile_border_linetype = design_settings$tile_border_linetype
313
  )
314
  },
315
  error = function(e) {
 
320
  }
321
  )
322
 
323
+ # Add labels on x and y axes
324
+ confusion_matrix_plot <- confusion_matrix_plot +
325
+ ggplot2::labs(
326
+ x = design_settings$x_label,
327
+ y = design_settings$y_label
328
+ )
329
+
330
+ # Add title
331
+ if (nchar(design_settings$title_label) > 0) {
332
+ confusion_matrix_plot <- confusion_matrix_plot +
333
+ ggplot2::labs(
334
+ title = design_settings$title_label
335
+ )
336
+ }
337
+
338
+ # Add caption
339
+ if (nchar(design_settings$caption_label) > 0) {
340
+ confusion_matrix_plot <- confusion_matrix_plot +
341
+ ggplot2::labs(
342
+ caption = design_settings$caption_label
343
+ )
344
+ }
345
+
346
+
347
  tryCatch(
348
  {
349
  ggplot2::ggsave(