Diffusers

12 minute read

Published:

Diffusers 是 Diffusion 模型的代码框架包装,其最大的单位是 pipeline。你只需要一个 pipeline 和权重文件(夹),就可以做到几行代码生成图片。例子如下:

from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5", torch_dtype=torch.float16, variant='fp16')
pipe.to("cuda")
pipe("A dog").images[0]

现在一些论文(比如 layout-guidance,简称 tflcg)利用 diffusers 库,结合自己的算法提供了一个 pipeline 供大家使用,比如其库提供的示例代码:

from tflcg.layout_guidance_pipeline import LayoutGuidanceStableDiffusionPipeline
from diffusers import EulerDiscreteScheduler
import transformers
import torch
transformers.utils.move_cache()
pipe = LayoutGuidanceStableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5", torch_dtype=torch.float16, variant='fp16')
pipe = pipe.to("mps")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
prompt = "A cat playing with a ball"
bboxes = [[0.55, 0.4, 0.95, 0.8]]
image = pipe(prompt, num_inference_steps=20,
             token_indices=[[2]],
             bboxes=bboxes).images[0]
image = pipe.draw_box(image, bboxes)
image.save("output.png")

为了能够在未来的工作中写出类似的代码,我觉得我有必要研究一下 diffusers 库的相关代码细节。我参考的是知乎上大佬的这篇文章并且补充了一些细节。

Pipeline

一个 pipeline 包含了:

  • VAE、UNet、Text Encoder 三个大模块。

  • Tokenizer、Scheduler、Safety checker、Feature extractor 四个小模块。

其中后两个是用来检测是否健康的(你懂的),也称作 NSFW 检测器(Not Safe For Work)。我们使用 from_pretrained 函数读取一个 pipeline,这个函数会根据文件夹中的 model_index.json 新建一个 pipeline 类。比如 SD 1.5 的 model_index.json 文件如下:

{
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.6.0",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "safety_checker": [
    "stable_diffusion",
    "StableDiffusionSafetyChecker"
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

_class_name 告诉你要变成啥 pipeline,比如 SD 就是 StableDiffusionPipeline。然后后面就是从哪个库里读取什么类作为 pipeline 这个模块的类,比如 VAE 模块对应的是 diffusers 库中的 AutoencoderKL 类。

StableDiffusionPipeline

我们来看 StableDiffusionPipeline 的 __init__

class StableDiffusionPipeline(
    DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
): # 继承了一堆东西,比如最基本的 Diffusion Pipeline
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        image_encoder: CLIPVisionModelWithProjection = None,
        requires_safety_checker: bool = True,
    ):
        super().__init__()
        如果 *** 缺少信息给出警告
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            image_encoder=image_encoder,
        )
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # VAE 的压缩率?
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
        self.register_to_config(requires_safety_checker=requires_safety_checker)

这个变量跟着一个冒号的叫做类型注解,用于注释,不影响代码运行(也就是说,即使调用函数时形参类型不匹配,python 也不会报错)

然后一堆 hasattr 函数检查 config 文件有没有漏信息,给出警告。

最后这个 register_modules 是 Diffusion Pipeline 内置的函数:

    def register_modules(self, **kwargs):
        for name, module in kwargs.items():
            # retrieve library
            if module is None or isinstance(module, (tuple, list)) and module[0] is None:
                register_dict = {name: (None, None)}
            else:
                library, class_name = _fetch_class_library_tuple(module)
                register_dict = {name: (library, class_name)}
            # save model index config
            self.register_to_config(**register_dict)
            # set models
            setattr(self, name, module)

不讲废话,然后是 __call__ 部分,也就是生成图片的部分。

    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        timesteps: List[int] = None,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        ip_adapter_image: Optional[PipelineImageInput] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.0,
        clip_skip: Optional[int] = None,
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        **kwargs,
    ):

