Ludvig
commited on
Commit
·
f8cfd7b
1
Parent(s):
747851f
Many improvements and fixes
Browse files
app.py
CHANGED
|
@@ -10,7 +10,6 @@ import streamlit as st # Import last
|
|
| 10 |
import pandas as pd
|
| 11 |
from pandas.api.types import is_float_dtype
|
| 12 |
from itertools import combinations
|
| 13 |
-
from collections import OrderedDict
|
| 14 |
|
| 15 |
from utils import call_subprocess, clean_string_for_non_alphanumerics, clean_str_column
|
| 16 |
from data import read_data, read_data_cached, DownloadHeader, generate_data
|
|
@@ -157,7 +156,6 @@ elif input_choice == "Upload counts":
|
|
| 157 |
st.session_state["step"] = 2
|
| 158 |
|
| 159 |
if st.session_state["step"] >= 2:
|
| 160 |
-
print(st.session_state["count_data"])
|
| 161 |
# Ensure targets and predictions are clean strings
|
| 162 |
st.session_state["count_data"][target_col] = clean_str_column(
|
| 163 |
st.session_state["count_data"][target_col]
|
|
@@ -242,47 +240,43 @@ elif input_choice == "Enter counts":
|
|
| 242 |
st.session_state["count_data"] = pd.DataFrame(
|
| 243 |
all_pairs, columns=["Target", "Prediction"]
|
| 244 |
)
|
|
|
|
| 245 |
|
| 246 |
st.session_state["step"] = 1
|
| 247 |
|
| 248 |
if st.session_state["step"] >= 1:
|
| 249 |
with st.form(key="enter_counts_form"):
|
| 250 |
-
st.write(
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
st.
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
| 264 |
|
| 265 |
if st.form_submit_button(
|
| 266 |
label="Generate data",
|
| 267 |
):
|
| 268 |
-
st.session_state["count_data"]
|
| 269 |
-
int(val) for val in count_input_fields.values()
|
| 270 |
-
]
|
| 271 |
st.session_state["step"] = 2
|
| 272 |
|
| 273 |
if st.session_state["step"] >= 2:
|
| 274 |
DownloadHeader.header_and_data_download(
|
| 275 |
-
"
|
| 276 |
data=st.session_state["count_data"],
|
| 277 |
file_name="confusion_matrix_counts.csv",
|
|
|
|
| 278 |
help="Download counts",
|
| 279 |
-
col_sizes=[10,
|
| 280 |
)
|
| 281 |
-
col1, col2, col3 = st.columns([4, 5, 4])
|
| 282 |
-
with col2:
|
| 283 |
-
st.write(st.session_state["count_data"])
|
| 284 |
-
st.write(f"{st.session_state['count_data'].shape}")
|
| 285 |
-
|
| 286 |
target_col = "Target"
|
| 287 |
prediction_col = "Prediction"
|
| 288 |
n_col = "N"
|
|
@@ -318,7 +312,7 @@ if st.session_state["step"] >= 2:
|
|
| 318 |
st.subheader("The Data")
|
| 319 |
col1, col2, col3 = st.columns([2, 2, 2])
|
| 320 |
with col2:
|
| 321 |
-
st.
|
| 322 |
st.write(f"{df.shape} (Showing first 5 rows)")
|
| 323 |
|
| 324 |
else:
|
|
|
|
| 10 |
import pandas as pd
|
| 11 |
from pandas.api.types import is_float_dtype
|
| 12 |
from itertools import combinations
|
|
|
|
| 13 |
|
| 14 |
from utils import call_subprocess, clean_string_for_non_alphanumerics, clean_str_column
|
| 15 |
from data import read_data, read_data_cached, DownloadHeader, generate_data
|
|
|
|
| 156 |
st.session_state["step"] = 2
|
| 157 |
|
| 158 |
if st.session_state["step"] >= 2:
|
|
|
|
| 159 |
# Ensure targets and predictions are clean strings
|
| 160 |
st.session_state["count_data"][target_col] = clean_str_column(
|
| 161 |
st.session_state["count_data"][target_col]
|
|
|
|
| 240 |
st.session_state["count_data"] = pd.DataFrame(
|
| 241 |
all_pairs, columns=["Target", "Prediction"]
|
| 242 |
)
|
| 243 |
+
st.session_state["count_data"]["N"] = 0
|
| 244 |
|
| 245 |
st.session_state["step"] = 1
|
| 246 |
|
| 247 |
if st.session_state["step"] >= 1:
|
| 248 |
with st.form(key="enter_counts_form"):
|
| 249 |
+
st.write(
|
| 250 |
+
"Fill in the counts by pressing each cell in the `N` column and inputting the counts."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
new_counts = st.data_editor(
|
| 254 |
+
st.session_state["count_data"],
|
| 255 |
+
hide_index=True,
|
| 256 |
+
column_config={
|
| 257 |
+
"Target": st.column_config.TextColumn(disabled=True),
|
| 258 |
+
"Prediction": st.column_config.TextColumn(disabled=True),
|
| 259 |
+
"N": st.column_config.NumberColumn(
|
| 260 |
+
disabled=False, min_value=0, step=1
|
| 261 |
+
),
|
| 262 |
+
},
|
| 263 |
+
)
|
| 264 |
|
| 265 |
if st.form_submit_button(
|
| 266 |
label="Generate data",
|
| 267 |
):
|
| 268 |
+
st.session_state["count_data"] = new_counts
|
|
|
|
|
|
|
| 269 |
st.session_state["step"] = 2
|
| 270 |
|
| 271 |
if st.session_state["step"] >= 2:
|
| 272 |
DownloadHeader.header_and_data_download(
|
| 273 |
+
"",
|
| 274 |
data=st.session_state["count_data"],
|
| 275 |
file_name="confusion_matrix_counts.csv",
|
| 276 |
+
label="Download counts",
|
| 277 |
help="Download counts",
|
| 278 |
+
col_sizes=[10, 3],
|
| 279 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
target_col = "Target"
|
| 281 |
prediction_col = "Prediction"
|
| 282 |
n_col = "N"
|
|
|
|
| 312 |
st.subheader("The Data")
|
| 313 |
col1, col2, col3 = st.columns([2, 2, 2])
|
| 314 |
with col2:
|
| 315 |
+
st.dataframe(df.head(5), hide_index=True)
|
| 316 |
st.write(f"{df.shape} (Showing first 5 rows)")
|
| 317 |
|
| 318 |
else:
|
design.py
CHANGED
|
@@ -78,16 +78,14 @@ def design_section(
|
|
| 78 |
"Any observation with either a target or prediction "
|
| 79 |
"of another class is excluded.",
|
| 80 |
)
|
| 81 |
-
# TODO: Once the arrow bug in cvms is fixed, enable reversing!
|
| 82 |
# Reverse by default
|
| 83 |
-
|
| 84 |
with col2:
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
pass
|
| 91 |
|
| 92 |
# Color palette
|
| 93 |
output["palette"] = _add_select_box(
|
|
@@ -151,6 +149,45 @@ def design_section(
|
|
| 151 |
st.markdown("""---""")
|
| 152 |
st.markdown("**Advanced**:")
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
with st.expander("Elements"):
|
| 155 |
col1, col2 = st.columns(2)
|
| 156 |
with col1:
|
|
@@ -454,8 +491,10 @@ def design_section(
|
|
| 454 |
# Save settings as json
|
| 455 |
with open(design_settings_store_path, "w") as f:
|
| 456 |
json.dump(output, f)
|
| 457 |
-
|
| 458 |
-
|
|
|
|
|
|
|
| 459 |
|
| 460 |
design_ready = False
|
| 461 |
if st.session_state["step"] >= 3:
|
|
|
|
| 78 |
"Any observation with either a target or prediction "
|
| 79 |
"of another class is excluded.",
|
| 80 |
)
|
|
|
|
| 81 |
# Reverse by default
|
| 82 |
+
selected_classes.reverse()
|
| 83 |
with col2:
|
| 84 |
+
st.write(" ")
|
| 85 |
+
st.write(" ")
|
| 86 |
+
reverse_class_order = st.checkbox(
|
| 87 |
+
"Reverse order", value=False, help="Reverse the order of the classes."
|
| 88 |
+
)
|
|
|
|
| 89 |
|
| 90 |
# Color palette
|
| 91 |
output["palette"] = _add_select_box(
|
|
|
|
| 149 |
st.markdown("""---""")
|
| 150 |
st.markdown("**Advanced**:")
|
| 151 |
|
| 152 |
+
with st.expander("Labels"):
|
| 153 |
+
col1, col2 = st.columns(2)
|
| 154 |
+
with col1:
|
| 155 |
+
output["x_label"] = st.text_input(
|
| 156 |
+
"x-axis",
|
| 157 |
+
value=get_uploaded_setting(
|
| 158 |
+
key="x_label", default="True Class", type_=str
|
| 159 |
+
),
|
| 160 |
+
)
|
| 161 |
+
with col2:
|
| 162 |
+
output["y_label"] = st.text_input(
|
| 163 |
+
"y-axis",
|
| 164 |
+
value=get_uploaded_setting(
|
| 165 |
+
key="y_label", default="Predicted Class", type_=str
|
| 166 |
+
),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
st.markdown("---")
|
| 170 |
+
col1, col2 = st.columns(2)
|
| 171 |
+
with col1:
|
| 172 |
+
output["title_label"] = st.text_input(
|
| 173 |
+
"Title",
|
| 174 |
+
value=get_uploaded_setting(
|
| 175 |
+
key="title_label", default="", type_=str
|
| 176 |
+
),
|
| 177 |
+
)
|
| 178 |
+
with col2:
|
| 179 |
+
output["caption_label"] = st.text_input(
|
| 180 |
+
"Caption",
|
| 181 |
+
value=get_uploaded_setting(
|
| 182 |
+
key="caption_label", default="", type_=str
|
| 183 |
+
),
|
| 184 |
+
)
|
| 185 |
+
st.info(
|
| 186 |
+
"Note: When adding a title or caption, "
|
| 187 |
+
"you may need to adjust the height and "
|
| 188 |
+
"width of the plot as well."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
with st.expander("Elements"):
|
| 192 |
col1, col2 = st.columns(2)
|
| 193 |
with col1:
|
|
|
|
| 491 |
# Save settings as json
|
| 492 |
with open(design_settings_store_path, "w") as f:
|
| 493 |
json.dump(output, f)
|
| 494 |
+
if not output["place_x_axis_above"]:
|
| 495 |
+
selected_classes.reverse()
|
| 496 |
+
if reverse_class_order:
|
| 497 |
+
selected_classes.reverse()
|
| 498 |
|
| 499 |
design_ready = False
|
| 500 |
if st.session_state["step"] >= 3:
|
plot.R
CHANGED
|
@@ -267,6 +267,16 @@ if (isTRUE(design_settings$counts_on_top) ||
|
|
| 267 |
)
|
| 268 |
}
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
confusion_matrix_plot <- tryCatch(
|
| 272 |
{
|
|
@@ -282,7 +292,13 @@ confusion_matrix_plot <- tryCatch(
|
|
| 282 |
rm_zero_text = !design_settings$show_zero_text,
|
| 283 |
add_zero_shading = design_settings$show_zero_shading,
|
| 284 |
add_arrows = design_settings$show_arrows,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
counts_on_top = design_settings$counts_on_top,
|
|
|
|
|
|
|
| 286 |
diag_percentages_only = design_settings$diag_percentages_only,
|
| 287 |
digits = as.integer(design_settings$num_digits),
|
| 288 |
palette = design_settings$palette,
|
|
@@ -290,7 +306,10 @@ confusion_matrix_plot <- tryCatch(
|
|
| 290 |
font_counts = do.call("font", counts_font_args),
|
| 291 |
font_normalized = do.call("font", normalized_font_args),
|
| 292 |
font_row_percentages = do.call("font", percentages_font_args),
|
| 293 |
-
font_col_percentages = do.call("font", percentages_font_args)
|
|
|
|
|
|
|
|
|
|
| 294 |
)
|
| 295 |
},
|
| 296 |
error = function(e) {
|
|
@@ -301,6 +320,30 @@ confusion_matrix_plot <- tryCatch(
|
|
| 301 |
}
|
| 302 |
)
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
tryCatch(
|
| 305 |
{
|
| 306 |
ggplot2::ggsave(
|
|
|
|
| 267 |
)
|
| 268 |
}
|
| 269 |
|
| 270 |
+
tile_border_color <- NA
|
| 271 |
+
if (isTRUE(design_settings$show_tile_border)) {
|
| 272 |
+
tile_border_color <- design_settings$tile_border_color
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
intensity_by <- ifelse(
|
| 276 |
+
tolower(design_settings$intensity_by) == "counts",
|
| 277 |
+
"counts",
|
| 278 |
+
"normalized"
|
| 279 |
+
)
|
| 280 |
|
| 281 |
confusion_matrix_plot <- tryCatch(
|
| 282 |
{
|
|
|
|
| 292 |
rm_zero_text = !design_settings$show_zero_text,
|
| 293 |
add_zero_shading = design_settings$show_zero_shading,
|
| 294 |
add_arrows = design_settings$show_arrows,
|
| 295 |
+
arrow_size = design_settings$arrow_size,
|
| 296 |
+
arrow_nudge_from_text = design_settings$arrow_nudge_from_text,
|
| 297 |
+
intensity_by = intensity_by,
|
| 298 |
+
darkness = design_settings$darkness,
|
| 299 |
counts_on_top = design_settings$counts_on_top,
|
| 300 |
+
place_x_axis_above = design_settings$place_x_axis_above,
|
| 301 |
+
rotate_y_text = design_settings$rotate_y_text,
|
| 302 |
diag_percentages_only = design_settings$diag_percentages_only,
|
| 303 |
digits = as.integer(design_settings$num_digits),
|
| 304 |
palette = design_settings$palette,
|
|
|
|
| 306 |
font_counts = do.call("font", counts_font_args),
|
| 307 |
font_normalized = do.call("font", normalized_font_args),
|
| 308 |
font_row_percentages = do.call("font", percentages_font_args),
|
| 309 |
+
font_col_percentages = do.call("font", percentages_font_args),
|
| 310 |
+
tile_border_color = tile_border_color,
|
| 311 |
+
tile_border_size = design_settings$tile_border_size,
|
| 312 |
+
tile_border_linetype = design_settings$tile_border_linetype
|
| 313 |
)
|
| 314 |
},
|
| 315 |
error = function(e) {
|
|
|
|
| 320 |
}
|
| 321 |
)
|
| 322 |
|
| 323 |
+
# Add labels on x and y axes
|
| 324 |
+
confusion_matrix_plot <- confusion_matrix_plot +
|
| 325 |
+
ggplot2::labs(
|
| 326 |
+
x = design_settings$x_label,
|
| 327 |
+
y = design_settings$y_label
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Add title
|
| 331 |
+
if (nchar(design_settings$title_label) > 0) {
|
| 332 |
+
confusion_matrix_plot <- confusion_matrix_plot +
|
| 333 |
+
ggplot2::labs(
|
| 334 |
+
title = design_settings$title_label
|
| 335 |
+
)
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
# Add caption
|
| 339 |
+
if (nchar(design_settings$caption_label) > 0) {
|
| 340 |
+
confusion_matrix_plot <- confusion_matrix_plot +
|
| 341 |
+
ggplot2::labs(
|
| 342 |
+
caption = design_settings$caption_label
|
| 343 |
+
)
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
tryCatch(
|
| 348 |
{
|
| 349 |
ggplot2::ggsave(
|