Spaces:
Runtime error
Runtime error
| """ | |
| This script evaluates the contribution of a technique from the ablation study for | |
| improving the masker evaluation metrics. The differences in the metrics are computed | |
| for all images of paired models, that is those which only differ in the inclusion or | |
| not of the given technique. Then, statistical inference is performed through the | |
| percentile bootstrap to obtain robust estimates of the differences in the metrics and | |
| confidence intervals. The script plots the distribution of the bootrstraped estimates. | |
| """ | |
| print("Imports...", end="") | |
| from argparse import ArgumentParser | |
| import yaml | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| from scipy.stats import trim_mean | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| # ----------------------- | |
| # ----- Constants ----- | |
| # ----------------------- | |
| dict_metrics = { | |
| "names": { | |
| "tpr": "TPR, Recall, Sensitivity", | |
| "tnr": "TNR, Specificity, Selectivity", | |
| "fpr": "FPR", | |
| "fpt": "False positives relative to image size", | |
| "fnr": "FNR, Miss rate", | |
| "fnt": "False negatives relative to image size", | |
| "mpr": "May positive rate (MPR)", | |
| "mnr": "May negative rate (MNR)", | |
| "accuracy": "Accuracy (ignoring may)", | |
| "error": "Error", | |
| "f05": "F05 score", | |
| "precision": "Precision", | |
| "edge_coherence": "Edge coherence", | |
| "accuracy_must_may": "Accuracy (ignoring cannot)", | |
| }, | |
| "key_metrics": ["f05", "error", "edge_coherence"], | |
| } | |
| dict_techniques = { | |
| "depth": "depth", | |
| "segmentation": "seg", | |
| "seg": "seg", | |
| "dada_s": "dada_seg", | |
| "dada_seg": "dada_seg", | |
| "dada_segmentation": "dada_seg", | |
| "dada_m": "dada_masker", | |
| "dada_masker": "dada_masker", | |
| "spade": "spade", | |
| "pseudo": "pseudo", | |
| "pseudo-labels": "pseudo", | |
| "pseudo_labels": "pseudo", | |
| } | |
| # Model features | |
| model_feats = [ | |
| "masker", | |
| "seg", | |
| "depth", | |
| "dada_seg", | |
| "dada_masker", | |
| "spade", | |
| "pseudo", | |
| "ground", | |
| "instagan", | |
| ] | |
| # Colors | |
| palette_colorblind = sns.color_palette("colorblind") | |
| color_cat1 = palette_colorblind[0] | |
| color_cat2 = palette_colorblind[1] | |
| palette_lightest = [ | |
| sns.light_palette(color_cat1, n_colors=20)[3], | |
| sns.light_palette(color_cat2, n_colors=20)[3], | |
| ] | |
| palette_light = [ | |
| sns.light_palette(color_cat1, n_colors=3)[1], | |
| sns.light_palette(color_cat2, n_colors=3)[1], | |
| ] | |
| palette_medium = [color_cat1, color_cat2] | |
| palette_dark = [ | |
| sns.dark_palette(color_cat1, n_colors=3)[1], | |
| sns.dark_palette(color_cat2, n_colors=3)[1], | |
| ] | |
| palette_cat1 = [ | |
| palette_lightest[0], | |
| palette_light[0], | |
| palette_medium[0], | |
| palette_dark[0], | |
| ] | |
| palette_cat2 = [ | |
| palette_lightest[1], | |
| palette_light[1], | |
| palette_medium[1], | |
| palette_dark[1], | |
| ] | |
| color_cat1_light = palette_light[0] | |
| color_cat2_light = palette_light[1] | |
| def parsed_args(): | |
| """ | |
| Parse and returns command-line args | |
| Returns: | |
| argparse.Namespace: the parsed arguments | |
| """ | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "--input_csv", | |
| default="ablations_metrics_20210311.csv", | |
| type=str, | |
| help="CSV containing the results of the ablation study", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| default=None, | |
| type=str, | |
| help="Output directory", | |
| ) | |
| parser.add_argument( | |
| "--technique", | |
| default=None, | |
| type=str, | |
| help="Keyword specifying the technique. One of: pseudo, depth, segmentation, dada_seg, dada_masker, spade", | |
| ) | |
| parser.add_argument( | |
| "--dpi", | |
| default=200, | |
| type=int, | |
| help="DPI for the output images", | |
| ) | |
| parser.add_argument( | |
| "--n_bs", | |
| default=1e6, | |
| type=int, | |
| help="Number of bootrstrap samples", | |
| ) | |
| parser.add_argument( | |
| "--alpha", | |
| default=0.99, | |
| type=float, | |
| help="Confidence level", | |
| ) | |
| parser.add_argument( | |
| "--bs_seed", | |
| default=17, | |
| type=int, | |
| help="Bootstrap random seed, for reproducibility", | |
| ) | |
| return parser.parse_args() | |
| def add_ci_mean( | |
| ax, sample_measure, bs_mean, bs_std, ci, color, alpha, fontsize, invert=False | |
| ): | |
| # Fill area between CI | |
| dist = ax.lines[0] | |
| dist_y = dist.get_ydata() | |
| dist_x = dist.get_xdata() | |
| linewidth = dist.get_linewidth() | |
| x_idx_low = np.argmin(np.abs(dist_x - ci[0])) | |
| x_idx_high = np.argmin(np.abs(dist_x - ci[1])) | |
| x_ci = dist_x[x_idx_low:x_idx_high] | |
| y_ci = dist_y[x_idx_low:x_idx_high] | |
| ax.fill_between(x_ci, 0, y_ci, facecolor=color, alpha=alpha) | |
| # Add vertical lines of CI | |
| ax.vlines( | |
| x=ci[0], | |
| ymin=0.0, | |
| ymax=y_ci[0], | |
| color=color, | |
| linewidth=linewidth, | |
| label="ci_low", | |
| ) | |
| ax.vlines( | |
| x=ci[1], | |
| ymin=0.0, | |
| ymax=y_ci[-1], | |
| color=color, | |
| linewidth=linewidth, | |
| label="ci_high", | |
| ) | |
| # Add annotations | |
| bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2) | |
| if invert: | |
| ha_l = "right" | |
| ha_u = "left" | |
| else: | |
| ha_l = "left" | |
| ha_u = "right" | |
| ax.text( | |
| ci[0], | |
| 0.0, | |
| s="L = {:.4f}".format(ci[0]), | |
| ha=ha_l, | |
| va="bottom", | |
| fontsize=fontsize, | |
| bbox=bbox_props, | |
| ) | |
| ax.text( | |
| ci[1], | |
| 0.0, | |
| s="U = {:.4f}".format(ci[1]), | |
| ha=ha_u, | |
| va="bottom", | |
| fontsize=fontsize, | |
| bbox=bbox_props, | |
| ) | |
| # Add vertical line of bootstrap mean | |
| x_idx_mean = np.argmin(np.abs(dist_x - bs_mean)) | |
| ax.vlines( | |
| x=bs_mean, ymin=0.0, ymax=dist_y[x_idx_mean], color="k", linewidth=linewidth | |
| ) | |
| # Add annotation of bootstrap mean | |
| bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2) | |
| ax.text( | |
| bs_mean, | |
| 0.6 * dist_y[x_idx_mean], | |
| s="Bootstrap mean = {:.4f}".format(bs_mean), | |
| ha="center", | |
| va="center", | |
| fontsize=fontsize, | |
| bbox=bbox_props, | |
| ) | |
| # Add vertical line of sample_measure | |
| x_idx_smeas = np.argmin(np.abs(dist_x - sample_measure)) | |
| ax.vlines( | |
| x=sample_measure, | |
| ymin=0.0, | |
| ymax=dist_y[x_idx_smeas], | |
| color="k", | |
| linewidth=linewidth, | |
| linestyles="dotted", | |
| ) | |
| # Add SD | |
| bbox_props = dict(boxstyle="darrow, pad=0.4", fc="w", ec="k", lw=2) | |
| ax.text( | |
| bs_mean, | |
| 0.4 * dist_y[x_idx_mean], | |
| s="SD = {:.4f} = SE".format(bs_std), | |
| ha="center", | |
| va="center", | |
| fontsize=fontsize, | |
| bbox=bbox_props, | |
| ) | |
| def add_null_pval(ax, null, color, alpha, fontsize): | |
| # Fill area between CI | |
| dist = ax.lines[0] | |
| dist_y = dist.get_ydata() | |
| dist_x = dist.get_xdata() | |
| linewidth = dist.get_linewidth() | |
| x_idx_null = np.argmin(np.abs(dist_x - null)) | |
| if x_idx_null >= (len(dist_x) / 2.0): | |
| x_pval = dist_x[x_idx_null:] | |
| y_pval = dist_y[x_idx_null:] | |
| else: | |
| x_pval = dist_x[:x_idx_null] | |
| y_pval = dist_y[:x_idx_null] | |
| ax.fill_between(x_pval, 0, y_pval, facecolor=color, alpha=alpha) | |
| # Add vertical lines of null | |
| dist = ax.lines[0] | |
| linewidth = dist.get_linewidth() | |
| y_max = ax.get_ylim()[1] | |
| ax.vlines( | |
| x=null, | |
| ymin=0.0, | |
| ymax=y_max, | |
| color="k", | |
| linewidth=linewidth, | |
| linestyles="dotted", | |
| ) | |
| # Add annotations | |
| bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2) | |
| ax.text( | |
| null, | |
| 0.75 * y_max, | |
| s="Null hypothesis = {:.1f}".format(null), | |
| ha="center", | |
| va="center", | |
| fontsize=fontsize, | |
| bbox=bbox_props, | |
| ) | |
| def plot_bootstrap_distr( | |
| sample_measure, bs_samples, alpha, color_ci, color_pval=None, null=None | |
| ): | |
| # Compute results from bootstrap | |
| q_low = (1.0 - alpha) / 2.0 | |
| q_high = 1.0 - q_low | |
| ci = np.quantile(bs_samples, [q_low, q_high]) | |
| bs_mean = np.mean(bs_samples) | |
| bs_std = np.std(bs_samples) | |
| if null is not None and color_pval is not None: | |
| pval_flag = True | |
| pval = np.min([[np.mean(bs_samples > null), np.mean(bs_samples < null)]]) * 2 | |
| else: | |
| pval_flag = False | |
| # Set up plot | |
| sns.set(style="whitegrid") | |
| fontsize = 24 | |
| font = {"family": "DejaVu Sans", "weight": "normal", "size": fontsize} | |
| plt.rc("font", **font) | |
| alpha_plot = 0.5 | |
| # Initialize the matplotlib figure | |
| fig, ax = plt.subplots(figsize=(30, 12), dpi=args.dpi) | |
| # Plot distribution of bootstrap means | |
| sns.kdeplot(bs_samples, color="b", linewidth=5, gridsize=1000, ax=ax) | |
| y_lim = ax.get_ylim() | |
| # Change spines | |
| sns.despine(left=True, bottom=True) | |
| # Annotations | |
| add_ci_mean( | |
| ax, | |
| sample_measure, | |
| bs_mean, | |
| bs_std, | |
| ci, | |
| color=color_ci, | |
| alpha=alpha_plot, | |
| fontsize=fontsize, | |
| ) | |
| if pval_flag: | |
| add_null_pval(ax, null, color=color_pval, alpha=alpha_plot, fontsize=fontsize) | |
| # Legend | |
| ci_patch = mpatches.Patch( | |
| facecolor=color_ci, | |
| edgecolor=None, | |
| alpha=alpha_plot, | |
| label="{:d} % confidence interval".format(int(100 * alpha)), | |
| ) | |
| if pval_flag: | |
| if pval == 0.0: | |
| pval_patch = mpatches.Patch( | |
| facecolor=color_pval, | |
| edgecolor=None, | |
| alpha=alpha_plot, | |
| label="P value / 2 = {:.1f}".format(pval / 2.0), | |
| ) | |
| elif np.around(pval / 2.0, decimals=4) > 0.0000: | |
| pval_patch = mpatches.Patch( | |
| facecolor=color_pval, | |
| edgecolor=None, | |
| alpha=alpha_plot, | |
| label="P value / 2 = {:.4f}".format(pval / 2.0), | |
| ) | |
| else: | |
| pval_patch = mpatches.Patch( | |
| facecolor=color_pval, | |
| edgecolor=None, | |
| alpha=alpha_plot, | |
| label="P value / 2 < $10^{}$".format(np.ceil(np.log10(pval / 2.0))), | |
| ) | |
| leg = ax.legend( | |
| handles=[ci_patch, pval_patch], | |
| ncol=1, | |
| loc="upper right", | |
| frameon=True, | |
| framealpha=1.0, | |
| title="", | |
| fontsize=fontsize, | |
| columnspacing=1.0, | |
| labelspacing=0.2, | |
| markerfirst=True, | |
| ) | |
| else: | |
| leg = ax.legend( | |
| handles=[ci_patch], | |
| ncol=1, | |
| loc="upper right", | |
| frameon=True, | |
| framealpha=1.0, | |
| title="", | |
| fontsize=fontsize, | |
| columnspacing=1.0, | |
| labelspacing=0.2, | |
| markerfirst=True, | |
| ) | |
| plt.setp(leg.get_title(), fontsize=fontsize, horizontalalignment="left") | |
| # Set X-label | |
| ax.set_xlabel("Bootstrap estimates", rotation=0, fontsize=fontsize, labelpad=10.0) | |
| # Set Y-label | |
| ax.set_ylabel("Density", rotation=90, fontsize=fontsize, labelpad=10.0) | |
| # Ticks | |
| plt.setp(ax.get_xticklabels(), fontsize=0.8 * fontsize, verticalalignment="top") | |
| plt.setp(ax.get_yticklabels(), fontsize=0.8 * fontsize) | |
| ax.set_ylim(y_lim) | |
| return fig, bs_mean, bs_std, ci, pval | |
| if __name__ == "__main__": | |
| # ----------------------------- | |
| # ----- Parse arguments ----- | |
| # ----------------------------- | |
| args = parsed_args() | |
| print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) | |
| # Determine output dir | |
| if args.output_dir is None: | |
| output_dir = Path(os.environ["SLURM_TMPDIR"]) | |
| else: | |
| output_dir = Path(args.output_dir) | |
| if not output_dir.exists(): | |
| output_dir.mkdir(parents=True, exist_ok=False) | |
| # Store args | |
| output_yml = output_dir / "{}_bootstrap.yml".format(args.technique) | |
| with open(output_yml, "w") as f: | |
| yaml.dump(vars(args), f) | |
| # Determine technique | |
| if args.technique.lower() not in dict_techniques: | |
| raise ValueError("{} is not a valid technique".format(args.technique)) | |
| else: | |
| technique = dict_techniques[args.technique.lower()] | |
| # Read CSV | |
| df = pd.read_csv(args.input_csv, index_col="model_img_idx") | |
| # Find relevant model pairs | |
| model_pairs = [] | |
| for mi in df.loc[df[technique]].model_feats.unique(): | |
| for mj in df.model_feats.unique(): | |
| if mj == mi: | |
| continue | |
| if df.loc[df.model_feats == mj, technique].unique()[0]: | |
| continue | |
| is_pair = True | |
| for f in model_feats: | |
| if f == technique: | |
| continue | |
| elif ( | |
| df.loc[df.model_feats == mj, f].unique()[0] | |
| != df.loc[df.model_feats == mi, f].unique()[0] | |
| ): | |
| is_pair = False | |
| break | |
| else: | |
| pass | |
| if is_pair: | |
| model_pairs.append((mi, mj)) | |
| break | |
| print("\nModel pairs identified:\n") | |
| for pair in model_pairs: | |
| print("{} & {}".format(pair[0], pair[1])) | |
| df["base"] = ["N/A"] * len(df) | |
| for spp in model_pairs: | |
| df.loc[df.model_feats.isin(spp), "depth_base"] = spp[1] | |
| # Build bootstrap data | |
| data = {m: [] for m in dict_metrics["key_metrics"]} | |
| for m_with, m_without in model_pairs: | |
| df_with = df.loc[df.model_feats == m_with] | |
| df_without = df.loc[df.model_feats == m_without] | |
| for metric in data.keys(): | |
| diff = ( | |
| df_with.sort_values(by="img_idx")[metric].values | |
| - df_without.sort_values(by="img_idx")[metric].values | |
| ) | |
| data[metric].extend(diff.tolist()) | |
| # Run bootstrap | |
| measures = ["mean", "median", "20_trimmed_mean"] | |
| bs_data = {meas: {m: np.zeros(args.n_bs) for m in data.keys()} for meas in measures} | |
| np.random.seed(args.bs_seed) | |
| for m, data_m in data.items(): | |
| for idx, s in enumerate(tqdm(range(args.n_bs))): | |
| # Sample with replacement | |
| bs_sample = np.random.choice(data_m, size=len(data_m), replace=True) | |
| # Store mean | |
| bs_data["mean"][m][idx] = np.mean(bs_sample) | |
| # Store median | |
| bs_data["median"][m][idx] = np.median(bs_sample) | |
| # Store 20 % trimmed mean | |
| bs_data["20_trimmed_mean"][m][idx] = trim_mean(bs_sample, 0.2) | |
| for metric in dict_metrics["key_metrics"]: | |
| sample_measure = trim_mean(data[metric], 0.2) | |
| fig, bs_mean, bs_std, ci, pval = plot_bootstrap_distr( | |
| sample_measure, | |
| bs_data["20_trimmed_mean"][metric], | |
| alpha=args.alpha, | |
| color_ci=color_cat1_light, | |
| color_pval=color_cat2_light, | |
| null=0.0, | |
| ) | |
| # Save figure | |
| output_fig = output_dir / "{}_bootstrap_{}_{}.png".format( | |
| args.technique, metric, "20_trimmed_mean" | |
| ) | |
| fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") | |
| # Store results | |
| output_results = output_dir / "{}_bootstrap_{}_{}.yml".format( | |
| args.technique, metric, "20_trimmed_mean" | |
| ) | |
| results_dict = { | |
| "measure": "20_trimmed_mean", | |
| "sample_measure": float(sample_measure), | |
| "bs_mean": float(bs_mean), | |
| "bs_std": float(bs_std), | |
| "ci_left": float(ci[0]), | |
| "ci_right": float(ci[1]), | |
| "pval": float(pval), | |
| } | |
| with open(output_results, "w") as f: | |
| yaml.dump(results_dict, f) | |