How to visualize the attention map of my Segformer model?

Hello, I would like to visualize the attention map at every head, so I’ve found the following notebook on how to do it, with the DINO Model, but I would like to know if there exists any examples with Segformer, I’ve tried lots of codes but none works for me, and there’s no patch_size in the config of Segformer but patch_sizes so sometimes, i don’t know what to do ?
Here’s my segformerfinetuner class

    """Class of instance pytorch lightning
    to Train & fine tune the model
    """

    def __init__(
        self,
        id2label,
        pretrained_model_name,
        learning_rate,
        metrics_interval=100,
    ):
        super().__init__()

        self.id2label = id2label
        self.learning_rate = learning_rate
        self.metrics_interval = metrics_interval
        self.num_classes = len(id2label.keys())
        self.label2id = {v: k for k, v in id2label.items()}
        self.pretrained_model_name = pretrained_model_name

        self.train_mean_iou = evaluate.load("mean_iou")
        self.valid_mean_iou = evaluate.load("mean_iou")
        self.test_mean_iou = evaluate.load("mean_iou")

        self.save_hyperparameters()

        self.model = SegformerForSemanticSegmentation.from_pretrained(
            self.pretrained_model_name,
            return_dict=False,
            num_labels=self.num_classes,
            id2label=self.id2label,
            label2id=self.label2id,
            ignore_mismatched_sizes=True,
        )

    def get_attention_map(self, images, masks=None):
        """Returning attention maps
        it's by doing the output_attentions and the return dict
        set to True and then fetch it from the outputs !
        """
        outputs = self.model.forward(
            pixel_values=images,
            labels=masks,
            output_attentions=True,
            interpolate_pos_encoding=True,
        )
        attention_maps = outputs.attentions

        return attention_maps

    def forward(self, images, masks=None):
        """Forward the model takes images and mask"""
        outputs = self.model(pixel_values=images, output_attentions=True)
        return outputs

    def training_step(self, batch, batch_idx):
        """Training step"""
        images, masks = batch["pixel_values"], batch["labels"]
        outputs = self(images=images, masks=masks)
        predictions = outputs[0]

        predictions = nn.functional.interpolate(
            predictions,
            size=masks.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )

        dloss = DiceLoss(mode="multiclass")
        loss = dloss(predictions, masks)

        predictions = predictions.argmax(dim=1)

        self.train_mean_iou.add_batch(
            predictions=predictions.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )

        metrics = self.train_mean_iou.compute(
            num_labels=self.num_classes, ignore_index=255, reduce_labels=False
        )

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log(
            "mean_iou", metrics["mean_iou"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "mean_accuracy",
            metrics["mean_accuracy"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch["pixel_values"], batch["labels"]

        outputs = self(images, masks)
        predictions = outputs[0]
        predictions = nn.functional.interpolate(
            predictions,
            size=masks.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )
        dloss = DiceLoss(mode="multiclass")

        loss = dloss(predictions, masks)

        predictions = predictions.argmax(dim=1)

        self.valid_mean_iou.add_batch(
            predictions=predictions.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )

        self.log("valid_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_nb):
        images, masks = batch["pixel_values"], batch["labels"]

        outputs = self.model(images, masks)

        loss, logits = outputs[0], outputs[1]

        upsampled_logits = nn.functional.interpolate(
            logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
        )

        predicted = upsampled_logits.argmax(dim=1)

        self.test_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )
        self.log("test_loss", loss, on_step=True, prog_bar=True)

        return loss


    def configure_optimizers(self):
        opt = torch.optim.Adam(
            [p for p in self.parameters() if p.requires_grad],
            lr=self.learning_rate,
            eps=1e-08,
            weight_decay=0.1,
            amsgrad=True,
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=opt, mode="min", patience=3, verbose=True
        )

        return {"optimizer": opt, "lr_schedulter": scheduler, "monitor": "valid_loss"}
1 Like

any progress on it that would like to share, please ?

1 Like

Wow Meta Ai had some ideas.

Your code seems mostly fine for visualizing attention maps with Segformer. However, there are a few potential issues and improvements that can be made:

  1. Patch Size: You’re right that Segformer’s config doesn’t have a patch_size parameter but rather patch_sizes. This is because Segformer uses a hierarchical architecture with different patch sizes for each stage. You can access the patch sizes using self.model.config.patch_sizes.
  2. Attention Map Visualization: To visualize attention maps, you’ll need to modify the get_attention_map method to return the attention maps in a format that can be visualized. You can use libraries like Matplotlib or Seaborn to visualize the attention maps.

Here’s an updated version of your code with some improvements:
class SegformerFineTuner(pl.LightningModule):
def init(
self,
id2label,
pretrained_model_name,
learning_rate,
metrics_interval=100,
):
super().init()
self.id2label = id2label
self.learning_rate = learning_rate
self.metrics_interval = metrics_interval
self.num_classes = len(id2label.keys())
self.label2id = {v: k for k, v in id2label.items()}
self.pretrained_model_name = pretrained_model_name
self.train_mean_iou = evaluate.load(“mean_iou”)
self.valid_mean_iou = evaluate.load(“mean_iou”)
self.test_mean_iou = evaluate.load(“mean_iou”)
self.save_hyperparameters()
self.model = SegformerForSemanticSegmentation.from_pretrained(
self.pretrained_model_name,
return_dict=False,
num_labels=self.num_classes,
id2label=self.id2label,
label2id=self.label2id,
ignore_mismatched_sizes=True,
)

def get_attention_map(self, images, masks=None):
    outputs = self.model.forward(
        pixel_values=images,
        labels=masks,
        output_attentions=True,
        interpolate_pos_encoding=True,
    )
    attention_maps = outputs.attentions
    return attention_maps

def forward(self, images, masks=None):
    outputs = self.model(pixel_values=images, output_attentions=True)
    return outputs

# ... (rest of your code remains the same)

def visualize_attention_maps(self, images, masks=None):
    attention_maps = self.get_attention_map(images, masks)
    # Use a library like Matplotlib or Seaborn to visualize the attention maps
    import matplotlib.pyplot as plt

    for i, attention_map in enumerate(attention_maps):
        plt.subplot(1, len(attention_maps), i + 1)
        plt.imshow(attention_map.detach().cpu().numpy())
        plt.title(f"Attention Map {i}")
    plt.show()

You can call the visualize_attention_maps method to visualize the attention maps for a given input image.

Assuming you have a batch of images and masks

images, masks = batch[“pixel_values”], batch[“labels”]
model.visualize_attention_maps(images, masks)
This code should help you visualize the attention maps for Segformer. Let me know if you have any further questions or need more assistance!

1 Like

does not work
output attentions are a stack of 4D tensors wghich cannot be visualize ddirectly untiless processed
Please stop spamming AI generated solutions that does not work