Spaces:
Runtime error
Runtime error
| print("Imports...", end="", flush=True) | |
| import sys | |
| from pathlib import Path | |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
| import atexit | |
| import logging | |
| from argparse import ArgumentParser | |
| from copy import deepcopy | |
| import comet_ml | |
| import climategan | |
| from comet_ml.api import API | |
| from climategan.trainer import Trainer | |
| from climategan.utils import get_comet_rest_api_key | |
| logging.basicConfig() | |
| logging.getLogger().setLevel(logging.ERROR) | |
| import traceback | |
| print("Done.") | |
| def set_opts(opts, str_nested_key, value): | |
| """ | |
| Changes an opts with nested keys: | |
| set_opts(addict.Dict(), "a.b.c", 2) == Dict({"a":{"b": {"c": 2}}}) | |
| Args: | |
| opts (addict.Dict): opts whose values should be changed | |
| str_nested_key (str): nested keys joined on "." | |
| value (any): value to set to the nested keys of opts | |
| """ | |
| keys = str_nested_key.split(".") | |
| o = opts | |
| for k in keys[:-1]: | |
| o = o[k] | |
| o[keys[-1]] = value | |
| def set_conf(opts, conf): | |
| """ | |
| Updates opts according to a test scenario's configuration dict. | |
| Ignores all keys starting with "__" which are used for the scenario | |
| but outside the opts | |
| Args: | |
| opts (addict.Dict): trainer options | |
| conf (dict): scenario's configuration | |
| """ | |
| for k, v in conf.items(): | |
| if k.startswith("__"): | |
| continue | |
| set_opts(opts, k, v) | |
| class bcolors: | |
| HEADER = "\033[95m" | |
| OKBLUE = "\033[94m" | |
| OKGREEN = "\033[92m" | |
| WARNING = "\033[93m" | |
| FAIL = "\033[91m" | |
| ENDC = "\033[0m" | |
| BOLD = "\033[1m" | |
| UNDERLINE = "\033[4m" | |
| class Colors: | |
| def _r(self, key, *args): | |
| return f"{key}{' '.join(args)}{bcolors.ENDC}" | |
| def ob(self, *args): | |
| return self._r(bcolors.OKBLUE, *args) | |
| def w(self, *args): | |
| return self._r(bcolors.WARNING, *args) | |
| def og(self, *args): | |
| return self._r(bcolors.OKGREEN, *args) | |
| def f(self, *args): | |
| return self._r(bcolors.FAIL, *args) | |
| def b(self, *args): | |
| return self._r(bcolors.BOLD, *args) | |
| def u(self, *args): | |
| return self._r(bcolors.UNDERLINE, *args) | |
| def comet_handler(exp, api): | |
| def sub_handler(): | |
| p = Colors() | |
| print() | |
| print(p.b(p.w("Deleting comet experiment"))) | |
| api.delete_experiment(exp.get_key()) | |
| return sub_handler | |
| def print_start(desc): | |
| p = Colors() | |
| cdesc = p.b(p.ob(desc)) | |
| title = "| " + cdesc + " |" | |
| line = "-" * (len(desc) + 6) | |
| print(f"{line}\n{title}\n{line}") | |
| def print_end(desc=None, ok=None): | |
| p = Colors() | |
| if ok and desc is None: | |
| desc = "Done" | |
| cdesc = p.b(p.og(desc)) | |
| elif not ok and desc is None: | |
| desc = "! Fail !" | |
| cdesc = p.b(p.f(desc)) | |
| elif desc is not None: | |
| cdesc = p.b(p.og(desc)) | |
| else: | |
| desc = "Unknown" | |
| cdesc = desc | |
| title = "| " + cdesc + " |" | |
| line = "-" * (len(desc) + 6) | |
| print(f"{line}\n{title}\n{line}\n") | |
| def delete_on_exit(exp): | |
| """ | |
| Registers a callback to delete the comet exp at program exit | |
| Args: | |
| exp (comet_ml.Experiment): The exp to delete | |
| """ | |
| rest_api_key = get_comet_rest_api_key() | |
| api = API(api_key=rest_api_key) | |
| atexit.register(comet_handler(exp, api)) | |
| if __name__ == "__main__": | |
| # ----------------------------- | |
| # ----- Parse Arguments ----- | |
| # ----------------------------- | |
| parser = ArgumentParser() | |
| parser.add_argument("--no_delete", action="store_true", default=False) | |
| parser.add_argument("--no_end_to_end", action="store_true", default=False) | |
| parser.add_argument("--include", "-i", nargs="+", default=[]) | |
| parser.add_argument("--exclude", "-e", nargs="+", default=[]) | |
| args = parser.parse_args() | |
| assert not (args.include and args.exclude), "Choose 1: include XOR exclude" | |
| include = set(int(i) for i in args.include) | |
| exclude = set(int(i) for i in args.exclude) | |
| if include: | |
| print("Including exclusively tests", " ".join(args.include)) | |
| if exclude: | |
| print("Excluding tests", " ".join(args.exclude)) | |
| # -------------------------------------- | |
| # ----- Create global experiment ----- | |
| # -------------------------------------- | |
| print("Creating comet Experiment...", end="", flush=True) | |
| global_exp = comet_ml.Experiment( | |
| project_name="climategan-test", display_summary_level=0 | |
| ) | |
| print("Done.") | |
| if not args.no_delete: | |
| delete_on_exit(global_exp) | |
| # prompt util for colors | |
| prompt = Colors() | |
| # ------------------------------------- | |
| # ----- Base Test Scenario Opts ----- | |
| # ------------------------------------- | |
| print("Loading opts...", end="", flush=True) | |
| base_opts = climategan.utils.load_opts() | |
| base_opts.data.check_samples = False | |
| base_opts.train.fid.n_images = 5 | |
| base_opts.comet.display_size = 5 | |
| base_opts.tasks = ["m", "s", "d"] | |
| base_opts.domains = ["r", "s"] | |
| base_opts.data.loaders.num_workers = 4 | |
| base_opts.data.loaders.batch_size = 2 | |
| base_opts.data.max_samples = 9 | |
| base_opts.train.epochs = 1 | |
| if isinstance(base_opts.data.transforms[-1].new_size, int): | |
| base_opts.data.transforms[-1].new_size = 256 | |
| else: | |
| base_opts.data.transforms[-1].new_size.default = 256 | |
| print("Done.") | |
| # -------------------------------------- | |
| # ----- Configure Test Scenarios ----- | |
| # -------------------------------------- | |
| # override any nested key in opts | |
| # create scenario-specific variables with __key | |
| # ALWAYS specify a __doc key to describe your scenario | |
| test_scenarios = [ | |
| {"__use_comet": False, "__doc": "MSD no exp", "__verbose": 1}, # 0 | |
| {"__doc": "MSD with exp"}, # 1 | |
| { | |
| "__doc": "MSD no exp upsample_featuremaps", # 2 | |
| "__use_comet": False, | |
| "gen.d.upsample_featuremaps": True, | |
| "gen.s.upsample_featuremaps": True, | |
| }, | |
| {"tasks": ["p"], "domains": ["rf"], "__doc": "Painter"}, # 3 | |
| { | |
| "__doc": "M no exp low level feats", # 4 | |
| "__use_comet": False, | |
| "gen.m.use_low_level_feats": True, | |
| "gen.m.use_dada": False, | |
| "tasks": ["m"], | |
| }, | |
| { | |
| "__doc": "MSD no exp deeplabv2", # 5 | |
| "__use_comet": False, | |
| "gen.encoder.architecture": "deeplabv2", | |
| "gen.s.architecture": "deeplabv2", | |
| }, | |
| { | |
| "__doc": "MSDP no End-to-end", # 6 | |
| "domains": ["rf", "r", "s"], | |
| "tasks": ["m", "s", "d", "p"], | |
| }, | |
| { | |
| "__doc": "MSDP inference only no exp", # 7 | |
| "__inference": True, | |
| "__use_comet": False, | |
| "domains": ["rf", "r", "s"], | |
| "tasks": ["m", "s", "d", "p"], | |
| }, | |
| { | |
| "__doc": "MSDP with End-to-end", # 8 | |
| "__pl4m": True, | |
| "domains": ["rf", "r", "s"], | |
| "tasks": ["m", "s", "d", "p"], | |
| }, | |
| { | |
| "__doc": "Kitti pretrain", # 9 | |
| "train.epochs": 2, | |
| "train.kitti.pretrain": True, | |
| "train.kitti.epochs": 1, | |
| "domains": ["kitti", "r", "s"], | |
| "train.kitti.batch_size": 2, | |
| }, | |
| {"__doc": "Depth Dada archi", "gen.d.architecture": "dada"}, # 10 | |
| { | |
| "__doc": "Depth Base archi", | |
| "gen.d.architecture": "base", | |
| "gen.m.use_dada": False, | |
| "gen.s.use_dada": False, | |
| }, # 11 | |
| { | |
| "__doc": "Depth Base Classification", # 12 | |
| "gen.d.architecture": "base", | |
| "gen.d.classify.enable": True, | |
| "gen.m.use_dada": False, | |
| "gen.s.use_dada": False, | |
| }, | |
| { | |
| "__doc": "MSD Resnet V3+ backbone", | |
| "gen.deeplabv3.backbone": "resnet", | |
| }, # 13 | |
| { | |
| "__use_comet": False, | |
| "__doc": "MSD SPADE 12 (without x)", | |
| "__verbose": 1, | |
| "gen.m.use_spade": True, | |
| "gen.m.spade.cond_nc": 12, | |
| }, # 14 | |
| { | |
| "__use_comet": False, | |
| "__doc": "MSD SPADE 15 (with x)", | |
| "__verbose": 1, | |
| "gen.m.use_spade": True, | |
| "gen.m.spade.cond_nc": 15, | |
| }, # 15 | |
| { | |
| "__use_comet": False, | |
| "__doc": "Painter With Diff Augment", | |
| "__verbose": 1, | |
| "domains": ["rf"], | |
| "tasks": ["p"], | |
| "gen.p.diff_aug.use": True, | |
| }, # 15 | |
| { | |
| "__use_comet": False, | |
| "__doc": "MSD DADA_s", | |
| "__verbose": 1, | |
| "gen.s.use_dada": True, | |
| "gen.m.use_dada": False, | |
| }, # 16 | |
| { | |
| "__use_comet": False, | |
| "__doc": "MSD DADA_ms", | |
| "__verbose": 1, | |
| "gen.s.use_dada": True, | |
| "gen.m.use_dada": True, | |
| }, # 17 | |
| ] | |
| n_confs = len(test_scenarios) | |
| fails = [] | |
| successes = [] | |
| # -------------------------------- | |
| # ----- Run Test Scenarios ----- | |
| # -------------------------------- | |
| for test_idx, conf in enumerate(test_scenarios): | |
| if test_idx in exclude or (include and test_idx not in include): | |
| reason = ( | |
| "because it is in exclude" | |
| if test_idx in exclude | |
| else "because it is not in include" | |
| ) | |
| print("Ignoring test", test_idx, reason) | |
| continue | |
| # copy base scenario opts | |
| test_opts = deepcopy(base_opts) | |
| # update with scenario configuration | |
| set_conf(test_opts, conf) | |
| # print scenario description | |
| print_start( | |
| f"[{test_idx}/{n_confs - 1}] " | |
| + conf.get("__doc", "WARNING: no __doc for test scenario") | |
| ) | |
| print() | |
| comet = conf.get("__use_comet", True) | |
| pl4m = conf.get("__pl4m", False) | |
| inference = conf.get("__inference", False) | |
| verbose = conf.get("__verbose", 0) | |
| # set (or not) experiment | |
| test_exp = None | |
| if comet: | |
| test_exp = global_exp | |
| try: | |
| # create trainer | |
| trainer = Trainer( | |
| opts=test_opts, | |
| verbose=verbose, | |
| comet_exp=test_exp, | |
| ) | |
| trainer.functional_test_mode() | |
| # set (or not) painter loss for masker (= end-to-end) | |
| if pl4m: | |
| trainer.use_pl4m = True | |
| # test training procedure | |
| trainer.setup(inference=inference) | |
| if not inference: | |
| trainer.train() | |
| successes.append(test_idx) | |
| ok = True | |
| except Exception as e: | |
| print(e) | |
| print(traceback.format_exc()) | |
| fails.append(test_idx) | |
| ok = False | |
| finally: | |
| print_end(ok=ok) | |
| print_end(desc=" ----- Summary ----- ") | |
| if len(fails) == 0: | |
| print("•• All scenarios were successful") | |
| else: | |
| print(f"•• {len(successes)}/{len(test_scenarios)} successful tests") | |
| print(f"•• Failed test indices: {', '.join(map(str, fails))}") | |