这里有一堆参数,看过论文的话,字面意思大多都能理解(除了 negative prompt 部分我也没看过,还有 callback 这一些,先不管)

然后做了一个输入参数的 check,点进去特别好玩,有一堆 raise Error。

        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

这段代码把 batch size 设为 prompt 的个数,并行计算。

        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            self.do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=lora_scale,
            clip_skip=self.clip_skip,
        )

进行编码。

    timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

设置时间参数。这种方法有很多,比如我们看 tflcg 示例代码里:

pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)

就可以发现,他用的是 EulerDiscreteScheduler ,虽然我也不知道是啥。SD 1.5 默认的话是 PNDMScheduler。 我还是不知道是啥。

        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

初始噪声的确定,返回 [batch size, 4, 64, 64] 的 tensor。它里面包装的很好,generator 甚至可以是一个 list。

后面的整个循环就是经典 DDPM/DDIM 采样。由于我没看过 classifier free guidance,所以没仔细看,直接咕咕咕了。

        latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = self.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            timestep_cond=timestep_cond,
            cross_attention_kwargs=self.cross_attention_kwargs,
            added_cond_kwargs=added_cond_kwargs,
            return_dict=False,
        )[0]

        # compute the previous noisy sample x_t -> x_t-1
        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

function: scale_model_input

scheduler.scale_model_input 是什么意思?我翻了源码,比如 scheduling_euler_discrete.py 是这样的:

    def scale_model_input(
        self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
    ) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.

        Args:
            sample (`torch.FloatTensor`):
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.

        Returns:
            `torch.FloatTensor`:
                A scaled input sample.
        """
        if self.step_index is None:
            self._init_step_index(timestep)

        sigma = self.sigmas[self.step_index]
        sample = sample / ((sigma**2 + 1) ** 0.5)

        self.is_scale_input_called = True
        return sample

我猜测是根据不同的 scheduler 对不同的输入进行预处理,事实证明应该就是这样,因为我翻了 scheduling_ddim.py,发现了一个很搞笑的:

    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`):
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.

        Returns:
            `torch.FloatTensor`:
                A scaled input sample.
        """
        return sample

处理就是:啥都不处理

这些递推函数什么的,都在 scheduler 里封装好了。实在是太酷了。

LayoutGuidanceStableDiffusionPipeline

可以发现,它的 __call__ 几乎就是在 SD 的基础上魔改的。只不过它多传了 token_indicesbboxes。我们调不一样的地方看。

首先它继承了 StableDiffusionAttendAndExcitePipeline,也就是说,它其实是在 AAE 的基础上做的,为啥论文里都没写。

值得一提的是,我看的 diffusers 是最新版,但是 tflcg 基于的是 0.15.0,所以一些细节会不一样,支持的功能可能也不一样。

另外值得一提的是,AAE 已经被 diffusers 删了,成为时代的眼泪了(难绷,也就一年啊)。

看关键的反向传播部分:

     with torch.enable_grad():
         latents = latents.clone().detach().requires_grad_(True) # 先克隆,单独拿出来,为了不影响计算图
         for guidance_iter in range(max_guidance_iter_per_step): # 反向传播几次,论文也写了,默认 5 次,
             if loss.item() / scale_factor < 0.2:
                 break
             latent_model_input = self.scheduler.scale_model_input(
                 latents, t
             )
             self.unet(                                          # 过一遍 Unet
                 latent_model_input,
                 t,
                 encoder_hidden_states=cond_prompt_embeds,
                 cross_attention_kwargs=cross_attention_kwargs,
             )
             self.unet.zero_grad()
             loss = (                                            # 根据论文描述的过 loss
                 self._compute_loss(token_indices, bboxes, device)
                 * scale_factor
             )
             grad_cond = torch.autograd.grad(                    # 求梯度
                 loss.requires_grad_(True),
                 [latents],
                 retain_graph=True,
             )[0]
             latents = (
                 latents - grad_cond * self.scheduler.sigmas[i] ** 2
             )

