Ludvig commited on
Commit
9c2558f
·
1 Parent(s): 6c86e78

Adds zoom and greyscale toggle to result viewer

Browse files
Files changed (6) hide show
  1. README.md +0 -3
  2. app.py +47 -14
  3. components.py +25 -0
  4. data.py +40 -9
  5. design.py +20 -44
  6. utils.py +27 -0
README.md CHANGED
@@ -14,11 +14,8 @@ Streamlit application for plotting a confusion matrix.
14
 
15
 
16
  ## TODOs
17
- - Add option to preview plot in black and white (for printed papers)
18
  - ggsave only uses DPI for scaling? We would expect output files to have the given DPI?
19
  - Allow svg, pdf?
20
- - Add full reset button (empty cache on different files) - callback?
21
  - Add option to change zero-tile background (e.g. to black for black backgrounds)
22
  - Add option to format total-count tile in sum tiles
23
- - Selectable templates (for 2,3,4,5 classes - one selects num classes and pick a color scheme and other common defaults)
24
  - Allow handling tick text - e.g. for long class names or many classes.
 
14
 
15
 
16
  ## TODOs
 
17
  - ggsave only uses DPI for scaling? We would expect output files to have the given DPI?
18
  - Allow svg, pdf?
 
19
  - Add option to change zero-tile background (e.g. to black for black backgrounds)
20
  - Add option to format total-count tile in sum tiles
 
21
  - Allow handling tick text - e.g. for long class names or many classes.
app.py CHANGED
@@ -11,7 +11,12 @@ import pandas as pd
11
  from pandas.api.types import is_float_dtype
12
  from itertools import combinations
13
 
14
- from utils import call_subprocess, clean_string_for_non_alphanumerics, clean_str_column
 
 
 
 
 
15
  from data import read_data, read_data_cached, DownloadHeader, generate_data
16
  from design import design_section
17
  from text_sections import (
@@ -414,11 +419,38 @@ if st.session_state["step"] >= 2:
414
  encoding="UTF-8",
415
  )
416
 
417
- DownloadHeader.header_and_image_download(
418
- "", filepath=conf_mat_path, label="Download plot"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  )
420
 
421
- col1, col2, col3 = st.columns([2, 8, 2])
422
  with col2:
423
  st.write(" ")
424
  st.write(" ")
@@ -431,16 +463,17 @@ if st.session_state["step"] >= 2:
431
  output_format="auto",
432
  )
433
 
434
- # Convert the image to grayscale
435
- st.write(" ")
436
- image = image.convert("CMYK").convert("L")
437
- st.image(
438
- image,
439
- caption="Greyscale version for assessing colors in print",
440
- clamp=False,
441
- channels="RGB",
442
- output_format="auto",
443
- )
 
444
  st.write(" ")
445
  st.write("Note: The downloadable file has a transparent background.")
446
 
 
11
  from pandas.api.types import is_float_dtype
12
  from itertools import combinations
13
 
14
+ from utils import (
15
+ call_subprocess,
16
+ clean_string_for_non_alphanumerics,
17
+ clean_str_column,
18
+ min_max_scale_list,
19
+ )
20
  from data import read_data, read_data_cached, DownloadHeader, generate_data
21
  from design import design_section
22
  from text_sections import (
 
419
  encoding="UTF-8",
420
  )
421
 
422
+ (
423
+ image_col_size,
424
+ st.session_state["show_greyscale"],
425
+ ) = DownloadHeader.slider_and_image_download(
426
+ filepath=conf_mat_path,
427
+ download_label="Download plot",
428
+ slider_label="Zoom",
429
+ toggle_label="Show greyscale",
430
+ toggle_value=True,
431
+ toggle_cols=[10, 1],
432
+ slider_help="Zoom in/out to better match the size you expect to have in a paper etc. "
433
+ "This affects the font sizes and will likely lead to adjustments of `height` and `width`.",
434
+ )
435
+ st.session_state["image_col_size"] = (
436
+ min_max_scale_list(
437
+ x=[image_col_size],
438
+ new_min=2.0,
439
+ new_max=8.0,
440
+ old_min=0.0,
441
+ old_max=1.0,
442
+ )[0]
443
+ if image_col_size <= 1
444
+ else min_max_scale_list(
445
+ x=[image_col_size],
446
+ new_min=8.0,
447
+ new_max=23.0,
448
+ old_min=1.0,
449
+ old_max=2.0,
450
+ )[0]
451
  )
452
 
453
+ col1, col2, col3 = st.columns([2, st.session_state["image_col_size"], 2])
454
  with col2:
455
  st.write(" ")
456
  st.write(" ")
 
463
  output_format="auto",
464
  )
465
 
466
+ if st.session_state["show_greyscale"]:
467
+ # Convert the image to grayscale
468
+ st.write(" ")
469
+ image = image.convert("CMYK").convert("L")
470
+ st.image(
471
+ image,
472
+ caption="Greyscale version for assessing colors in print",
473
+ clamp=False,
474
+ channels="RGB",
475
+ output_format="auto",
476
+ )
477
  st.write(" ")
478
  st.write("Note: The downloadable file has a transparent background.")
479
 
components.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_toggle import st_toggle_switch
3
+
4
+
5
+ def add_toggle_vertical(label, key, default, cols=[2, 5]):
6
+ st.markdown(f"<p style='font-size:0.85em;'>{label}</p>", unsafe_allow_html=True)
7
+ col1, _ = st.columns(cols)
8
+ with col1:
9
+ return st_toggle_switch(
10
+ " ",
11
+ default_value=default,
12
+ key=key,
13
+ label_after=True,
14
+ inactive_color="#eb5a53",
15
+ )
16
+
17
+
18
+ def add_toggle_horizontal(label, key, default):
19
+ return st_toggle_switch(
20
+ label=label,
21
+ default_value=default,
22
+ key=key,
23
+ label_after=True,
24
+ inactive_color="#eb5a53",
25
+ )
data.py CHANGED
@@ -4,6 +4,8 @@ import pandas as pd
4
  import streamlit as st
5
  from utils import call_subprocess
6
 
 
 
7
 
8
  def read_data(data):
9
  if data is not None:
@@ -33,23 +35,52 @@ class DownloadHeader:
33
  """
34
 
35
  @staticmethod
36
- def header_and_image_download(
37
- header, filepath, key=None, label="Download", help="Download plot"
38
- ):
39
- col1, col2 = st.columns([11, 3])
40
- with col1:
41
- st.subheader(header)
 
 
 
 
 
 
 
 
 
 
42
  with col2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  st.write("")
44
  with open(filepath, "rb") as img:
45
  st.download_button(
46
- label=label,
47
  data=img,
48
  file_name=pathlib.Path(filepath).name,
49
  mime="image/png",
50
- key=key,
51
- help=help,
52
  )
 
53
 
54
  @staticmethod
55
  def _convert_df_to_csv(data, **kwargs):
 
4
  import streamlit as st
5
  from utils import call_subprocess
6
 
7
+ from components import add_toggle_vertical
8
+
9
 
10
  def read_data(data):
11
  if data is not None:
 
35
  """
36
 
37
  @staticmethod
38
+ def slider_and_image_download(
39
+ filepath,
40
+ slider_label,
41
+ toggle_label,
42
+ download_label="Download",
43
+ slider_min=0.0,
44
+ slider_max=2.0,
45
+ slider_value=1.0,
46
+ slider_step=0.1,
47
+ slider_help=None,
48
+ toggle_value=False,
49
+ toggle_cols=[2, 5],
50
+ download_help="Download plot",
51
+ key=None,
52
+ ) -> int:
53
+ col1, col2, col3, col4 = st.columns([2, 6, 3, 3])
54
  with col2:
55
+ # Image viewing size slider
56
+ image_col_size = st.slider(
57
+ slider_label,
58
+ min_value=slider_min,
59
+ max_value=slider_max,
60
+ value=slider_value,
61
+ step=slider_step,
62
+ help=slider_help,
63
+ key=key + "_slider" if key is not None else key,
64
+ )
65
+ with col3:
66
+ toggle_state = add_toggle_vertical(
67
+ label=toggle_label,
68
+ key=key + "_toggle" if key is not None else key,
69
+ default=toggle_value,
70
+ cols=toggle_cols,
71
+ )
72
+ with col4:
73
  st.write("")
74
  with open(filepath, "rb") as img:
75
  st.download_button(
76
+ label=download_label,
77
  data=img,
78
  file_name=pathlib.Path(filepath).name,
79
  mime="image/png",
80
+ key=key + "_download" if key is not None else key,
81
+ help=download_help,
82
  )
83
+ return image_col_size, toggle_state
84
 
85
  @staticmethod
86
  def _convert_df_to_csv(data, **kwargs):
design.py CHANGED
@@ -2,38 +2,14 @@ from typing import List, Callable, Any, Tuple
2
  import json
3
  import streamlit as st
4
  from PIL import Image
