Ludvig commited on
Commit
3bf7bf3
·
1 Parent(s): 5a49222

Fixes and improvements

Browse files
Files changed (3) hide show
  1. app.py +26 -38
  2. data.py +19 -0
  3. design.py +338 -90
app.py CHANGED
@@ -3,12 +3,21 @@ App for plotting confusion matrix with `cvms::plot_confusion_matrix()`.
3
 
4
  TODO:
5
  - IMPORTANT! Allow specifying which class probabilities are of! (See plot prob_of_class)
 
 
 
 
6
  - Allow setting threshold - manual, max J, spec/sens
7
  - Add bg box around confusion matrix plot as text dissappears on dark mode!
8
  - ggsave does not use dpi??
9
  - allow svg, pdf?
10
  - entered count -> counts (upload as well)
11
- - Add full reset button (empty cache on different files)
 
 
 
 
 
12
 
13
  """
14
 
@@ -61,6 +70,7 @@ def set_tmp_dir():
61
  temp_dir, temp_dir_path = set_tmp_dir()
62
  gen_data_store_path = pathlib.Path(f"{temp_dir_path}/generated_data.csv")
63
  data_store_path = pathlib.Path(f"{temp_dir_path}/data.csv")
 
64
  conf_mat_path = pathlib.Path(f"{temp_dir_path}/confusion_matrix.png")
65
 
66
 
@@ -312,65 +322,43 @@ if st.session_state["step"] >= 2:
312
  )
313
 
314
  # Section for specifying design settings
315
- design_settings = design_section(
 
316
  num_classes=num_classes,
317
  predictions_are_probabilities=predictions_are_probabilities,
 
318
  )
319
 
320
- # design_ready tells us whether to proceed or wait
321
  # for user to fix issues
322
- if st.session_state["step"] >= 3 and design_settings["design_ready"]:
323
- # TODO Fix and update these flags
324
- element_flags = [
325
- key
326
- for key, val in {
327
- "--add_counts": design_settings["show_counts"],
328
- "--add_normalized": design_settings["show_normalized"],
329
- "--add_sums": design_settings["show_sums"],
330
- "--add_row_percentages": design_settings["show_row_percentages"],
331
- "--add_col_percentages": design_settings["show_col_percentages"],
332
- "--add_arrows": design_settings["show_arrows"],
333
- "--add_zero_percentages": design_settings["show_zero_percentages"],
334
- "--add_zero_text": design_settings["show_zero_text"],
335
- "--add_zero_shading": design_settings["show_zero_shading"],
336
- "--add_tile_border": design_settings["show_tile_border"],
337
- "--counts_on_top": design_settings["counts_on_top"],
338
- "--diag_percentages_only": design_settings["diag_percentages_only"],
339
- "--rotate_y_text": design_settings["rotate_y_text"],
340
- "--place_x_axis_above": design_settings["place_x_axis_above"],
341
- }.items()
342
- if val
343
- ]
344
 
345
  plotting_args = [
346
  "--data_path",
347
  f"'{data_store_path}'",
348
  "--out_path",
349
  f"'{conf_mat_path}'",
 
 
350
  "--target_col",
351
  f"'{target_col}'",
352
  "--prediction_col",
353
  f"'{prediction_col}'",
354
- "--width",
355
- f"{design_settings['width']}",
356
- "--height",
357
- f"{design_settings['height']}",
358
- "--dpi",
359
- f"{design_settings['dpi']}",
360
  "--classes",
361
- f"{','.join(design_settings['selected_classes'])}",
362
- "--digits",
363
- f"{design_settings['num_digits']}",
364
- "--palette",
365
- f"{design_settings['palette']}",
366
  ]
367
 
368
  if st.session_state["input_type"] == "counts":
369
  # The input data are counts
370
  plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"]
371
 
372
- plotting_args += element_flags
373
-
374
  plotting_args = " ".join(plotting_args)
375
 
376
  call_subprocess(
 
3
 
4
  TODO:
5
  - IMPORTANT! Allow specifying which class probabilities are of! (See plot prob_of_class)
6
+ - IMPORTANT! Use json/txt file to pass settings to r instead?
7
+ - IMPORTANT! Allow saving and uploading design settings - so many of them
8
+ that one shouldn't have to enter all the changes for every plot
9
+ when making multiple at a time!
10
  - Allow setting threshold - manual, max J, spec/sens
11
  - Add bg box around confusion matrix plot as text dissappears on dark mode!
12
  - ggsave does not use dpi??
13
  - allow svg, pdf?
14
  - entered count -> counts (upload as well)
15
+ - Add full reset button (empty cache on different files) - callback?
16
+ - Handle <2 classes in design box (add st.error)
17
+ - Handle classes with spaces in them?
18
+
19
+ NOTE:
20
+
21
 
22
  """
