Ludvig
Adds log-scaled intensities. Disables DEV mode
9856d69
raw
history blame
22.4 kB
from typing import List, Callable, Any, Tuple
import json
import streamlit as st
from text_sections import (
design_text,
)
def _add_select_box(
key: str,
label: str,
default: Any,
options: List[Any],
get_setting_fn: Callable,
type_=str,
):
"""
Add selectbox with selection of default value from setting function.
"""
chosen_default = get_setting_fn(
key=key,
default=default,
type_=type_,
options=options,
)
return st.selectbox(
label, options=options, index=options.index(chosen_default), key=key
)
def design_section(
num_classes,
design_settings_store_path,
):
output = {}
with st.form(key="settings_upload_form"):
design_text()
uploaded_settings_path = st.file_uploader(
"Upload design settings", type=["json"]
)
# TODO: Allow resetting settings!
if st.form_submit_button(label="Apply settings"):
if uploaded_settings_path is not None:
uploaded_design_settings = json.load(uploaded_settings_path)
else:
st.warning("No settings were uploaded. Uploading settings is optional.")
def get_uploaded_setting(key, default, type_=None, options=None):
# NOTE: Must be placed here, to have `uploaded_design_settings` in locals
if "uploaded_design_settings" in locals() and key in uploaded_design_settings:
out = uploaded_design_settings[key]
if type_ is not None:
if not isinstance(out, type_):
st.warning(
f"An uploaded setting ({key}) had the wrong type. Using default value."
)
return default
if options is not None:
if out not in options:
st.warning(
f"An uploaded setting ({key}) was not a valid choice. Using default value."
)
return default
return out
return default
with st.form(key="settings_form"):
col1, col2, col3 = st.columns([4, 2, 2])
with col1:
selected_classes = st.multiselect(
"Select classes (min=2, order is respected)",
options=st.session_state["classes"],
default=st.session_state["classes"],
help="Select the classes to create the confusion matrix for. "
"Any observation with either a target or prediction "
"of another class is excluded.",
)
# Reverse by default
selected_classes.reverse()
with col2:
st.write(" ")
st.write(" ")
reverse_class_order = st.checkbox(
"Reverse order", value=False, help="Reverse the order of the classes."
)
# Color palette
output["palette"] = _add_select_box(
key="palette",
label="Color Palette",
default="Blues",
options=["Blues", "Greens", "Oranges", "Greys", "Purples", "Reds"],
get_setting_fn=get_uploaded_setting,
type_=str,
)
# Ask for output parameters
col1, col2, col3 = st.columns(3)
with col1:
output["width"] = st.number_input(
"Width (px)",
value=get_uploaded_setting(
key="width", default=1200 + 100 * (num_classes - 2), type_=int
),
step=50,
)
with col2:
output["height"] = st.number_input(
"Height (px)",
value=get_uploaded_setting(
key="width", default=1200 + 100 * (num_classes - 2), type_=int
),
step=50,
)
with col3:
output["dpi"] = st.number_input(
"DPI (scaling)",
value=get_uploaded_setting(key="dpi", default=320, type_=int),
step=10,
help="While the output file *currently* won't have this DPI, "
"the DPI setting affects scaling of elements. ",
)
st.write(" ") # Slightly bigger gap between the two sections
col1, col2, col3 = st.columns(3)
with col1:
output["show_counts"] = st.checkbox(
"Show Counts",
value=get_uploaded_setting(key="show_counts", default=True, type_=bool),
)
with col2:
output["show_normalized"] = st.checkbox(
"Show Normalized (%)",
value=get_uploaded_setting(
key="show_normalized", default=True, type_=bool
),
)
with col3:
output["show_sums"] = st.checkbox(
"Show Sum Tiles",
value=get_uploaded_setting(key="show_sums", default=False, type_=bool),
help="Show extra row and column with the "
"totals for that row/column.",
)
st.markdown("""---""")
st.markdown("**Advanced**:")
with st.expander("Labels"):
col1, col2 = st.columns(2)
with col1:
output["x_label"] = st.text_input(
"x-axis",
value=get_uploaded_setting(
key="x_label", default="True Class", type_=str
),
)
with col2:
output["y_label"] = st.text_input(
"y-axis",
value=get_uploaded_setting(
key="y_label", default="Predicted Class", type_=str
),
)
st.markdown("---")
col1, col2 = st.columns(2)
with col1:
output["title_label"] = st.text_input(
"Title",
value=get_uploaded_setting(
key="title_label", default="", type_=str
),
)
with col2:
output["caption_label"] = st.text_input(
"Caption",
value=get_uploaded_setting(
key="caption_label", default="", type_=str
),
)
st.info(
"Note: When adding a title or caption, "
"you may need to adjust the height and "
"width of the plot as well."
)
with st.expander("Elements"):
col1, col2 = st.columns(2)
with col1:
output["rotate_y_text"] = st.checkbox(
"Rotate y-axis text",
value=get_uploaded_setting(
key="rotate_y_text", default=True, type_=bool
),
)
output["place_x_axis_above"] = st.checkbox(
"Place x-axis on top",
value=get_uploaded_setting(
key="place_x_axis_above", default=True, type_=bool
),
)
output["counts_on_top"] = st.checkbox(
"Counts on top",
value=get_uploaded_setting(
key="counts_on_top", default=False, type_=bool
),
help="Whether to switch the positions of the counts and normalized counts (%). "
"The counts become the big centralized numbers and the "
"normalized counts go below with a smaller font size.",
)
with col2:
output["num_digits"] = st.number_input(
"Digits",
value=get_uploaded_setting(key="num_digits", default=2, type_=int),
help="Number of digits to round percentages to.",
)
st.markdown("""---""")
col1, col2 = st.columns(2)
with col1:
st.write("Row and column percentages:")
output["show_row_percentages"] = st.checkbox(
"Show row percentages",
value=get_uploaded_setting(
key="show_row_percentages", default=num_classes < 6, type_=bool
),
)
output["show_col_percentages"] = st.checkbox(
"Show column percentages",
value=get_uploaded_setting(
key="show_col_percentages", default=num_classes < 6, type_=bool
),
)
output["show_arrows"] = st.checkbox(
"Show arrows",
value=get_uploaded_setting(
key="show_arrows", default=True, type_=bool
),
)
output["diag_percentages_only"] = st.checkbox(
"Diagonal row/column percentages only",
value=get_uploaded_setting(
key="diag_percentages_only", default=False, type_=bool
),
)
with col2:
output["arrow_size"] = (
st.slider(
"Arrow size",
value=get_uploaded_setting(
key="arrow_size", default=0.048 * 10, type_=float
),
min_value=0.03 * 10,
max_value=0.06 * 10,
step=0.001 * 10,
)
/ 10
)
output["arrow_nudge_from_text"] = (
st.slider(
"Arrow nudge from text",
value=get_uploaded_setting(
key="arrow_nudge_from_text", default=0.065 * 10, type_=float
),
min_value=0.00,
max_value=0.1 * 10,
step=0.001 * 10,
)
/ 10
)
with st.expander("Tiles"):
col1, col2 = st.columns(2)
with col1:
output["intensity_by"] = _add_select_box(
key="intensity_by",
label="Intensity based on",
default="Counts",
options=[
"Counts",
"Normalized (%)",
"Log Counts",
"Log2 Counts",
"Log10 Counts",
"Arcsinh Counts",
],
get_setting_fn=get_uploaded_setting,
type_=str,
)
with col2:
output["darkness"] = st.slider(
"Darkness",
min_value=0.0,
max_value=1.0,
value=get_uploaded_setting(
key="darkness", default=0.8, type_=float
),
step=0.01,
help="How dark the darkest colors should be, between 0 and 1, where 1 is darkest.",
)
st.markdown("""---""")
output["show_tile_border"] = st.checkbox(
"Add tile borders",
value=get_uploaded_setting(
key="show_tile_border", default=False, type_=bool
),
)
col1, col2, col3 = st.columns(3)
with col1:
output["tile_border_color"] = st.color_picker(
"Border color",
value=get_uploaded_setting(
key="tile_border_color", default="#000000", type_=str
),
)
with col2:
output["tile_border_size"] = st.slider(
"Border size",
min_value=0.0,
max_value=3.0,
value=get_uploaded_setting(
key="tile_border_size", default=0.1, type_=float
),
step=0.01,
)
with col3:
output["tile_border_linetype"] = _add_select_box(
key="tile_border_linetype",
label="Border linetype",
default="solid",
options=[
"solid",
"dashed",
"dotted",
"dotdash",
"longdash",
"twodash",
],
get_setting_fn=get_uploaded_setting,
type_=str,
)
st.markdown("""---""")
st.write("Sum tile settings:")
col1, col2 = st.columns(2)
with col1:
output["sum_tile_palette"] = _add_select_box(
key="sum_tile_palette",
label="Color Palette",
default="Greens",
options=["Greens", "Oranges", "Greys", "Purples", "Reds", "Blues"],
get_setting_fn=get_uploaded_setting,
type_=str,
)
with col2:
output["sum_tile_label"] = st.text_input(
"Label",
value=get_uploaded_setting(
key="sum_tile_label", default="Σ", type_=str
),
key="sum_tiles_label",
)
# tile_fill = NULL,
# font_color = NULL,
# tile_border_color = NULL,
# tile_border_size = NULL,
# tile_border_linetype = NULL,
# tc_tile_fill = NULL,
# tc_font_color = NULL,
# tc_tile_border_color = NULL,
# tc_tile_border_size = NULL,
# tc_tile_border_linetype = NULL
with st.expander("Zero Counts"):
st.write("Special settings for tiles where the count is 0:")
col1, col2, col3 = st.columns(3)
with col1:
output["show_zero_shading"] = st.checkbox(
"Add shading",
value=get_uploaded_setting(
key="show_zero_shading", default=True, type_=bool
),
)
with col2:
output["show_zero_text"] = st.checkbox(
"Show text",
value=get_uploaded_setting(
key="show_zero_text", default=False, type_=bool
),
help="Whether to show counts, normalized (%), etc.",
)
with col3:
output["show_zero_percentages"] = st.checkbox(
"Show row/column percentages",
value=get_uploaded_setting(
key="show_zero_percentages", default=False, type_=bool
),
help="Only relevant when row/column percentages are enabled.",
)
if True:
with st.expander("Fonts"):
# Specify available settings and defaults per font
font_types = {
"Top Font": {
"key_prefix": "font_top",
"description": "The big text in the middle (normalized (%) by default).",
"settings": {
"size": 4.3, # 2.8
"color": "#000000",
"alpha": 1.0,
"bold": False,
"italic": False,
},
},
"Bottom Font": {
"key_prefix": "font_bottom",
"description": "The text just below the top font (counts by default).",
"settings": {
"size": 2.8,
"color": "#000000",
"alpha": 1.0,
"bold": False,
"italic": False,
},
},
"Percentages Font": {
"key_prefix": "font_percentage",
"description": "The row and column percentages.",
"settings": {
"size": 2.35,
"color": "#000000",
"alpha": 0.85,
"bold": False,
"italic": True,
"suffix": "%",
"prefix": "",
},
},
"Normalized (%)": {
"key_prefix": "font_normalized",
"description": "Special settings for the normalized (%) text.",
"settings": {"suffix": "%", "prefix": ""},
},
"Counts": {
"key_prefix": "font_counts",
"description": "Special settings for the counts text.",
"settings": {"suffix": "", "prefix": ""},
},
}
for font_type_title, font_type_spec in font_types.items():
st.markdown(f"**{font_type_title}**")
st.markdown(font_type_spec["description"])
num_cols = 3
font_settings = create_font_settings(
key_prefix=font_type_spec["key_prefix"],
get_setting_fn=get_uploaded_setting,
settings_to_get=list(font_type_spec["settings"].keys()),
)
for i, (setting_name, setting_widget) in enumerate(
font_settings.items()
):
if i % num_cols == 0:
cols = st.columns(num_cols)
with cols[i % num_cols]:
default = font_type_spec["settings"][
setting_name[len(font_type_spec["key_prefix"]) + 1 :]
]
output[setting_name] = setting_widget(
k=setting_name, d=default
)
if font_type_title != list(font_types.keys())[-1]:
st.markdown("""---""")
st.markdown("""---""")
if st.form_submit_button(label="Generate plot"):
st.session_state["step"] = 3
if (
not output["show_sums"]
or output["sum_tile_palette"] != output["palette"]
):
# Save settings as json
with open(design_settings_store_path, "w") as f:
json.dump(output, f)
if not output["place_x_axis_above"]:
selected_classes.reverse()
if reverse_class_order:
selected_classes.reverse()
design_ready = False
if st.session_state["step"] >= 3:
design_ready = True
if output["show_sums"] and output["sum_tile_palette"] == output["palette"]:
st.error(
"The color palettes (background colors) "
"for the tiles and sum tiles are identical. "
"Please select a different color palette for "
"the sum tiles under **Tiles** >> *Sum tile settings*."
)
design_ready = False
if len(selected_classes) < 2:
st.error("At least 2 classes must be selected.")
design_ready = False
return output, design_ready, selected_classes
# defaults: dict,
def create_font_settings(
key_prefix: str, get_setting_fn: Callable, settings_to_get: List[str]
) -> Tuple[dict, dict]:
# TODO: Defaults must be set based on font type! Also,
# we probably need to allow not setting the argument so the
# plotting function can handle the defaulting?
def make_key(key):
return f"{key_prefix}_{key}"
font_settings = {
make_key("color"): lambda k, d: st.color_picker(
"Color",
key=k,
value=get_setting_fn(key=k, default=d, type_=str),
),
make_key("bold"): lambda k, d: st.checkbox(
"Bold",
key=k,
value=get_setting_fn(key=k, default=d, type_=bool),
),
make_key("italic"): lambda k, d: st.checkbox(
"Italic",
key=k,
value=get_setting_fn(key=k, default=d, type_=bool),
),
make_key("size"): lambda k, d: st.number_input(
"Size",
key=k,
value=get_setting_fn(key=k, default=float(d), type_=float),
),
make_key("nudge_x"): lambda k, d: st.number_input(
"Nudge on x-axis",
key=k,
value=get_setting_fn(key=k, default=d, type_=float),
),
make_key("nudge_y"): lambda k, d: st.number_input(
"Nudge on y-axis",
key=k,
value=get_setting_fn(key=k, default=d, type_=float),
),
make_key("alpha"): lambda k, d: st.slider(
"Transparency",
min_value=0.0,
max_value=1.0,
value=get_setting_fn(key=k, default=d, type_=float),
step=0.01,
key=k,
),
make_key("prefix"): lambda k, d: st.text_input(
"Prefix",
key=k,
value=get_setting_fn(key=k, default=d, type_=str),
),
make_key("suffix"): lambda k, d: st.text_input(
"Suffix",
key=k,
value=get_setting_fn(key=k, default=d, type_=str),
),
}
# Filter settings
font_settings = {
k: v
for k, v in font_settings.items()
if f"{k[len(key_prefix)+1:]}" in settings_to_get
}
return font_settings