那么这个主程序部分也就结束了。但是,loss 函数中 Attention Map 这部分是怎么提取出来的呢?

Attention Map

        self.attention_store = AttentionStore() # 一个存储 Attention map 的地方
        self.register_attention_control()

主程序中代码先用这两句确定了一个 attention_store,然后用一个函数“注册”。这个函数在 AAE 的 pipeline 里,接下来就可以在 Unet 运作的时候直接读出 attention map,这是怎么做到的?让我们来看看。

这里就要理解 python 的特性了:def 中 list 作为形参的话,形参的改变会导致实参的改变(和 C 一样)。以及,a=b 这样的 list 赋值其实是对象引用的赋值,a 变了 b 也会跟着变。知道这两个特性以后我们就可以理解这样的“实时更新”是怎么做到的了。

register_attention_control 是 AAE pipeline 的一个函数,被 layout 的 pipeline 继承了。

    def register_attention_control(self):
        attn_procs = {}
        cross_att_count = 0
        for name in self.unet.attn_processors.keys():
            if name.startswith("mid_block"):
                place_in_unet = "mid"
            elif name.startswith("up_blocks"):
                place_in_unet = "up"
            elif name.startswith("down_blocks"):
                place_in_unet = "down"
            else:
                continue

            cross_att_count += 1
            attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet)

        self.unet.set_attn_processor(attn_procs)
        self.attention_store.num_att_layers = cross_att_count

这里面调用了 unet.attn_processors,我们先看这个函数:

function: attn_processors

    @property
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

@property 装饰器是使得函数可以和属性一样被访问,比如外层调用的时候就没有加括号。

named_children 是 torch.nn.Module 的函数,可以返回一个字典,包含模型的所有子模块和其名字,这是 GPT 给的例子:

import torch
import torch.nn as nn
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
model = SimpleModel()
for name, child in model.named_children():
    print(name, child)
'''
输出为:
fc1 Linear(in_features=10, out_features=5, bias=True)
fc2 Linear(in_features=5, out_features=1, bias=True)
'''

值得注意的是,模块是基于 nn.Module 的实例,你可以自定义模块,那么 named_children 也会返回你自定义模块的类。那么这段代码也可以理解了。因为 Unet 包含了很多子模块,子模块里也有很多子模块(比如一个下采样模块包含了 resnet、transformer2d 等),所以这个函数用递归的形式遍历所有子模块。如果出现了 set_processor 的 attribution,那么说明这是 attention 模块,那么就把它加入到一个字典里,表示需要读取。现在我们看哪些子模块有 set_processor,就可以知道 attention 存在哪里。

但是我 tm 没找到。所以我打开了远程算力把这东西 print 出来。

'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor object at 0x7f0b03228790>

是存在一个 down_blocks 的数组里,然后里面有 attentions 的数组,然后里面是 transformer_blocks 的数组,然后里面有内置了 attn,

        self.transformer_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    num_embeds_ada_norm=num_embeds_ada_norm,
                    attention_bias=attention_bias,
                    only_cross_attention=only_cross_attention,
                    upcast_attention=upcast_attention,
                    norm_type=norm_type,
                    norm_elementwise_affine=norm_elementwise_affine,
                )
                for d in range(num_layers)
            ]
        )

所以这个函数的功能就是,把 unet 中所有有 attention 的部分扒出来,放在一个字典里给你。

Attention and its processer

我们扒出 Attention 在哪里。在一个 BasicTransformerBlock 的 class 里,里面调用了 Attention。这个 Attention 来自库里实现的 attention_processor。

        self.attn1 = Attention(
            query_dim=dim,                             # Q 的通道数
            heads=num_attention_heads,                 # multihead-attention 里的 head 数 
            dim_head=attention_head_dim,               # 一个 head 的通道数,也就是 K 和 Q 的通道数
            dropout=dropout,                           # dropout 正则化的概率,默认 0
            bias=attention_bias,                       # 全连接网络加不加 bias
            cross_attention_dim=cross_attention_dim if only_cross_attention else None, # encoder(CLIP) 编码后每一个 token 的通道数
            upcast_attention=upcast_attention,
        )