23
 
 
70
  temp_dir, temp_dir_path = set_tmp_dir()
71
  gen_data_store_path = pathlib.Path(f"{temp_dir_path}/generated_data.csv")
72
  data_store_path = pathlib.Path(f"{temp_dir_path}/data.csv")
73
+ design_settings_store_path = pathlib.Path(f"{temp_dir_path}/design_settings.json")
74
  conf_mat_path = pathlib.Path(f"{temp_dir_path}/confusion_matrix.png")
75
 
76
 
 
322
  )
323
 
324
  # Section for specifying design settings
325
+
326
+ design_settings, design_ready, selected_classes, prob_of_class = design_section(
327
  num_classes=num_classes,
328
  predictions_are_probabilities=predictions_are_probabilities,
329
+ design_settings_store_path=design_settings_store_path,
330
  )
331
 
332
+ # design_ready tells us whether to proceed or wait
333
  # for user to fix issues
334
+ if st.session_state["step"] >= 3 and design_ready:
335
+ DownloadHeader.header_and_json_download(
336
+ header="Confusion Matrix Plot",
337
+ data=design_settings,
338
+ file_name="design_settings.json",
339
+ label="Download design settings",
340
+ help="Download the design settings to allow reusing setttings in future plots.",
341
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  plotting_args = [
344
  "--data_path",
345
  f"'{data_store_path}'",
346
  "--out_path",
347
  f"'{conf_mat_path}'",
348
+ "--settings_path",
349
+ f"'{design_settings_store_path}'",
350
  "--target_col",
351
  f"'{target_col}'",
352
  "--prediction_col",
353
  f"'{prediction_col}'",
 
 
 
 
 
 
354
  "--classes",
355
+ f"{','.join(selected_classes)}",
 
 
 
 
356
  ]
357
 
358
  if st.session_state["input_type"] == "counts":
359
  # The input data are counts
360
  plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"]
361
 
 
 
362
  plotting_args = " ".join(plotting_args)
363
 
364
  call_subprocess(
data.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pathlib
2
  import pandas as pd
3
  import streamlit as st
@@ -70,3 +71,21 @@ class DownloadHeader:
70
  key=key,
71
  help=help,
72
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import pathlib
3
  import pandas as pd
4
  import streamlit as st
 
71
  key=key,
72
  help=help,
73
  )
74
+
75
+ @staticmethod
76
+ def header_and_json_download(
77
+ header, data: dict, file_name, download_col_size=4, key=None, label="Download", help="Download json file"
78
+ ):
79
+ col1, col2 = st.columns([9, download_col_size])
80
+ with col1:
81
+ st.subheader(header)
82
+ with col2:
83
+ data_json = json.dumps(data)
84
+ st.download_button(
85
+ label=label,
86
+ data=data_json,
87
+ file_name=file_name,
88
+ key=key,
89
+ mime="application/json",
90
+ help=help,
91
+ )
design.py CHANGED
@@ -1,28 +1,78 @@
 
 
 
1
  import streamlit as st
2
 
3
  from text_sections import (
4
  design_text,
5
  )
6
 
7
- # arrow_size = 0.048,
8
- # arrow_nudge_from_text = 0.065,
9
- # sums_settings = sum_tile_settings(),
10
- # intensity_by
11
 
12
- # darkness = 0.8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def design_section(
16
  num_classes,
17
  predictions_are_probabilities,
 
18
  ):
19
  output = {}
20
 
21
- with st.form(key="settings_form"):
22
  design_text()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  col1, col2 = st.columns(2)
24
  with col1:
