Stable Diffusion EDM Sampler详细解释

[复制链接]
发表于 2024-9-13 15:12:25 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

×
EDM (Euler Discretization with Momentum): EDM代表了欧拉离散化并带有动量的方法,它通常是对连续时间扩散过程进行数值积分的一种变体,通过引入动量项来改进收敛性和稳定性。
    
   在 EDMSampler 中,重要原理可以概括为以下几点:
   

  • 扩散过程:

    • 在训练阶段,扩散模子学习怎样将带有不同水平噪声的数据徐徐还原至无噪声的状态。
    • 这个过程可以视为一系列连续的概率转换,每个步调对应一个噪声水平(即 sigma)。

  • Euler方法:

    • 该类中的 sampler_step 方法使用了Euler方法进行离散化处置惩罚,这是一种数值积分技能,用于近似解决微分方程。
    • 对于扩散模子来说,这个离散化的一步就是根据当前噪声水平和下一步的目标噪声水平,盘算并应用相应的更新到采样状态上。

  • 动态调整:

    • EDMSampler 包含了一些自定义参数如 s_churn, s_tmin, s_tmax,它们影响着每一步中加入额外随机性的程度(通过 gamma 参数控制)。
    • 当噪声水平在特定范围内时(由 s_tmin 和 s_tmax 定义),会引入额外的随机扰动,这有助于进步采样的多样性和避免陷入局部最优解。

  • 去噪与步进:

    • 每一步迭代包括:盘算修正后的噪声水平、给当前状态注入噪声、使用预训练的扩散模子进行去噪,并执行欧拉步进更新状态。
    • 可能还包含一个校正步调来进一步优化采样结果的质量。

  • 整体流程:

    • __call__ 方法被计划成可以直接调用启动采样过程的形式,它按照预定的噪声水平序列逐级降低噪声强度,直至最终生成高质量的新样本。

   决定能否看懂代码的重点!!!
   euler_step = self.euler_step(x, d, dt) 这段代码是在求解微分方程
   
   

  • x 是当前状态变量,即带有某个噪声级别的样本。
  • d 通常表示在这个噪声级别下对样本进行去噪操作后的结果与原噪声样本之间的差异
  • dt 是时间步长或者说噪声水平的变革量,即 next_sigma - sigma_hat。
   

   说得简单点就是,花个坐标轴x,y; 现在x轴上某个点的值已知(当前带有噪声的样本),现在x轴方向变革了dt也已知,导致了y轴的变革量d也已知,求斜率(斜率即微分方程的解)
  
  1. class EDMSampler(SingleStepDiffusionSampler):
  2.     def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
  3.         super().__init__(*args, **kwargs)
  4.         self.s_churn = s_churn
  5.         self.s_tmin = s_tmin
  6.         self.s_tmax = s_tmax
  7.         self.s_noise = s_noise
  8.     def sampler_step(self, sigma, next_sigma, model, x, cond, uc=None, gamma=0.0, **kwargs):
  9.         sigma_hat = sigma * (gamma + 1.0)
  10.         if gamma > 0:
  11.             eps = Tensor(np.random.randn(*x.shape), x.dtype) * self.s_noise
  12.             x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
  13.         denoised = self.denoise(x, model, sigma_hat, cond, uc, **kwargs)
  14.         d = to_d(x, sigma_hat, denoised)
  15.         dt = append_dims(next_sigma - sigma_hat, x.ndim)
  16.         euler_step = self.euler_step(x, d, dt) #核心,在解微分方程
  17.         x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, model, cond, uc)
  18.         return x
  19.     def __call__(self, model, x, cond, uc=None, num_steps=None, **kwargs):
  20.         x = ops.cast(x, ms.float32)
  21.         x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
  22.         for i in self.get_sigma_gen(num_sigmas):
  23.             gamma = (
  24.                 min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
  25.             )
  26.             x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], model, x, cond, uc, gamma, **kwargs)
  27.         return x
  28. #只想搞懂原理的话,下面的依赖可以不看
  29. class SingleStepDiffusionSampler(BaseDiffusionSampler):
  30.     def sampler_step(self, sigma, next_sigma, model, x, cond, uc=None, gamma=0.0, **kwargs):
  31.         sigma_hat = sigma * (gamma + 1.0)
  32.         if gamma > 0:
  33.             eps = Tensor(np.random.randn(*x.shape), x.dtype) * self.s_noise
  34.             x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
  35.         denoised = self.denoise(x, model, sigma_hat, cond, uc, **kwargs)
  36.         d = to_d(x, sigma_hat, denoised)
  37.         dt = append_dims(next_sigma - sigma_hat, x.ndim)
  38.         euler_step = self.euler_step(x, d, dt)
  39.         x = euler_step
  40.         return x
  41.     def euler_step(self, x, d, dt):
  42.         return x + dt * d
  43. class BaseDiffusionSampler:
  44.     def __init__(
  45.         self,
  46.         discretization_config: Union[Dict, ListConfig, OmegaConf],
  47.         num_steps: Union[int, None] = None,
  48.         guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
  49.         verbose: bool = False,
  50.     ):
  51.         self.num_steps = num_steps
  52.         self.discretization = instantiate_from_config(discretization_config)
  53.         self.guider = instantiate_from_config(
  54.             default(
  55.                 guider_config,
  56.                 DEFAULT_GUIDER,
  57.             )
  58.         )
  59.         self.verbose = verbose
  60.     def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
  61.         sigmas = self.discretization(self.num_steps if num_steps is None else num_steps)
  62.         uc = default(uc, cond)
  63.         x *= Tensor(np.sqrt(1.0 + sigmas[0] ** 2.0), x.dtype)
  64.         num_sigmas = len(sigmas)
  65.         s_in = ops.ones((x.shape[0],), x.dtype)
  66.         return x, s_in, sigmas, num_sigmas, cond, uc
  67.     def denoise(self, x, model, sigma, cond, uc, **kwargs):
  68.         noised_input, sigmas, cond = self.guider.prepare_inputs(x, sigma, cond, uc)
  69.         cond = model.openai_input_warpper(cond)
  70.         c_skip, c_out, c_in, c_noise = model.denoiser(sigmas, noised_input.ndim)
  71.         model_output = model.model(noised_input * c_in, c_noise, **cond, **kwargs)
  72.         model_output = model_output.astype(ms.float32)
  73.         denoised = model_output * c_out + noised_input * c_skip
  74.         denoised = self.guider(denoised, sigma)
  75.         return denoised
  76.     def get_sigma_gen(self, num_sigmas):
  77.         sigma_generator = range(num_sigmas - 1)
  78.         if self.verbose:
  79.             print("#" * 30, " Sampling setting ", "#" * 30)
  80.             print(f"Sampler: {self.__class__.__name__}")
  81.             print(f"Discretization: {self.discretization.__class__.__name__}")
  82.             print(f"Guider: {self.guider.__class__.__name__}")
  83.             sigma_generator = tqdm(
  84.                 sigma_generator,
  85.                 total=(num_sigmas - 1),
  86.                 desc=f"Sampling with {self.__class__.__name__} for {(num_sigmas - 1)} steps",
  87.             )
  88.         return sigma_generator
复制代码
   
   
    

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表