|
|
""" |
|
|
App for plotting confusion matrix with `cvms::plot_confusion_matrix()`. |
|
|
|
|
|
""" |
|
|
|
|
|
import pathlib |
|
|
import tempfile |
|
|
from PIL import Image |
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
from pandas.api.types import is_float_dtype |
|
|
from itertools import combinations |
|
|
from collections import OrderedDict |
|
|
|
|
|
from utils import call_subprocess, clean_string_for_non_alphanumerics |
|
|
from data import read_data, read_data_cached, DownloadHeader, generate_data |
|
|
from design import design_section |
|
|
from text_sections import ( |
|
|
intro_text, |
|
|
columns_text, |
|
|
upload_predictions_text, |
|
|
upload_counts_text, |
|
|
generate_data_text, |
|
|
enter_count_data_text, |
|
|
) |
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
.small-font { |
|
|
font-size:0.85em !important; |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def set_tmp_dir(): |
|
|
""" |
|
|
Must cache to avoid regenerating! |
|
|
Must be the same throughout the iterations! |
|
|
""" |
|
|
temp_dir = tempfile.TemporaryDirectory() |
|
|
return temp_dir, temp_dir.name |
|
|
|
|
|
|
|
|
temp_dir, temp_dir_path = set_tmp_dir() |
|
|
gen_data_store_path = pathlib.Path(f"{temp_dir_path}/generated_data.csv") |
|
|
data_store_path = pathlib.Path(f"{temp_dir_path}/data.csv") |
|
|
design_settings_store_path = pathlib.Path(f"{temp_dir_path}/design_settings.json") |
|
|
conf_mat_path = pathlib.Path(f"{temp_dir_path}/confusion_matrix.png") |
|
|
|
|
|
|
|
|
def input_choice_callback(): |
|
|
""" |
|
|
Resets steps to 0. |
|
|
Used when switching between input methods. |
|
|
""" |
|
|
st.session_state["step"] = 0 |
|
|
st.session_state["input_type"] = None |
|
|
|
|
|
to_delete = ["classes", "count_data"] |
|
|
for key in to_delete: |
|
|
if key in st.session_state: |
|
|
st.session_state.pop(key) |
|
|
|
|
|
|
|
|
if gen_data_store_path.exists(): |
|
|
gen_data_store_path.unlink() |
|
|
if data_store_path.exists(): |
|
|
data_store_path.unlink() |
|
|
if conf_mat_path.exists(): |
|
|
conf_mat_path.unlink() |
|
|
|
|
|
|
|
|
|
|
|
intro_text() |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.get("step") is None: |
|
|
st.session_state["step"] = 0 |
|
|
|
|
|
input_choice = st.radio( |
|
|
label="Input", |
|
|
options=["Upload predictions", "Upload counts", "Generate", "Enter counts"], |
|
|
index=0, |
|
|
horizontal=True, |
|
|
on_change=input_choice_callback, |
|
|
) |
|
|
|
|
|
if st.session_state.get("input_type") is None: |
|
|
if input_choice in ["Upload predictions", "Generate"]: |
|
|
st.session_state["input_type"] = "data" |
|
|
else: |
|
|
st.session_state["input_type"] = "counts" |
|
|
|
|
|
|
|
|
if input_choice == "Upload predictions": |
|
|
with st.form(key="data_form"): |
|
|
upload_predictions_text() |
|
|
data_path = st.file_uploader("Upload a dataset", type=["csv"]) |
|
|
if st.form_submit_button(label="Use data"): |
|
|
if data_path: |
|
|
st.session_state["step"] = 1 |
|
|
else: |
|
|
st.session_state["step"] = 0 |
|
|
st.markdown( |
|
|
"Please upload a file first (or **generate** some random data to try the function)." |
|
|
) |
|
|
|
|
|
if st.session_state["step"] >= 1: |
|
|
|
|
|
df = read_data_cached(data_path) |
|
|
with st.form(key="column_form"): |
|
|
columns_text() |
|
|
target_col = st.selectbox("Targets column", options=list(df.columns)) |
|
|
prediction_col = st.selectbox( |
|
|
"Predictions column", options=list(df.columns) |
|
|
) |
|
|
|
|
|
if st.form_submit_button(label="Set columns"): |
|
|
st.session_state["step"] = 2 |
|
|
|
|
|
|
|
|
elif input_choice == "Upload counts": |
|
|
with st.form(key="data_form"): |
|
|
upload_counts_text() |
|
|
data_path = st.file_uploader("Upload your counts", type=["csv"]) |
|
|
if st.form_submit_button(label="Use counts"): |
|
|
if data_path: |
|
|
st.session_state["step"] = 1 |
|
|
else: |
|
|
st.session_state["step"] = 0 |
|
|
st.write("Please upload a file first.") |
|
|
|
|
|
if st.session_state["step"] >= 1: |
|
|
|
|
|
st.session_state["count_data"] = read_data_cached(data_path) |
|
|
with st.form(key="column_form"): |
|
|
columns_text() |
|
|
target_col = st.selectbox( |
|
|
"Targets column", options=list(st.session_state["count_data"].columns) |
|
|
) |
|
|
prediction_col = st.selectbox( |
|
|
"Predictions column", |
|
|
options=list(st.session_state["count_data"].columns), |
|
|
) |
|
|
n_col = st.selectbox( |
|
|
"Counts column", options=list(st.session_state["count_data"].columns) |
|
|
) |
|
|
|
|
|
if st.form_submit_button(label="Set columns"): |
|
|
st.session_state["step"] = 2 |
|
|
st.session_state["classes"] = sorted( |
|
|
[ |
|
|
str(c) |
|
|
for c in st.session_state["count_data"][target_col].unique() |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
elif input_choice == "Generate": |
|
|
|
|
|
def reset_generation_callback(): |
|
|
p = pathlib.Path(gen_data_store_path) |
|
|
if p.exists(): |
|
|
p.unlink() |
|
|
|
|
|
with st.form(key="generate_form"): |
|
|
generate_data_text() |
|
|
col1, col2, col3 = st.columns(3) |
|
|
with col1: |
|
|
num_classes = st.number_input( |
|
|
"# Classes", |
|
|
value=3, |
|
|
min_value=2, |
|
|
help="Number of classes to generate data for.", |
|
|
) |
|
|
with col2: |
|
|
num_observations = st.number_input( |
|
|
"# Observations", |
|
|
value=30, |
|
|
min_value=2, |
|
|
max_value=10000, |
|
|
help="Number of observations to generate data for.", |
|
|
) |
|
|
with col3: |
|
|
seed = st.number_input("Random Seed", value=42, min_value=0) |
|
|
if st.form_submit_button( |
|
|
label="Generate data", on_click=reset_generation_callback |
|
|
): |
|
|
st.session_state["step"] = 2 |
|
|
|
|
|
if st.session_state["step"] >= 2: |
|
|
generate_data( |
|
|
out_path=gen_data_store_path, |
|
|
num_classes=num_classes, |
|
|
num_observations=num_observations, |
|
|
seed=seed, |
|
|
) |
|
|
df = read_data(gen_data_store_path) |
|
|
target_col = "Target" |
|
|
prediction_col = "Predicted Class" |
|
|
|
|
|
elif input_choice == "Enter counts": |
|
|
|
|
|
def repopulate_matrix_callback(): |
|
|
if "count_data" not in st.session_state: |
|
|
if "count_data" in st.session_state: |
|
|
st.session_state.pop("count_data") |
|
|
|
|
|
with st.form(key="enter_classes_form"): |
|
|
enter_count_data_text() |
|
|
classes_joined = st.text_input("Classes (comma-separated)") |
|
|
|
|
|
if st.form_submit_button( |
|
|
label="Populate matrix", on_click=repopulate_matrix_callback |
|
|
): |
|
|
|
|
|
|
|
|
st.session_state["classes"] = [ |
|
|
clean_string_for_non_alphanumerics(s) for s in classes_joined.split(",") |
|
|
] |
|
|
|
|
|
|
|
|
all_pairs = list(combinations(st.session_state["classes"], 2)) |
|
|
all_pairs += [(pair[1], pair[0]) for pair in all_pairs] |
|
|
all_pairs += [(c, c) for c in st.session_state["classes"]] |
|
|
|
|
|
|
|
|
st.session_state["count_data"] = pd.DataFrame( |
|
|
all_pairs, columns=["Target", "Prediction"] |
|
|
) |
|
|
|
|
|
st.session_state["step"] = 1 |
|
|
|
|
|
if st.session_state["step"] >= 1: |
|
|
with st.form(key="enter_counts_form"): |
|
|
st.write("Fill in the counts for `N(Target, Prediction)` pairs.") |
|
|
count_input_fields = OrderedDict() |
|
|
|
|
|
num_cols = 3 |
|
|
cols = st.columns(num_cols) |
|
|
for i, (targ, pred) in enumerate( |
|
|
zip( |
|
|
st.session_state["count_data"]["Target"], |
|
|
st.session_state["count_data"]["Prediction"], |
|
|
) |
|
|
): |
|
|
count_input_fields[f"{targ}____{pred}"] = cols[ |
|
|
i % num_cols |
|
|
].number_input(f"N({targ}, {pred})", step=1) |
|
|
|
|
|
if st.form_submit_button( |
|
|
label="Generate data", |
|
|
): |
|
|
st.session_state["count_data"]["N"] = [ |
|
|
int(val) for val in count_input_fields.values() |
|
|
] |
|
|
st.session_state["step"] = 2 |
|
|
|
|
|
if st.session_state["step"] >= 2: |
|
|
DownloadHeader.header_and_data_download( |
|
|
"Entered counts", |
|
|
data=st.session_state["count_data"], |
|
|
file_name="confusion_matrix_counts.csv", |
|
|
help="Download counts", |
|
|
col_sizes=[10, 2], |
|
|
) |
|
|
col1, col2, col3 = st.columns([4, 5, 4]) |
|
|
with col2: |
|
|
st.write(st.session_state["count_data"]) |
|
|
st.write(f"{st.session_state['count_data'].shape}") |
|
|
|
|
|
target_col = "Target" |
|
|
prediction_col = "Prediction" |
|
|
n_col = "N" |
|
|
|
|
|
if st.session_state["step"] >= 2: |
|
|
if st.session_state["input_type"] == "data": |
|
|
|
|
|
df = df.loc[:, [target_col, prediction_col]] |
|
|
|
|
|
|
|
|
df[target_col] = df[target_col].astype(str) |
|
|
df[target_col] = df[target_col].apply(lambda x: x.replace(" ", "_")) |
|
|
|
|
|
|
|
|
df.to_csv(data_store_path) |
|
|
|
|
|
|
|
|
st.session_state["classes"] = sorted([str(c) for c in df[target_col].unique()]) |
|
|
|
|
|
predictions_are_probabilities = is_float_dtype(df[prediction_col]) |
|
|
if predictions_are_probabilities and len(st.session_state["classes"]) != 2: |
|
|
st.error( |
|
|
"Predictions can only be probabilities in binary classification. " |
|
|
f"Got {len(st.session_state['classes'])} classes." |
|
|
) |
|
|
|
|
|
st.subheader("The Data") |
|
|
col1, col2, col3 = st.columns([2, 2, 2]) |
|
|
with col2: |
|
|
st.write(df.head(5)) |
|
|
st.write(f"{df.shape} (Showing first 5 rows)") |
|
|
|
|
|
else: |
|
|
predictions_are_probabilities = False |
|
|
st.session_state["count_data"].to_csv(data_store_path) |
|
|
|
|
|
|
|
|
num_classes = len(st.session_state["classes"]) |
|
|
if num_classes < 2: |
|
|
|
|
|
raise ValueError( |
|
|
"Uploaded data must contain 2 or more classes in `Targets column`. " |
|
|
f"Got {num_classes} target classes." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
design_settings, design_ready, selected_classes, prob_of_class = design_section( |
|
|
num_classes=num_classes, |
|
|
predictions_are_probabilities=predictions_are_probabilities, |
|
|
design_settings_store_path=design_settings_store_path, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state["step"] >= 3 and design_ready: |
|
|
DownloadHeader.centered_json_download( |
|
|
data=design_settings, |
|
|
file_name="design_settings.json", |
|
|
label="Download design settings", |
|
|
help="Download the design settings to allow reusing settings in future plots.", |
|
|
) |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
plotting_args = [ |
|
|
"--data_path", |
|
|
f"'{data_store_path}'", |
|
|
"--out_path", |
|
|
f"'{conf_mat_path}'", |
|
|
"--settings_path", |
|
|
f"'{design_settings_store_path}'", |
|
|
"--target_col", |
|
|
f"'{target_col}'", |
|
|
"--prediction_col", |
|
|
f"'{prediction_col}'", |
|
|
"--classes", |
|
|
f"{','.join(selected_classes)}", |
|
|
] |
|
|
|
|
|
if st.session_state["input_type"] == "counts": |
|
|
|
|
|
plotting_args += ["--n_col", f"{n_col}", "--data_are_counts"] |
|
|
|
|
|
plotting_args = " ".join(plotting_args) |
|
|
|
|
|
call_subprocess( |
|
|
f"Rscript plot.R {plotting_args}", |
|
|
message="Plotting script", |
|
|
return_output=True, |
|
|
encoding="UTF-8", |
|
|
) |
|
|
|
|
|
DownloadHeader.header_and_image_download( |
|
|
"", filepath=conf_mat_path, label="Download Plot" |
|
|
) |
|
|
col1, col2, col3 = st.columns([2, 8, 2]) |
|
|
with col2: |
|
|
image = Image.open(conf_mat_path) |
|
|
st.image( |
|
|
image, |
|
|
caption="Confusion Matrix", |
|
|
clamp=False, |
|
|
channels="RGB", |
|
|
output_format="auto", |
|
|
) |
|
|
|
|
|
else: |
|
|
st.write("Please upload data.") |
|
|
|
|
|
|
|
|
for _ in range(5): |
|
|
st.write(" ") |
|
|
|
|
|
st.markdown("---") |
|
|
st.write() |
|
|
col1, col2, col3, _ = st.columns([6, 3, 3, 3]) |
|
|
with col1: |
|
|
st.write("Developed by [Ludvig Renbo Olsen](http://ludvigolsen.dk)") |
|
|
with col2: |
|
|
st.markdown("[Report issues](https://github.com/LudvigOlsen/cvms_plot_app/issues)") |
|
|
with col3: |
|
|
st.markdown("[Source code](https://github.com/LudvigOlsen/cvms_plot_app/)") |
|
|
|