Image Segmentation
English

CASWiT: Context-Aware Stage Wise Transformer for Ultra-High Resolution Semantic Segmentation

License: MIT SOTA: FLAIR HUB @ RGB SOTA: URUR SOTA: ISIC SOTA: CRAG

Official implementation of CASWiT, a dual-branch architecture for ultra-high resolution semantic segmentation that leverages stage-wise cross-attention fusion between high-resolution and low-resolution branches.

πŸ“‹ Table of Contents

🎯 Overview

CASWiT addresses semantic segmentation on ultra-high resolution (UHR) imagery with a dual-resolution design:

  • HR Branch: processes high-resolution crops (e.g., 512Γ—512) for fine details
  • LR Branch: processes low-resolution context (typically downsampled) for global information
  • Stage-wise Cross-Attention Fusion: HR features attend to LR context at each encoder stage

CASWiT generalizes beyond aerial imagery without architectural changes, and we provide configs/scripts for several UHR datasets (FLAIR-HUB, URUR, SWISSIMAGE inference, and additional medical benchmarks).

πŸ—οΈ Architecture

CASWiT architecture

Key components:

  • Dual Swin Transformer Backbones
  • Cross-Attention Fusion Blocks at each encoder stage
  • Auxiliary LR Supervision (optional, weighted by model.lr_supervision_weight)

πŸ“¦ Installation

Requirements

  • Python 3.12+
  • PyTorch 2+
  • CUDA 12+ (for GPU training)

Setup

git clone https://huggingface.co/heig-vd-geo/CASWiT
cd CASWiT
pip install -r requirements.txt

🐳 Docker

A Dockerfile is provided for a reproducible environment.

Build

docker build -t caswit:latest .

Run (GPU)

docker run --gpus all -it --rm \
  -v $(pwd):/workspace \
  caswit:latest

If your datasets/checkpoints are outside the repo, mount them too, e.g.:

docker run --gpus all -it --rm \
  -v $(pwd):/workspace \
  -v /path/to/data:/data \
  -v /path/to/checkpoints:/checkpoints \
  caswit:latest

πŸ“Š Dataset Preparation

FLAIR-HUB

  1. Download the FlairHub dataset
  2. Merge GeoTIFF tiles into mosaics:
python dataset/prepareFlairHub.py

URUR

Expected structure:

URUR/
β”œβ”€β”€ train/
β”‚   β”œβ”€β”€ image/
β”‚   └── label/
β”œβ”€β”€ val/
β”‚   β”œβ”€β”€ image/
β”‚   └── label/
└── test/
    β”œβ”€β”€ image/
    └── label/

A re-hosted copy is available here: https://huggingface.co/datasets/heig-vd-geo/URUR

SWISSIMAGE

Download images using the provided CSV:

python dataset/download_swissimage.py list_all_swiss_image_sept2025.csv

πŸš€ Usage

All commands below work either directly with the scripts in train/ or via the unified main.py.

Training

Single GPU:

python train/train.py configs/config_FlairHub.yaml

Multi-GPU (DDP):

torchrun --nproc_per_node=4 train/train.py configs/config_FlairHub.yaml

Evaluation

python train/eval.py configs/config_FlairHub.yaml weights/checkpoint.pth test

Inference (single image)

python train/inference.py configs/config_FlairHub.yaml weights/checkpoint.pth image.tif output.png

Inference on a VRT (DDP)

This runs tiled inference on a VRT using multiple GPUs:

torchrun --nproc_per_node=5 --master_port=29501 -m train.inference_vrt_ddp \
  --config configs/config_SWISSIMAGE_inf.yaml \
  --checkpoint weights/CASWiT-Base-SSL_FLAIRHUB_UN.pth \
  --vrt file.vrt \
  --out_dir output/ \
  --tile 1024 \
  --stride 512 \
  --lr_side 2048

Using main.py

# Train
python main.py train --config configs/config_FlairHub.yaml

# Eval (default split = test in the script if not specified)
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth

# Inference (single image)
python main.py inference --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth --image image.tif --output pred.png