25
- output["selected_classes"] = st.multiselect(
26
  "Select classes (min=2, order is respected)",
27
  options=st.session_state["classes"],
28
  default=st.session_state["classes"],
@@ -35,43 +85,67 @@ def design_section(
35
  st.session_state["input_type"] == "data"
36
  and predictions_are_probabilities
37
  ):
38
- output["prob_of_class"] = st.selectbox(
39
  "Probabilities are of (not working)",
40
  options=st.session_state["classes"],
41
  index=1,
42
  )
43
  else:
44
- output["prob_of_class"] = None
45
 
46
- output["palette"] = st.selectbox(
47
- "Color Palette",
 
 
 
48
  options=["Blues", "Greens", "Oranges", "Greys", "Purples", "Reds"],
 
 
49
  )
50
 
51
  # Ask for output parameters
52
- # TODO: Set default based on number of classes and sum tiles
53
  col1, col2, col3 = st.columns(3)
54
  with col1:
55
  output["width"] = st.number_input(
56
- "Width (px)", value=1200 + 100 * (num_classes - 2), step=50
 
 
 
 
57
  )
58
  with col2:
59
  output["height"] = st.number_input(
60
- "Height (px)", value=1200 + 100 * (num_classes - 2), step=50
 
 
 
 
61
  )
62
  with col3:
63
- output["dpi"] = st.number_input("DPI (not working)", value=320, step=10)
 
 
 
 
64
 
65
  st.write(" ") # Slightly bigger gap between the two sections
66
  col1, col2, col3 = st.columns(3)
67
  with col1:
68
- output["show_counts"] = st.checkbox("Show Counts", value=True)
 
 
 
69
  with col2:
70
- output["show_normalized"] = st.checkbox("Show Normalized (%)", value=True)
 
 
 
 
 
71
  with col3:
72
  output["show_sums"] = st.checkbox(
73
  "Show Sum Tiles",
74
- value=False,
75
  help="Show extra row and column with the "
76
  "totals for that row/column.",
77
  )
@@ -82,19 +156,32 @@ def design_section(
82
  with st.expander("Elements"):
83
  col1, col2 = st.columns(2)
84
  with col1:
85
- output["rotate_y_text"] = st.checkbox("Rotate y-axis text", value=True)
 
 
 
 
 
86
  output["place_x_axis_above"] = st.checkbox(
87
- "Place x-axis on top", value=True
 
 
 
88
  )
89
  output["counts_on_top"] = st.checkbox(
90
  "Counts on top (not working)",
 
 
 
91
  help="Whether to switch the positions of the counts and normalized counts (%). "
92
  "The counts become the big centralized numbers and the "
93
  "normalized counts go below with a smaller font size.",
94
  )
95
  with col2:
96
  output["num_digits"] = st.number_input(
97
- "Digits", value=2, help="Number of digits to round percentages to."
 
 
98
  )
99
 
100
  st.markdown("""---""")
@@ -103,20 +190,36 @@ def design_section(
103
  with col1:
104
  st.write("Row and column percentages:")
105
  output["show_row_percentages"] = st.checkbox(
106
- "Show row percentages", value=num_classes < 6
 
 
 
107
  )
108
  output["show_col_percentages"] = st.checkbox(
109
- "Show column percentages", value=num_classes < 6
 
 
 
 
 
 
 
 
 
110
  )
111
- output["show_arrows"] = st.checkbox("Show arrows", value=True)
112
  output["diag_percentages_only"] = st.checkbox(
113
- "Diagonal row/column percentages only"
 
 
 
114
  )
115
  with col2:
116
  output["arrow_size"] = (
117
  st.slider(
118
  "Arrow size",
119
- value=0.048 * 10,
 
 
120
  min_value=0.03 * 10,
121
  max_value=0.06 * 10,
122
  step=0.001 * 10,
@@ -126,7 +229,9 @@ def design_section(
126
  output["arrow_nudge_from_text"] = (
127
  st.slider(
128
  "Arrow nudge from text",
129
- value=0.065 * 10,
 
 
130
  min_value=0.00,
131
  max_value=0.1 * 10,
132
  step=0.001 * 10,
@@ -135,45 +240,60 @@ def design_section(
135
  )
136
 
137
  with st.expander("Tiles"):
138
- col1, col2, col3 = st.columns(3)
139
  with col1:
140
- pass
141
- with col2:
142
- output["intensity_by"] = st.selectbox(
143
- "Intensity based on", options=["Counts", "Normalized (%)"]
 
 
 
144
  )
145
- with col3:
146
  output["darkness"] = st.slider(
147
  "Darkness",
148
  min_value=0.0,
149
  max_value=1.0,
150
- value=0.8,
 
 
151
  step=0.01,
152
  help="How dark the darkest colors should be, between 0 and 1, where 1 is darkest.",
153
  )
154
 
155
  st.markdown("""---""")
156
 
157
- col1, col2, col3, col4 = st.columns(4)
 
 
 
 
 
 
 
158
  with col1:
159
- output["show_tile_border"] = st.checkbox(
160
- "Add tile borders", value=False
161
- )
162
- with col2:
163
  output["tile_border_color"] = st.color_picker(
164
- "Border color", value="#000000"
 
 
 
165
  )
166
- with col3:
167
  output["tile_border_size"] = st.slider(
168
  "Border size",
169
  min_value=0.0,
170
  max_value=3.0,
171
- value=0.1,
 
 
172
  step=0.01,
173
  )
174
- with col4:
175
- output["tile_border_linetype"] = st.selectbox(
176
- "Border linetype",
 
 
177
  options=[
178
  "solid",
179
  "dashed",
@@ -182,6 +302,8 @@ def design_section(
182
  "longdash",
183
  "twodash",
184
  ],
 
 
185
  )
186
 
187
  st.markdown("""---""")
@@ -190,19 +312,24 @@ def design_section(
190
 
191
  col1, col2 = st.columns(2)
192
  with col1:
193
- output["sum_tile_palette"] = st.selectbox(
194
- "Color Palette",
195
- key="sum_tiles_color_palette",
 
196
  options=["Greens", "Oranges", "Greys", "Purples", "Reds", "Blues"],
 
 
197
  )
 
198
  with col2:
199
  output["sum_tile_label"] = st.text_input(
200
  "Label",
201
- value="Σ",
 
 
202
  key="sum_tiles_label",
203
  )
204
 
205
- # label = NULL,
206
  # tile_fill = NULL,
207
  # font_color = NULL,
208
  # tile_border_color = NULL,
@@ -218,48 +345,118 @@ def design_section(
218
  st.write("Special settings for tiles where the count is 0:")
219
  col1, col2, col3 = st.columns(3)
220
  with col1:
221
- output["show_zero_shading"] = st.checkbox("Add shading", value=True)
 
 
 
 
 
222
  with col2:
223
  output["show_zero_text"] = st.checkbox(
224
  "Show text",
225
- value=False,
 
 
226
  help="Whether to show counts, normalized (%), etc.",
227
  )
228
  with col3:
229
  output["show_zero_percentages"] = st.checkbox(
230
  "Show row/column percentages",
231
- value=False,
 
 
232
  help="Only relevant when row/column percentages are enabled.",
233
  )
234
 
235
- with st.expander("Fonts"):
236
- font_dicts = {}
237
- font_types = [
238
- "Counts",
239
- "Normalized (%)",
240
- "Row Percentage",
241
- "Column Percentage",
242
- ]
243
- for font_type in font_types:
244
- st.subheader(font_type)
245
- num_cols = 3
246
- font_dicts[font_type] = font_inputs(key_prefix=font_type)
247
- for i, (_, setting_widget) in enumerate(font_dicts[font_type].items()):
248
- if i % num_cols == 0:
249
- cols = st.columns(num_cols)
250
- with cols[i % num_cols]:
251
- setting_widget()
252
-
253
- if font_type != font_types[-1]:
254
- st.markdown("""---""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  st.markdown("""---""")
257
 
258
  if st.form_submit_button(label="Generate plot"):
259
  st.session_state["step"] = 3
 
 
 
 
260
 
 
261
  if st.session_state["step"] >= 3:
262
- output["design_ready"] = True
263
  if output["show_sums"] and output["sum_tile_palette"] == output["palette"]:
264
  st.error(
265
  "The color palettes (background colors) "
@@ -267,26 +464,77 @@ def design_section(
267
  "Please select a different color palette for "
268
  "the sum tiles under **Tiles** >> *Sum tile settings*."
269
  )
270
- output["design_ready"] = False
 
 
271
 
272
- return output
273
 
 
 
 
 
 
 
 
 
 
274
 
275
- def font_inputs(key_prefix: str):
276
- return {
277
- "color": lambda: st.color_picker("Color", key=f"{key_prefix}_color"),
278
- "bold": lambda: st.checkbox("Bold", key=f"{key_prefix}_bold"),
279
- "cursive": lambda: st.checkbox("Italics", key=f"{key_prefix}_italics"),
280
- "size": lambda: st.number_input("Size", key=f"{key_prefix}_size"),
281
- "nudge_x": lambda: st.number_input(
282
- "Nudge on x-axis", key=f"{key_prefix}_nudge_x"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  ),
284
- "nudge_y": lambda: st.number_input(
285
- "Nudge on y-axis", key=f"{key_prefix}_nudge_y"
 
 
286
  ),
287
- "alpha": lambda: st.slider(
288
- "Transparency", min_value=0, max_value=1, value=1, key=f"{key_prefix}_alpha"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  ),
290
- "prefix": lambda: st.text_input("Prefix", key=f"{key_prefix}_prefix"),
291
- "suffix": lambda: st.text_input("Suffix", key=f"{key_prefix}_suffix"),
292
  }
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Callable, Any, Tuple
2
+ import json
3
+ import numpy as np
4
  import streamlit as st
5
 
6
  from text_sections import (
7
  design_text,
8
  )
9
 
 
 
 
 
10
 
11
+ def _add_select_box(
12
+ key: str,
13
+ label: str,
14
+ default: Any,
15
+ options: List[Any],
16
+ get_setting_fn: Callable,
17
+ type_=str,
18
+ ):
19
+ """
20
+ Add selectbox with selection of default value from setting function.
21
+ """
22
+ chosen_default = get_setting_fn(
23
+ key=key,
24
+ default=default,
25
+ type_=type_,
26
+ options=options,
27
+ )
28
+ return st.selectbox(
29
+ label, options=options, index=options.index(chosen_default), key=key
30
+ )
31
 
32
 
33
  def design_section(
34
  num_classes,
35
  predictions_are_probabilities,
36
+ design_settings_store_path,
37
  ):
38
  output = {}
39
 
40
+ with st.form(key="settings_upload_form"):
41
  design_text()
42
+ uploaded_settings_path = st.file_uploader(
43
+ "Upload design settings", type=["json"]
44
+ )
45
+ # TODO: Allow resetting settings!
46
+ if st.form_submit_button(label="Apply settings"):
47
+ if uploaded_settings_path is not None:
48
+ uploaded_design_settings = json.load(uploaded_settings_path)
49
+ else:
50
+ st.warning("No settings were uploaded. Uploading settings is optional.")
51
+
52
+ def get_uploaded_setting(key, default, type_=None, options=None):
53
+ # NOTE: Must be placed here, to have `uploaded_design_settings` in locals
54
+
55
+ if "uploaded_design_settings" in locals() and key in uploaded_design_settings:
56
+ out = uploaded_design_settings[key]
57
+ if type_ is not None:
58
+ if not isinstance(out, type_):
59
+ st.warning(
60
+ f"An uploaded setting ({key}) had the wrong type. Using default value."
61
+ )
62
+ return default
63
+ if options is not None:
64
+ if out not in options:
65
+ st.warning(
66
+ f"An uploaded setting ({key}) was not a valid choice. Using default value."
67
+ )
68
+ return default
69
+ return out
70
+ return default
71
+
72
+ with st.form(key="settings_form"):
73
  col1, col2 = st.columns(2)
74
  with col1:
75
+ selected_classes = st.multiselect(
76
  "Select classes (min=2, order is respected)",
77
  options=st.session_state["classes"],
78
  default=st.session_state["classes"],
 
85
  st.session_state["input_type"] == "data"
86
  and predictions_are_probabilities
87
  ):
88
+ prob_of_class = st.selectbox(
89
  "Probabilities are of (not working)",
90
  options=st.session_state["classes"],
91
  index=1,
92
  )
93
  else:
94
+ prob_of_class = None
95
 
96
+ # Color palette
97
+ output["palette"] = _add_select_box(
98
+ key="palette",
99
+ label="Color Palette",
100
+ default="Blues",
101
  options=["Blues", "Greens", "Oranges", "Greys", "Purples", "Reds"],
102
+ get_setting_fn=get_uploaded_setting,
103
+ type_=str,
104
  )
105
 
106
  # Ask for output parameters
 
107
  col1, col2, col3 = st.columns(3)
108
  with col1:
109
  output["width"] = st.number_input(
110
+ "Width (px)",
111
+ value=get_uploaded_setting(
112
+ key="width", default=1200 + 100 * (num_classes - 2), type_=int
113
+ ),
114
+ step=50,
115
  )
116
  with col2:
117
  output["height"] = st.number_input(
118
+ "Height (px)",
119
+ value=get_uploaded_setting(
120
+ key="width", default=1200 + 100 * (num_classes - 2), type_=int
121
+ ),
122
+ step=50,
123
  )
124
  with col3:
125
+ output["dpi"] = st.number_input(
126
+ "DPI (not working)",
127
+ value=get_uploaded_setting(key="dpi", default=320, type_=int),
128
+ step=10,
129
+ )
130
 
131
  st.write(" ") # Slightly bigger gap between the two sections
132
  col1, col2, col3 = st.columns(3)
133
  with col1:
134
+ output["show_counts"] = st.checkbox(
135
+ "Show Counts",
136
+ value=get_uploaded_setting(key="show_counts", default=True, type_=bool),
137
+ )
138
  with col2:
139
+ output["show_normalized"] = st.checkbox(
140
+ "Show Normalized (%)",
141
+ value=get_uploaded_setting(
142
+ key="show_normalized", default=True, type_=bool
143
+ ),
144
+ )
145
  with col3:
146
  output["show_sums"] = st.checkbox(
147
  "Show Sum Tiles",
148
+ value=get_uploaded_setting(key="show_sums", default=False, type_=bool),
149
  help="Show extra row and column with the "
150
  "totals for that row/column.",
151
  )
 
156
  with st.expander("Elements"):
157
  col1, col2 = st.columns(2)
158
  with col1:
159
+ output["rotate_y_text"] = st.checkbox(
160
+ "Rotate y-axis text",
161
+ value=get_uploaded_setting(
162
+ key="rotate_y_text", default=True, type_=bool
163
+ ),
164
+ )
165
  output["place_x_axis_above"] = st.checkbox(
166
+ "Place x-axis on top",
167
+ value=get_uploaded_setting(
168
+ key="place_x_axis_above", default=True, type_=bool
169
+ ),
170
  )
171
  output["counts_on_top"] = st.checkbox(
172
  "Counts on top (not working)",
173
+ value=get_uploaded_setting(
174
+ key="counts_on_top", default=False, type_=bool
175
+ ),
176
  help="Whether to switch the positions of the counts and normalized counts (%). "
177
  "The counts become the big centralized numbers and the "
178
  "normalized counts go below with a smaller font size.",
179
  )
180
  with col2:
181
  output["num_digits"] = st.number_input(
182
+ "Digits",
183
+ value=get_uploaded_setting(key="num_digits", default=2, type_=int),
184
+ help="Number of digits to round percentages to.",
185
  )
186
 
187
  st.markdown("""---""")
 
190
  with col1:
191
  st.write("Row and column percentages:")
192
  output["show_row_percentages"] = st.checkbox(
193
+ "Show row percentages",
194
+ value=get_uploaded_setting(
195
+ key="show_row_percentages", default=num_classes < 6, type_=bool
196
+ ),
197
  )
198
  output["show_col_percentages"] = st.checkbox(
199
+ "Show column percentages",
200
+ value=get_uploaded_setting(
201
+ key="show_col_percentages", default=num_classes < 6, type_=bool
202
+ ),
203
+ )
204
+ output["show_arrows"] = st.checkbox(
205
+ "Show arrows",
206
+ value=get_uploaded_setting(
207
+ key="show_arrows", default=True, type_=bool
208
+ ),
209
  )
 
210
  output["diag_percentages_only"] = st.checkbox(
211
+ "Diagonal row/column percentages only",
212
+ value=get_uploaded_setting(
213
+ key="diag_percentages_only", default=False, type_=bool
214
+ ),
215
  )
216
  with col2:
217
  output["arrow_size"] = (
218
  st.slider(
219
  "Arrow size",
220
+ value=get_uploaded_setting(
221
+ key="arrow_size", default=0.048 * 10, type_=float
222
+ ),
223
  min_value=0.03 * 10,
224
  max_value=0.06 * 10,
225
  step=0.001 * 10,
 
229
  output["arrow_nudge_from_text"] = (
230
  st.slider(
231
  "Arrow nudge from text",
232
+ value=get_uploaded_setting(
233
+ key="arrow_nudge_from_text", default=0.065 * 10, type_=float
234
+ ),
235
  min_value=0.00,
236
  max_value=0.1 * 10,
237
  step=0.001 * 10,
 
240
  )
241
 
242
  with st.expander("Tiles"):
243
+ col1, col2 = st.columns(2)
244
  with col1:
245
+ output["intensity_by"] = _add_select_box(
246
+ key="intensity_by",
247
+ label="Intensity based on",
248
+ default="Counts",
249
+ options=["Counts", "Normalized (%)"],
250
+ get_setting_fn=get_uploaded_setting,
251
+ type_=str,
252
  )
253
+ with col2:
254
  output["darkness"] = st.slider(
255
  "Darkness",
256
  min_value=0.0,
257
  max_value=1.0,
258
+ value=get_uploaded_setting(
259
+ key="darkness", default=0.8, type_=float
260
+ ),
261
  step=0.01,
262
  help="How dark the darkest colors should be, between 0 and 1, where 1 is darkest.",
263
  )
264
 
265
  st.markdown("""---""")
266
 
267
+ output["show_tile_border"] = st.checkbox(
268
+ "Add tile borders",
269
+ value=get_uploaded_setting(
270
+ key="show_tile_border", default=False, type_=bool
271
+ ),
272
+ )
273
+
274
+ col1, col2, col3 = st.columns(3)
275
  with col1:
 
 
 
 
276
  output["tile_border_color"] = st.color_picker(
277
+ "Border color",
278
+ value=get_uploaded_setting(
279
+ key="tile_border_color", default="#000000", type_=str
280
+ ),
281
  )
282
+ with col2:
283
  output["tile_border_size"] = st.slider(
284
  "Border size",
285
  min_value=0.0,
286
  max_value=3.0,
287
+ value=get_uploaded_setting(
288
+ key="tile_border_size", default=0.1, type_=float
289
+ ),
290
  step=0.01,
291
  )
292
+ with col3:
293
+ output["tile_border_linetype"] = _add_select_box(
294
+ key="tile_border_linetype",
295
+ label="Border linetype",
296
+ default="solid",
297
  options=[
298
  "solid",
299
  "dashed",
 
302
  "longdash",
303
  "twodash",
304
  ],
305
+ get_setting_fn=get_uploaded_setting,
306
+ type_=str,
307
  )
308
 
309
  st.markdown("""---""")
 
312
 
313
  col1, col2 = st.columns(2)
314
  with col1:
315
+ output["sum_tile_palette"] = _add_select_box(
316
+ key="sum_tile_palette",
317
+ label="Color Palette",
318
+ default="Greens",
319
  options=["Greens", "Oranges", "Greys", "Purples", "Reds", "Blues"],
320
+ get_setting_fn=get_uploaded_setting,
321
+ type_=str,
322
  )
323
+
324
  with col2:
325
  output["sum_tile_label"] = st.text_input(
326
  "Label",
327
+ value=get_uploaded_setting(
328
+ key="sum_tile_label", default="Σ", type_=str
329
+ ),
330
  key="sum_tiles_label",
331
  )
332
 
 
333
  # tile_fill = NULL,
334
  # font_color = NULL,
335
  # tile_border_color = NULL,
 
345
  st.write("Special settings for tiles where the count is 0:")
346
  col1, col2, col3 = st.columns(3)
347
  with col1:
348
+ output["show_zero_shading"] = st.checkbox(
349
+ "Add shading",
350
+ value=get_uploaded_setting(
351
+ key="show_zero_shading", default=True, type_=bool
352
+ ),
353
+ )
354
  with col2:
355
  output["show_zero_text"] = st.checkbox(
356
  "Show text",
357
+ value=get_uploaded_setting(
358
+ key="show_zero_text", default=False, type_=bool
359
+ ),
360
  help="Whether to show counts, normalized (%), etc.",
361
  )
362
  with col3:
363
  output["show_zero_percentages"] = st.checkbox(
364
  "Show row/column percentages",
365
+ value=get_uploaded_setting(
366
+ key="show_zero_percentages", default=False, type_=bool
367
+ ),
368
  help="Only relevant when row/column percentages are enabled.",
369
  )
370
 
371
+ if True:
372
+ with st.expander("Fonts"):
373
+ # Specify available settings and defaults per font
374
+ font_types = {
375
+ "Top Font": {
376
+ "key_prefix": "font_top",
377
+ "description": "The big text in the middle (normalized (%) by default).",
378
+ "settings": {
379
+ "size": 4.3, # 2.8
380
+ "color": "#000000",
381
+ "alpha": 1.0,
382
+ "bold": False,
383
+ "italics": False,
384
+ },
385
+ },
386
+ "Bottom Font": {
387
+ "key_prefix": "font_bottom",
388
+ "description": "The text just below the top font (counts by default).",
389
+ "settings": {
390
+ "size": 2.8,
391
+ "color": "#000000",
392
+ "alpha": 1.0,
393
+ "bold": False,
394
+ "italics": False,
395
+ },
396
+ },
397
+ "Percentages Font": {
398
+ "key_prefix": "font_percentage",
399
+ "description": "The row and column percentages.",
400
+ "settings": {
401
+ "size": 2.35,
402
+ "color": "#000000",
403
+ "alpha": 0.85,
404
+ "bold": False,
405
+ "italics": True,
406
+ "suffix": "%",
407
+ "prefix": "",
408
+ },
409
+ },
410
+ "Normalized (%)": {
411
+ "key_prefix": "font_normalized",
412
+ "description": "Special settings for the normalized (%) text.",
413
+ "settings": {"suffix": "%", "prefix": ""},
414
+ },
415
+ "Counts": {
416
+ "key_prefix": "font_counts",
417
+ "description": "Special settings for the counts text.",
418
+ "settings": {"suffix": "", "prefix": ""},
419
+ },
420
+ }
421
+
422
+ for font_type_title, font_type_spec in font_types.items():
423
+ st.markdown(f"**{font_type_title}**")
424
+ st.markdown(font_type_spec["description"])
425
+ num_cols = 3
426
+ font_settings = create_font_settings(
427
+ key_prefix=font_type_spec["key_prefix"],
428
+ get_setting_fn=get_uploaded_setting,
429
+ settings_to_get=list(font_type_spec["settings"].keys()),
430
+ )
431
+
432
+ for i, (setting_name, setting_widget) in enumerate(
433
+ font_settings.items()
434
+ ):
435
+ if i % num_cols == 0:
436
+ cols = st.columns(num_cols)
437
+ with cols[i % num_cols]:
438
+ default = font_type_spec["settings"][
439
+ setting_name[len(font_type_spec["key_prefix"]) + 1 :]
440
+ ]
441
+ output[setting_name] = setting_widget(
442
+ k=setting_name, d=default
443
+ )
444
+
445
+ if font_type_title != list(font_types.keys())[-1]:
446
+ st.markdown("""---""")
447
 
448
  st.markdown("""---""")
449
 
450
  if st.form_submit_button(label="Generate plot"):
451
  st.session_state["step"] = 3
452
+ if output["show_sums"] and output["sum_tile_palette"] == output["palette"]:
453
+ # Save settings as json
454
+ with open(design_settings_store_path, "w") as f:
455
+ json.dump(output, f)
456
 
457
+ design_ready = False
458
  if st.session_state["step"] >= 3:
459
+ design_ready = True
460
  if output["show_sums"] and output["sum_tile_palette"] == output["palette"]:
461
  st.error(
462
  "The color palettes (background colors) "
 
464
  "Please select a different color palette for "
465
  "the sum tiles under **Tiles** >> *Sum tile settings*."
466
  )
467
+ design_ready = False
468
+
469
+ return output, design_ready, selected_classes, prob_of_class
470
 
 
471
 
472
+ # defaults: dict,
473
+ def create_font_settings(
474
+ key_prefix: str, get_setting_fn: Callable, settings_to_get: List[str]
475
+ ) -> Tuple[dict, dict]:
476
+ # TODO: Defaults must be set based on font type! Also,
477
+ # we probably need to allow not setting the argument so the
478
+ # plotting function can handle the defaulting?
479
+ def make_key(key):
480
+ return f"{key_prefix}_{key}"
481
 
482
+ font_settings = {
483
+ make_key("color"): lambda k, d: st.color_picker(
484
+ "Color",
485
+ key=k,
486
+ value=get_setting_fn(key=k, default=d, type_=str),
487
+ ),
488
+ make_key("bold"): lambda k, d: st.checkbox(
489
+ "Bold",
490
+ key=k,
491
+ value=get_setting_fn(key=k, default=d, type_=bool),
492
+ ),
493
+ make_key("italics"): lambda k, d: st.checkbox(
494
+ "Italics",
495
+ key=k,
496
+ value=get_setting_fn(key=k, default=d, type_=bool),
497
+ ),
498
+ make_key("size"): lambda k, d: st.number_input(
499
+ "Size",
500
+ key=k,
501
+ value=get_setting_fn(key=k, default=float(d), type_=float),
502
+ ),
503
+ make_key("nudge_x"): lambda k, d: st.number_input(
504
+ "Nudge on x-axis",
505
+ key=k,
506
+ value=get_setting_fn(key=k, default=d, type_=float),
507
  ),
508
+ make_key("nudge_y"): lambda k, d: st.number_input(
509
+ "Nudge on y-axis",
510
+ key=k,
511
+ value=get_setting_fn(key=k, default=d, type_=float),
512
  ),
513
+ make_key("alpha"): lambda k, d: st.slider(
514
+ "Transparency",
515
+ min_value=0.0,
516
+ max_value=1.0,
517
+ value=get_setting_fn(key=k, default=d, type_=float),
518
+ step=0.01,
519
+ key=k,
520
+ ),
521
+ make_key("prefix"): lambda k, d: st.text_input(
522
+ "Prefix",
523
+ key=k,
524
+ value=get_setting_fn(key=k, default=d, type_=str),
525
+ ),
526
+ make_key("suffix"): lambda k, d: st.text_input(
527
+ "Suffix",
528
+ key=k,
529
+ value=get_setting_fn(key=k, default=d, type_=str),
530
  ),
 
 
531
  }
532
+
533
+ # Filter settings
534
+ font_settings = {
535
+ k: v
536
+ for k, v in font_settings.items()
537
+ if f"{k[len(key_prefix)+1:]}" in settings_to_get
538
+ }
539
+
540
+ return font_settings