每次前向传播时,以概率 dropout 将输入张量的一些元素设置为零,其余元素按比例缩放。这样可以在模型训练过程中随机地“丢弃”一些神经元的输出,以减少神经网络的过拟合程度,从而提高模型的泛化能力。

然后我点进去大致浏览了一下,能看到很多之前了解过的东西(不是)。其用一个叫做 processor 的对象来完成 forward 操作。他的 forward 函数就是直接调用…

    def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
        # The `Attention` class can call different attention processors / attention functions
        # here we simply pass along all tensors to the selected processor class
        # For standard processors that are defined here, `**cross_attention_kwargs` is empty
        return self.processor(
            self,
            hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )

其中,to_q 函数是 nn.Module.Linear(query_dim, dim_head * heads)

to_k, to_v 函数是 nn.Module.Linear(cross_attention_dim, dim_head * heads)

最基本的 processor 是 AttnProcessor,代码是这样的:

class AttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __call__(
        self,
        attn: Attention,
        hidden_states,              # 图像编码后的空间,也就是要成为 Q 的空间
        encoder_hidden_states=None, # prompt 编码后的空间,也就是要成为 K V 的空间
        attention_mask=None,
        temb=None,
    ):
        # 多头注意力机制

现在重点看这个函数:

    def set_processor(self, processor: "AttnProcessor"):
        # if current processor is in `self._modules` and if passed `processor` is not, we need to
        # pop `processor` from `self._modules`
        if (
            hasattr(self, "processor")
            and isinstance(self.processor, torch.nn.Module)
            and not isinstance(processor, torch.nn.Module)
        ):
            logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
            self._modules.pop("processor")

        self.processor = processor

相当于提供了更换 processor 的一个接口。(真妙啊!)

function: set_attn_processor

    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)

那么就是把 Unet 的所有 attention 的 processor 全换成给定的。

function: register_attention_control

我们回到一开始的 register 函数:

    def register_attention_control(self):
        attn_procs = {}
        cross_att_count = 0
        for name in self.unet.attn_processors.keys(): # 把存在 attention 的拔下来
            if name.startswith("mid_block"):
                place_in_unet = "mid"
            elif name.startswith("up_blocks"):
                place_in_unet = "up"
            elif name.startswith("down_blocks"):
                place_in_unet = "down"
            else:
                continue

            cross_att_count += 1
            attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet) # 换成给定的 processor

        self.unet.set_attn_processor(attn_procs)                    # 递归修改
        self.attention_store.num_att_layers = cross_att_count

class: AttendExciteAttnProcessor

可以发现,这就是一层包装,然后把 AttendExciteAttnProcessor 创建出来,全部替换。至于 AttendExciteAttnProcessor 的实现,很简单,甚至删掉了 diffusers 源码的特判。

这里也着重介绍一下 head_to_batch_dimbatch_to_head_dim

# 关于这两个函数,是 multihead-attention 的,单独放在这里方便理解
 
    def head_to_batch_dim(self, tensor, out_dim=3): # 经过了 to_* 以后,末端 dim 包含了 head 个头的部分
        head_size = self.heads
        batch_size, seq_len, dim = tensor.shape
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) # 所以我们把 head 单独拎出来
        tensor = tensor.permute(0, 2, 1, 3) # 然后把 head 扔到前面
        if out_dim == 3:
            tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) # 把 head 融合进 batch_size 里
        return tensor

    def batch_to_head_dim(self, tensor): 
        head_size = self.heads
        batch_size, seq_len, dim = tensor.shape
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) # 之前 head 融进 batch_size 了,现在分离出来
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) # head 融进最后一层
        return tensor

