CASWiT: Context-Aware Stage Wise Transformer for Ultra-High Resolution Semantic Segmentation
Official implementation of CASWiT, a dual-branch architecture for ultra-high resolution semantic segmentation that leverages stage-wise cross-attention fusion between high-resolution and low-resolution branches.
π Table of Contents
- Overview
- Architecture
- Installation
- Docker
- Dataset Preparation
- Usage
- Configuration
- Results
- Citation
- License
π― Overview
CASWiT addresses semantic segmentation on ultra-high resolution (UHR) imagery with a dual-resolution design:
- HR Branch: processes high-resolution crops (e.g., 512Γ512) for fine details
- LR Branch: processes low-resolution context (typically downsampled) for global information
- Stage-wise Cross-Attention Fusion: HR features attend to LR context at each encoder stage
CASWiT generalizes beyond aerial imagery without architectural changes, and we provide configs/scripts for several UHR datasets (FLAIR-HUB, URUR, SWISSIMAGE inference, and additional medical benchmarks).
ποΈ Architecture
Key components:
- Dual Swin Transformer Backbones
- Cross-Attention Fusion Blocks at each encoder stage
- Auxiliary LR Supervision (optional, weighted by
model.lr_supervision_weight)
π¦ Installation
Requirements
- Python 3.12+
- PyTorch 2+
- CUDA 12+ (for GPU training)
Setup
git clone https://huggingface.co/heig-vd-geo/CASWiT
cd CASWiT
pip install -r requirements.txt
π³ Docker
A Dockerfile is provided for a reproducible environment.
Build
docker build -t caswit:latest .
Run (GPU)
docker run --gpus all -it --rm \
-v $(pwd):/workspace \
caswit:latest
If your datasets/checkpoints are outside the repo, mount them too, e.g.:
docker run --gpus all -it --rm \
-v $(pwd):/workspace \
-v /path/to/data:/data \
-v /path/to/checkpoints:/checkpoints \
caswit:latest
π Dataset Preparation
FLAIR-HUB
- Download the FlairHub dataset
- Merge GeoTIFF tiles into mosaics:
python dataset/prepareFlairHub.py
URUR
Expected structure:
URUR/
βββ train/
β βββ image/
β βββ label/
βββ val/
β βββ image/
β βββ label/
βββ test/
βββ image/
βββ label/
A re-hosted copy is available here: https://huggingface.co/datasets/heig-vd-geo/URUR
SWISSIMAGE
Download images using the provided CSV:
python dataset/download_swissimage.py list_all_swiss_image_sept2025.csv
π Usage
All commands below work either directly with the scripts in train/ or via the unified main.py.
Training
Single GPU:
python train/train.py configs/config_FlairHub.yaml
Multi-GPU (DDP):
torchrun --nproc_per_node=4 train/train.py configs/config_FlairHub.yaml
Evaluation
python train/eval.py configs/config_FlairHub.yaml weights/checkpoint.pth test
Inference (single image)
python train/inference.py configs/config_FlairHub.yaml weights/checkpoint.pth image.tif output.png
Inference on a VRT (DDP)
This runs tiled inference on a VRT using multiple GPUs:
torchrun --nproc_per_node=5 --master_port=29501 -m train.inference_vrt_ddp \
--config configs/config_SWISSIMAGE_inf.yaml \
--checkpoint weights/CASWiT-Base-SSL_FLAIRHUB_UN.pth \
--vrt file.vrt \
--out_dir output/ \
--tile 1024 \
--stride 512 \
--lr_side 2048
Using main.py
# Train
python main.py train --config configs/config_FlairHub.yaml
# Eval (default split = test in the script if not specified)
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth
# Inference (single image)
python main.py inference --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth --image image.tif --output pred.png
βοΈ Configuration
Configs are YAML files in configs/.
Model selection (decoder head)
You can select the decoder head directly in the config:
model:
# options: upernet | segformer | mask2former | fusion_last_stage_add | ssl
head: upernet
Notes:
mask2formermay requiretraining.amp: falsedepending on your environment/precision settings.sslis the SimMIM-style pretraining model and is not a drop-in replacement for segmentation inference/eval scripts.
Example config structure
paths:
data_path: "/path/to/dataset"
dataset_name: "FLAIRHUB"
train_img_subdir: "train/img"
train_msk_subdir: "train/msk"
val_img_subdir: "valid/img"
val_msk_subdir: "valid/msk"
test_img_subdir: "test/img"
test_msk_subdir: "test/msk"
save_dir: "weights"
pretrained_path: ""
model:
model_name: "openmmlab/upernet-swin-base"
num_classes: 15
cross_attention_heads: 1
ignore_index: 255
fusion_mlp_ratio: 4.0
fusion_drop_path: 0.1
lr_supervision_weight: 0.5
# options: upernet | segformer | mask2former | fusion_last_stage_add | ssl
head: upernet
training:
batch_size: 4
num_workers: 8
num_epochs: 20
learning_rate: 0.00006
amp: true
seed: 42
eta_min: 0.000001
wandb:
use_wandb: true
project: "CASWiT"
entity: "your-entity"
run_name: "caswit_experiment"
augmentations:
enable: false
p_hflip: 0.5
p_vflip: 0.5
p_rot90: 0.5
color_jitter:
brightness: 0.2
contrast: 0.2
saturation: 0.2
hue: 0.05
blur:
p: 0.1
kernel: 3
π§ͺ Reproducing experiments (all provided configs)
Below are ready-to-run commands for each config. Replace --checkpoint with your own file when evaluating/inferencing.
FLAIR-HUB
# Train (no extra aug)
python main.py train --config configs/config_FlairHub.yaml
# Train (with augmentations)
python main.py train --config configs/config_FlairHub_aug.yaml
# Eval
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth
URUR
python main.py train --config configs/config_URUR.yaml
python main.py train --config configs/config_URUR_aug.yaml
python main.py eval --config configs/config_URUR.yaml --checkpoint weights/checkpoint.pth
ISIC (medical)
python main.py train --config configs/config_ISIC_aug.yaml
python main.py eval --config configs/config_ISIC_aug.yaml --checkpoint weights/checkpoint.pth
CRAG (medical)
python main.py train --config configs/config_CRAG_aug.yaml
python main.py eval --config configs/config_CRAG_aug.yaml --checkpoint weights/checkpoint.pth
π Results
FLAIR-HUB (RGB-only UHR protocol)
| Model | mIoU (%) β | mF1 (%) β | mBIoU (%) β | GFLOPs β | FPS β |
|---|---|---|---|---|---|
| RGB Baselines (official) | |||||
| Swin-T + UPerNet | 62.01 | 75.27 | β | β | β |
| Swin-S + UPerNet | 61.87 | 75.11 | β | β | β |
| Swin-B + UPerNet | 64.05 | 76.88 | β | β | β |
| Swin-B + UPerNet (retrained) | 64.02 | 76.64 | 32.57 | β | β |
| Swin-L + UPerNet | 63.36 | 76.35 | β | 420 | 27.8 |
| Dual-branch baselines | |||||
| Dual Swin-Base (late fusion, no CA) | 64.25 | β | β | 398 | 19.4 |
| CASWiT-Base + UPerNet | 65.11 | 77.71 | 35.87 | 489 | 15.4 |
| CASWiT-Base-SSL + UPerNet | 65.35 | 77.87 | 35.99 | 489 | 15.4 |
| CASWiT-Base-SSL-aug + UPerNet | 65.83 | 78.22 | 36.90 | 489 | 15.4 |
| CASWiT-Base-SSL-aug + SegFormer | 66.37 | 78.58 | 36.51 | 298 | 17.9 |
CASWiT-Base already improves over the retrained Swin-B + UPerNet baseline, and CASWiT-Base-SSL-aug + SegFormer further pushes performance to 66.37 mIoU and ** 78.58 mF1**.
On mean boundary IoU, CASWiT-Base-SSL-aug + SegFormer reaches 36.51 mBIoU, which is a 3.94 mBIoU gain over the retrained Swin-B baseline (32.57).
ISIC / CRAG (test sets)
| Method | ISIC | CRAG |
|---|---|---|
| GPWFormer | 80.7 | 89.9 |
| Boosting Dual-Branch | 83.4 | 90.3 |
| CASWiT-Base-SSL-aug + UperNet (ours) | 85.4 | 90.3 |
| CASWiT-Base-SSL-aug + SegFormer (ours) | 86.5 | 90.7 |
URUR
We also evaluate CASWiT on the URUR ultra-high-resolution benchmark, comparing to both generic and UHR-specific segmentation models.
| Model | mIoU (%) β | Mem (MB) β |
|---|---|---|
| Generic Models | ||
| PSPNet | 32.0 | 5482 |
| ResNet18 + DeepLabv3+ | 33.1 | 5508 |
| STDC | 42.0 | 7617 |
| UHR Models | ||
| GLNet | 41.2 | 3063 |
| FCLt | 43.1 | 4508 |
| ISDNet | 45.8 | 4920 |
| WSDNet | 46.9 | 4510 |
| Boosting Dual-stream | 48.2 | 3682 |
| CASWiT-Base | 48.7 | 3530 |
| CASWiT-Base-SSL | 49.1 | 3530 |
On URUR, CASWiT-Base already matches and slightly surpasses prior UHR-specific methods, and CASWiT-Base-SSL achieves 49.1 mIoU, i.e. +2.2 mIoU over WSDNet and +0.9 mIoU over Boosting Dual-branch (UHRS), while remaining competitive in memory usage.
π¬ Self-Supervised Learning
CASWiT also supports self-supervised pre-training using SimMIM-style SSL (Simple Masked Image Modeling). We used this configuration on the entire SWISSIMAGE dataset to pretrain CASWiT.
π οΈ Project Structure
CASWiT/
βββ model/
β βββ CASWiT_upernet.py
β βββ CASWiT_segformer.py
β βββ CASWiT_m2f.py
β βββ CASWiT_fusion_last_stage_add.py
β βββ CASWiT_ssl.py
βββ dataset/
β βββ definition_dataset.py
β βββ fusion_augment.py
β βββ download_swissimage.py
β βββ prepareFlairHub.py
βββ configs/
β βββ config_FlairHub.yaml
β βββ config_FlairHub_aug.yaml
β βββ config_URUR.yaml
β βββ config_URUR_aug.yaml
β βββ config_SWISSIMAGE.yaml
β βββ config_SWISSIMAGE_inf.yaml
β βββ config_ISIC_aug.yaml
β βββ config_CRAG_aug.yaml
βββ utils/
β βββ metrics.py
β βββ attention_viz.py
βββ train/
β βββ train.py
β βββ eval.py
β βββ inference.py
β βββ inference_vrt_ddp.py
βββ weights/
βββ Dockerfile
βββ main.py
βββ requirements.txt
βββ README.md
π Citation
@misc{caswit,
title={Context-Aware Semantic Segmentation via Stage-Wise Attention},
author={Antoine Carreaud and Elias Naha and Arthur Chansel and Nina Lahellec and Jan Skaloud and Adrien Gressin},
year={2026},
eprint={2601.11310},
url={https://arxiv.org/abs/2601.11310},
}
π License
This project is licensed under the MIT License - see the LICENSE file for details.
π Acknowledgments
- UPerNet for the base segmentation architecture
- Swin Transformer for the base segmentation architecture
- FlairHub for the dataset
- URUR for the dataset
- CRAG for the dataset
- ISIC for the dataset