βš™οΈ Configuration

Configs are YAML files in configs/.

Model selection (decoder head)

You can select the decoder head directly in the config:

model:
  # options: upernet | segformer | mask2former | fusion_last_stage_add | ssl
  head: upernet

Notes:

  • mask2former may require training.amp: false depending on your environment/precision settings.
  • ssl is the SimMIM-style pretraining model and is not a drop-in replacement for segmentation inference/eval scripts.

Example config structure

paths:
  data_path: "/path/to/dataset"
  dataset_name: "FLAIRHUB"
  train_img_subdir: "train/img"
  train_msk_subdir: "train/msk"
  val_img_subdir: "valid/img"
  val_msk_subdir: "valid/msk"
  test_img_subdir: "test/img"
  test_msk_subdir: "test/msk"
  save_dir: "weights"
  pretrained_path: ""

model:
  model_name: "openmmlab/upernet-swin-base"
  num_classes: 15
  cross_attention_heads: 1
  ignore_index: 255
  fusion_mlp_ratio: 4.0
  fusion_drop_path: 0.1
  lr_supervision_weight: 0.5
  # options: upernet | segformer | mask2former | fusion_last_stage_add | ssl
  head: upernet

training:
  batch_size: 4
  num_workers: 8
  num_epochs: 20
  learning_rate: 0.00006
  amp: true
  seed: 42
  eta_min: 0.000001

wandb:
  use_wandb: true
  project: "CASWiT"
  entity: "your-entity"
  run_name: "caswit_experiment"

augmentations:
  enable: false
  p_hflip: 0.5
  p_vflip: 0.5
  p_rot90: 0.5
  color_jitter:
    brightness: 0.2
    contrast: 0.2
    saturation: 0.2
    hue: 0.05
  blur:
    p: 0.1
    kernel: 3

πŸ§ͺ Reproducing experiments (all provided configs)

Below are ready-to-run commands for each config. Replace --checkpoint with your own file when evaluating/inferencing.

FLAIR-HUB

# Train (no extra aug)
python main.py train --config configs/config_FlairHub.yaml

# Train (with augmentations)
python main.py train --config configs/config_FlairHub_aug.yaml

# Eval
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth

URUR

python main.py train --config configs/config_URUR.yaml
python main.py train --config configs/config_URUR_aug.yaml

python main.py eval --config configs/config_URUR.yaml --checkpoint weights/checkpoint.pth

ISIC (medical)

python main.py train --config configs/config_ISIC_aug.yaml
python main.py eval  --config configs/config_ISIC_aug.yaml --checkpoint weights/checkpoint.pth

CRAG (medical)

python main.py train --config configs/config_CRAG_aug.yaml
python main.py eval  --config configs/config_CRAG_aug.yaml --checkpoint weights/checkpoint.pth

πŸ“ˆ Results

FLAIR-HUB (RGB-only UHR protocol)

Model mIoU (%) ↑ mF1 (%) ↑ mBIoU (%) ↑ GFLOPs ↓ FPS ↑
RGB Baselines (official)
Swin-T + UPerNet 62.01 75.27 – – –
Swin-S + UPerNet 61.87 75.11 – – –
Swin-B + UPerNet 64.05 76.88 – – –
Swin-B + UPerNet (retrained) 64.02 76.64 32.57 – –
Swin-L + UPerNet 63.36 76.35 – 420 27.8
Dual-branch baselines
Dual Swin-Base (late fusion, no CA) 64.25 – – 398 19.4
CASWiT-Base + UPerNet 65.11 77.71 35.87 489 15.4
CASWiT-Base-SSL + UPerNet 65.35 77.87 35.99 489 15.4
CASWiT-Base-SSL-aug + UPerNet 65.83 78.22 36.90 489 15.4
CASWiT-Base-SSL-aug + SegFormer 66.37 78.58 36.51 298 17.9

CASWiT-Base already improves over the retrained Swin-B + UPerNet baseline, and CASWiT-Base-SSL-aug + SegFormer further pushes performance to 66.37 mIoU and ** 78.58 mF1**.
On mean boundary IoU, CASWiT-Base-SSL-aug + SegFormer reaches 36.51 mBIoU, which is a 3.94 mBIoU gain over the retrained Swin-B baseline (32.57).