class AttendExciteAttnProcessor:
    def __init__(self, attnstore, place_in_unet): # attenstore 和 place_in_unet
        super().__init__()
        self.attnstore = attnstore
        self.place_in_unet = place_in_unet

    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        query = attn.to_q(hidden_states)
        is_cross = encoder_hidden_states is not None
        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states # 不是交叉的话,自己和自己做
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        if attention_probs.requires_grad: # 存储 attention map
            self.attnstore(attention_probs, is_cross, self.place_in_unet)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)
        hidden_states = attn.to_out[0](hidden_states) # 过全连接
        hidden_states = attn.to_out[1](hidden_states) # dropout 正则化
        return hidden_states

attention map 存在了 self.attnstore,也就是 pipeline 类的 self.attention_store 里,我们继续看 attention_store。这是外面定义的类了!终于到最外层了!

AttentionStore

class AttentionStore:
    @staticmethod # 不用传 self 的修饰器
    def get_empty_store():
        return {"down": [], "mid": [], "up": []}

    def __call__(self, attn, is_cross: bool, place_in_unet: str): # 从 processor 中获得 unet 的位置、是否是 cross_attention 的信息
        if is_cross:
            if attn.shape[1] in self.attn_res:
                self.step_store[place_in_unet].append(attn) # 存进字典

        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers: # 如果满了,那么存入答案,当前清空(感觉没必要这么麻烦)
            self.cur_att_layer = 0
            self.between_steps()

    def between_steps(self):
        self.attention_store = self.step_store
        self.step_store = self.get_empty_store()

    def maps(self, block_type: str):
        return self.attention_store[block_type] # 读取 attention map

    def reset(self):
        self.cur_att_layer = 0
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self, attn_res=[256, 64]):
        """
        Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
        process
        """
        self.num_att_layers = -1
        self.cur_att_layer = 0
        self.step_store = self.get_empty_store() # step_store 存当前 step 正在做的,attention_store 存已完成的 step 的
        self.attention_store = {}
        self.curr_step_index = 0
        self.attn_res = attn_res

这是一个缓存的包装。我们来看看 loss 部分吧!

Loss

    def _compute_loss(self, token_indices, bboxes, device) -> torch.Tensor:
        loss = 0
        object_number = len(bboxes)
        total_maps = 0
        for location in ["mid", "up"]: # 论文里说的,只用考虑 mid 层和 up 层。这个接口其实是测试用的。
            for attn_map_integrated in self.attention_store.maps(location):
                attn_map = attn_map_integrated.chunk(2)[1]

                b, i, j = attn_map.shape
                H = W = int(np.sqrt(i))

                total_maps += 1
                for obj_idx in range(object_number): # 枚举 box
                    obj_loss = 0
                    obj_box = bboxes[obj_idx]

                    x_min, y_min, x_max, y_max = ( # 本来正则化的,变成像素
                        obj_box[0] * W,
                        obj_box[1] * H,
                        obj_box[2] * W,
                        obj_box[3] * H,
                    )
                    mask = torch.zeros((H, W), device=device)
                    mask[round(y_min) : round(y_max), round(x_min) : round(x_max)] = 1

                    for obj_position in token_indices[obj_idx]: # 枚举 box 内的物体
                        ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) # attention map 第一维是 batch_size * head,第二维是图像 Q,第三维是 token KV
                        activation_value = (ca_map_obj * mask).reshape(b, -1).sum(
                            dim=-1
                        ) / ca_map_obj.reshape(b, -1).sum(dim=-1) # 算出 masked 的 / 所有的

                        obj_loss += torch.mean((1 - activation_value) ** 2)

                    loss += obj_loss / len(token_indices[obj_idx])

        loss /= object_number * total_maps
        return loss

至此,我们把整个逻辑框架都看明白了。