5
- from streamlit_toggle import st_toggle_switch
6
-
7
 
 
8
  from text_sections import (
9
  design_text,
10
  )
11
  from templates import get_templates
12
 
13
 
14
- def _add_toggle_vertical(label, key, default, cols=[2, 5]):
15
- st.markdown(f"<p style='font-size:0.85em;'>{label}</p>", unsafe_allow_html=True)
16
- col1, _ = st.columns(cols)
17
- with col1:
18
- return st_toggle_switch(
19
- " ",
20
- default_value=default,
21
- key=key,
22
- label_after=True,
23
- inactive_color="#eb5a53",
24
- )
25
-
26
-
27
- def _add_toggle_horizontal(label, key, default):
28
- return st_toggle_switch(
29
- label=label,
30
- default_value=default,
31
- key=key,
32
- label_after=True,
33
- inactive_color="#eb5a53",
34
- )
35
-
36
-
37
  def _add_select_box(
38
  key: str,
39
  label: str,
@@ -91,7 +67,7 @@ def select_settings():
91
  options=[-1] + templates_available_num_classes,
92
  )
93
  with col2:
94
- has_sums = _add_toggle_vertical(
95
  label="With sum tiles",
96
  key="filter_sum_tiles",
97
  default=False,
@@ -199,7 +175,7 @@ def design_section(
199
  # Reverse by default
200
  selected_classes.reverse()
201
  with col2:
202
- reverse_class_order = _add_toggle_vertical(
203
  label="Reverse class order",
204
  key="reverse_order",
205
  default=False,
@@ -230,7 +206,7 @@ def design_section(
230
  with col2:
231
  st.session_state["selected_design_settings"][
232
  "palette_use_custom"
233
- ] = _add_toggle_vertical(
234
  label="Use custom gradient", key="custom_gradient", default=False
235
  )
236
  with col3:
@@ -278,7 +254,7 @@ def design_section(
278
  with col1:
279
  st.session_state["selected_design_settings"][
280
  "show_counts"
281
- ] = _add_toggle_vertical(
282
  label="Show counts",
283
  key="show_counts",
284
  default=get_uploaded_setting(
@@ -289,7 +265,7 @@ def design_section(
289
  with col2:
290
  st.session_state["selected_design_settings"][
291
  "show_normalized"
292
- ] = _add_toggle_vertical(
293
  label="Show normalized (%)",
294
  key="show_normalized",
295
  default=get_uploaded_setting(
@@ -300,7 +276,7 @@ def design_section(
300
  with col3:
301
  st.session_state["selected_design_settings"][
302
  "show_sums"
303
- ] = _add_toggle_vertical(
304
  label="Show sum tiles",
305
  key="show_sum_tiles",
306
  default=get_uploaded_setting(
@@ -366,7 +342,7 @@ def design_section(
366
  with col1:
367
  st.session_state["selected_design_settings"][
368
  "rotate_y_text"
369
- ] = _add_toggle_horizontal(
370
  label="Rotate y-axis text",
371
  key="rotate_y_text",
372
  default=get_uploaded_setting(
@@ -376,7 +352,7 @@ def design_section(
376
 
377
  st.session_state["selected_design_settings"][
378
  "place_x_axis_above"
379
- ] = _add_toggle_horizontal(
380
  label="Place x-axis on top",
381
  default=get_uploaded_setting(
382
  key="place_x_axis_above", default=True, type_=bool
@@ -385,7 +361,7 @@ def design_section(
385
  )
386
  st.session_state["selected_design_settings"][
387
  "counts_on_top"
388
- ] = _add_toggle_horizontal(
389
  label="Counts on top",
390
  default=get_uploaded_setting(
391
  key="counts_on_top", default=False, type_=bool
@@ -415,7 +391,7 @@ def design_section(
415
  st.write("Row and column percentages:")
416
  st.session_state["selected_design_settings"][
417
  "show_row_percentages"
418
- ] = _add_toggle_horizontal(
419
  label="Show row percentages",
420
  default=get_uploaded_setting(
421
  key="show_row_percentages",
@@ -426,7 +402,7 @@ def design_section(
426
  )
427
  st.session_state["selected_design_settings"][
428
  "show_col_percentages"
429
- ] = _add_toggle_horizontal(
430
  label="Show column percentages",
431
  default=get_uploaded_setting(
432
  key="show_col_percentages",
@@ -437,7 +413,7 @@ def design_section(
437
  )
438
  st.session_state["selected_design_settings"][
439
  "show_arrows"
440
- ] = _add_toggle_horizontal(
441
  label="Show arrows",
442
  default=get_uploaded_setting(
443
  key="show_arrows", default=True, type_=bool
@@ -446,7 +422,7 @@ def design_section(
446
  )
447
  st.session_state["selected_design_settings"][
448
  "diag_percentages_only"
449
- ] = _add_toggle_horizontal(
450
  label="Diagonal percentages only",
451
  default=get_uploaded_setting(
452
  key="diag_percentages_only",
@@ -525,7 +501,7 @@ def design_section(
525
 
526
  st.session_state["selected_design_settings"][
527
  "show_tile_border"
528
- ] = _add_toggle_horizontal(
529
  label="Add tile borders",
530
  default=get_uploaded_setting(
531
  key="show_tile_border", default=False, type_=bool
@@ -628,7 +604,7 @@ def design_section(
628
  with col1:
629
  st.session_state["selected_design_settings"][
630
  "show_zero_shading"
631
- ] = _add_toggle_vertical(
632
  label="Add shading",
633
  default=get_uploaded_setting(
634
  key="show_zero_shading", default=True, type_=bool
@@ -638,7 +614,7 @@ def design_section(
638
  with col2:
639
  st.session_state["selected_design_settings"][
640
  "show_zero_text"
641
- ] = _add_toggle_vertical(
642
  label="Show main text",
643
  default=get_uploaded_setting(
644
  key="show_zero_text", default=False, type_=bool
@@ -648,7 +624,7 @@ def design_section(
648
  with col3:
649
  st.session_state["selected_design_settings"][
650
  "show_zero_percentages"
651
- ] = _add_toggle_vertical(
652
  label="Show row/column percentages",
653
  default=get_uploaded_setting(
654
  key="show_zero_percentages",
@@ -794,12 +770,12 @@ def create_font_settings(
794
  key=k,
795
  value=get_setting_fn(key=k, default=d, type_=str),
796
  ),
797
- make_key("bold"): lambda k, d: _add_toggle_horizontal(
798
  label="Bold",
799
  key=k,
800
  default=get_setting_fn(key=k, default=d, type_=bool),
801
  ),
802
- make_key("italic"): lambda k, d: _add_toggle_horizontal(
803
  label="Italic",
804
  key=k,
805
  default=get_setting_fn(key=k, default=d, type_=bool),
 
2
  import json
3
  import streamlit as st
4
  from PIL import Image
 
 
5
 
6
+ from components import add_toggle_horizontal, add_toggle_vertical
7
  from text_sections import (
8
  design_text,
9
  )
10
  from templates import get_templates
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def _add_select_box(
14
  key: str,
15
  label: str,
 
67
  options=[-1] + templates_available_num_classes,
68
  )
69
  with col2:
70
+ has_sums = add_toggle_vertical(
71
  label="With sum tiles",
72
  key="filter_sum_tiles",
73
  default=False,
 
175
  # Reverse by default
176
  selected_classes.reverse()
177
  with col2:
178
+ reverse_class_order = add_toggle_vertical(
179
  label="Reverse class order",
180
  key="reverse_order",
181
  default=False,
 
206
  with col2:
207
  st.session_state["selected_design_settings"][
208
  "palette_use_custom"
209
+ ] = add_toggle_vertical(
210
  label="Use custom gradient", key="custom_gradient", default=False
211
  )
212
  with col3:
 
254
  with col1:
255
  st.session_state["selected_design_settings"][
256
  "show_counts"
257
+ ] = add_toggle_vertical(
258
  label="Show counts",
259
  key="show_counts",
260
  default=get_uploaded_setting(
 
265
  with col2:
266
  st.session_state["selected_design_settings"][
267
  "show_normalized"
268
+ ] = add_toggle_vertical(
269
  label="Show normalized (%)",
270
  key="show_normalized",
271
  default=get_uploaded_setting(
 
276
  with col3:
277
  st.session_state["selected_design_settings"][
278
  "show_sums"
279
+ ] = add_toggle_vertical(
280
  label="Show sum tiles",
281
  key="show_sum_tiles",
282
  default=get_uploaded_setting(
 
342
  with col1:
343
  st.session_state["selected_design_settings"][
344
  "rotate_y_text"
345
+ ] = add_toggle_horizontal(
346
  label="Rotate y-axis text",
347
  key="rotate_y_text",
348
  default=get_uploaded_setting(
 
352
 
353
  st.session_state["selected_design_settings"][
354
  "place_x_axis_above"
355
+ ] = add_toggle_horizontal(
356
  label="Place x-axis on top",
357
  default=get_uploaded_setting(
358
  key="place_x_axis_above", default=True, type_=bool
 
361
  )
362
  st.session_state["selected_design_settings"][
363
  "counts_on_top"
364
+ ] = add_toggle_horizontal(
365
  label="Counts on top",
366
  default=get_uploaded_setting(
367
  key="counts_on_top", default=False, type_=bool
 
391
  st.write("Row and column percentages:")
392
  st.session_state["selected_design_settings"][
393
  "show_row_percentages"
394
+ ] = add_toggle_horizontal(
395
  label="Show row percentages",
396
  default=get_uploaded_setting(
397
  key="show_row_percentages",
 
402
  )
403
  st.session_state["selected_design_settings"][
404
  "show_col_percentages"
405
+ ] = add_toggle_horizontal(
406
  label="Show column percentages",
407
  default=get_uploaded_setting(
408
  key="show_col_percentages",
 
413
  )
414
  st.session_state["selected_design_settings"][
415
  "show_arrows"
416
+ ] = add_toggle_horizontal(
417
  label="Show arrows",
418
  default=get_uploaded_setting(
419
  key="show_arrows", default=True, type_=bool
 
422
  )
423
  st.session_state["selected_design_settings"][
424
  "diag_percentages_only"
425
+ ] = add_toggle_horizontal(
426
  label="Diagonal percentages only",
427
  default=get_uploaded_setting(
428
  key="diag_percentages_only",
 
501
 
502
  st.session_state["selected_design_settings"][
503
  "show_tile_border"
504
+ ] = add_toggle_horizontal(
505
  label="Add tile borders",
506
  default=get_uploaded_setting(
507
  key="show_tile_border", default=False, type_=bool
 
604
  with col1:
605
  st.session_state["selected_design_settings"][
606
  "show_zero_shading"
607
+ ] = add_toggle_vertical(
608
  label="Add shading",
609
  default=get_uploaded_setting(
610
  key="show_zero_shading", default=True, type_=bool
 
614
  with col2:
615
  st.session_state["selected_design_settings"][
616
  "show_zero_text"
617
+ ] = add_toggle_vertical(
618
  label="Show main text",
619
  default=get_uploaded_setting(
620
  key="show_zero_text", default=False, type_=bool
 
624
  with col3:
625
  st.session_state["selected_design_settings"][
626
  "show_zero_percentages"
627
+ ] = add_toggle_vertical(
628
  label="Show row/column percentages",
629
  default=get_uploaded_setting(
630
  key="show_zero_percentages",
 
770
  key=k,
771
  value=get_setting_fn(key=k, default=d, type_=str),
772
  ),
773
+ make_key("bold"): lambda k, d: add_toggle_horizontal(
774
  label="Bold",
775
  key=k,
776
  default=get_setting_fn(key=k, default=d, type_=bool),
777
  ),
778
+ make_key("italic"): lambda k, d: add_toggle_horizontal(
779
  label="Italic",
780
  key=k,
781
  default=get_setting_fn(key=k, default=d, type_=bool),
utils.py CHANGED
@@ -2,6 +2,7 @@ import subprocess
2
  import re
3
  import streamlit as st
4
  import json
 
5
 
6
 
7
  def show_error(msg, action):
@@ -63,3 +64,29 @@ def clean_string_for_non_alphanumerics(s):
63
 
64
  def clean_str_column(x):
65
  return x.astype(str).apply(lambda x: clean_string_for_non_alphanumerics(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import re
3
  import streamlit as st
4
  import json
5
+ from typing import Optional
6
 
7
 
8
  def show_error(msg, action):
 
64
 
65
  def clean_str_column(x):
66
  return x.astype(str).apply(lambda x: clean_string_for_non_alphanumerics(x))
67
+
68
+
69
+ def min_max_scale_list(
70
+ x: list,
71
+ new_min: float,
72
+ new_max: float,
73
+ old_min: Optional[float] = None,
74
+ old_max: Optional[float] = None,
75
+ ) -> list:
76
+ """
77
+ MinMax scaler for lists.
78
+ Why: Currently we don't require numpy as dependency.
79
+ """
80
+ if old_min is None:
81
+ old_min = min(x)
82
+ if old_max is None:
83
+ old_max = max(x)
84
+
85
+ diff = old_max - old_min
86
+
87
+ # Avoiding zero-division
88
+ if diff == 0:
89
+ diff = 1
90
+
91
+ x = [(xi - old_min) / diff for xi in x]
92
+ return [xi * (new_max - new_min) + new_min for xi in x]