ISIC / CRAG (test sets)

Method ISIC CRAG
GPWFormer 80.7 89.9
Boosting Dual-Branch 83.4 90.3
CASWiT-Base-SSL-aug + UperNet (ours) 85.4 90.3
CASWiT-Base-SSL-aug + SegFormer (ours) 86.5 90.7

URUR

We also evaluate CASWiT on the URUR ultra-high-resolution benchmark, comparing to both generic and UHR-specific segmentation models.

Model mIoU (%) ↑ Mem (MB) ↓
Generic Models
PSPNet 32.0 5482
ResNet18 + DeepLabv3+ 33.1 5508
STDC 42.0 7617
UHR Models
GLNet 41.2 3063
FCLt 43.1 4508
ISDNet 45.8 4920
WSDNet 46.9 4510
Boosting Dual-stream 48.2 3682
CASWiT-Base 48.7 3530
CASWiT-Base-SSL 49.1 3530

On URUR, CASWiT-Base already matches and slightly surpasses prior UHR-specific methods, and CASWiT-Base-SSL achieves 49.1 mIoU, i.e. +2.2 mIoU over WSDNet and +0.9 mIoU over Boosting Dual-branch (UHRS), while remaining competitive in memory usage.

πŸ”¬ Self-Supervised Learning

CASWiT also supports self-supervised pre-training using SimMIM-style SSL (Simple Masked Image Modeling). We used this configuration on the entire SWISSIMAGE dataset to pretrain CASWiT.

πŸ› οΈ Project Structure

CASWiT/
β”œβ”€β”€ model/
β”‚   β”œβ”€β”€ CASWiT_upernet.py
β”‚   β”œβ”€β”€ CASWiT_segformer.py
β”‚   β”œβ”€β”€ CASWiT_m2f.py
β”‚   β”œβ”€β”€ CASWiT_fusion_last_stage_add.py
β”‚   └── CASWiT_ssl.py
β”œβ”€β”€ dataset/
β”‚   β”œβ”€β”€ definition_dataset.py
β”‚   β”œβ”€β”€ fusion_augment.py
β”‚   β”œβ”€β”€ download_swissimage.py
β”‚   └── prepareFlairHub.py
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ config_FlairHub.yaml
β”‚   β”œβ”€β”€ config_FlairHub_aug.yaml
β”‚   β”œβ”€β”€ config_URUR.yaml
β”‚   β”œβ”€β”€ config_URUR_aug.yaml
β”‚   β”œβ”€β”€ config_SWISSIMAGE.yaml
β”‚   β”œβ”€β”€ config_SWISSIMAGE_inf.yaml
β”‚   β”œβ”€β”€ config_ISIC_aug.yaml
β”‚   └── config_CRAG_aug.yaml
β”œβ”€β”€ utils/
β”‚   β”œβ”€β”€ metrics.py
β”‚   └── attention_viz.py
β”œβ”€β”€ train/
β”‚   β”œβ”€β”€ train.py
β”‚   β”œβ”€β”€ eval.py
β”‚   β”œβ”€β”€ inference.py
β”‚   └── inference_vrt_ddp.py
β”œβ”€β”€ weights/
β”œβ”€β”€ Dockerfile
β”œβ”€β”€ main.py
β”œβ”€β”€ requirements.txt
└── README.md

πŸ“ Citation

@misc{caswit,
      title={Context-Aware Semantic Segmentation via Stage-Wise Attention}, 
      author={Antoine Carreaud and Elias Naha and Arthur Chansel and Nina Lahellec and Jan Skaloud and Adrien Gressin},
      year={2026},
      eprint={2601.11310},
      url={https://arxiv.org/abs/2601.11310}, 
}

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

πŸ™ Acknowledgments


Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Datasets used to train heig-vd-geo/CASWiT

Collection including heig-vd-geo/CASWiT

Paper for heig-vd-geo/CASWiT