Source code for ksd_metric.stein.stein

from abc import ABC, abstractmethod
from functools import partial

import jax
from jax import numpy as jnp
from jaxtyping import Array, Float

from ..kernel import KernelInterface, KernelJax
from ..target import TargetDistributionInterface, TargetDistributionJax
from ..typing import ArrayLike


[docs] class KernelSteinDiscrepancyInterface(ABC):
[docs] def __init__( self, target: TargetDistributionInterface, kernel: KernelInterface, ) -> None: """ Initializes the KernelSteinDiscrepancyInterface with a target distribution. Args: target (TargetDistributionInterface): The target distribution interface. kernel (KernelInterface): The kernel interface used for computing the Stein kernel. """ self.target = target self.kernel = kernel
[docs] @abstractmethod def stein_kernel( self, x: ArrayLike, y: ArrayLike, ) -> ArrayLike: """ Computes the Stein kernel using the base kernel function and the gradient of the log target PDF. Args: x (Float[Array, "num"]): Input data point. y (Float[Array, "num"]): Input data point. Returns: Float[Array, "num"]: The value of the Stein kernel at (x, y). """ raise NotImplementedError("stein_kernel must be implemented.")
[docs] @abstractmethod def kernel_stein_discrepancy( self, samples: ArrayLike, ) -> float: """ Computes the kernel Stein discrepancy for the given samples. Args: samples (ArrayLike): A collection of samples from the target distribution. Returns: float: The value of the kernel Stein discrepancy for the given samples. """ raise NotImplementedError("kernel_stein_discrepancy must be implemented.")
[docs] class KernelSteinDiscrepancyJax(KernelSteinDiscrepancyInterface):
[docs] def __init__( self, target: TargetDistributionJax, kernel: KernelJax, ) -> None: """ Initializes the KernelSteinDiscrepancyJax with a target distribution and a kernel. Args: target (TargetDistributionJax): The target distribution. kernel (KernelJax): The kernel function. """ super().__init__(target=target, kernel=kernel)
[docs] @partial(jax.jit, static_argnums=(0,)) def stein_kernel( self, x: Float[Array, "num"], y: Float[Array, "num"], ) -> Float[Array, "num"]: """ Computes the Stein kernel using the base kernel function and the gradient of the log target PDF. Args: x (Float[Array, "num"]): Input data point. y (Float[Array, "num"]): Input data point. Returns: Float[Array, "num"]: The value of the Stein kernel at (x, y). """ dx_k = self.kernel.partial_derivative_x_kernel_function(x, y) dy_k = self.kernel.partial_derivative_y_kernel_function(x, y) dxdy_k = self.kernel.cross_partial_derivative_kernel_function(x, y) grad_log_p_x = self.target.grad_log_target_pdf(x) grad_log_p_y = self.target.grad_log_target_pdf(y) return ( jnp.trace(dxdy_k) + dx_k @ grad_log_p_y + dy_k @ grad_log_p_x + self.kernel.base_kernel_function(x, y) * grad_log_p_x @ grad_log_p_y )
@partial(jax.jit, static_argnums=(0,)) def _vectorised_kernel_stein_discrepancy( self, samples: Float[Array, "num dim"], ) -> Float[Array, "float"]: """ Computes the kernel Stein discrepancy for a collection of samples. Args: samples (Float[Array, "num dim"]): A collection of samples from the target distribution. Returns: Float[Array, "float"]: The value of the kernel Stein discrepancy for the given samples """ num = samples.shape[0] k = jax.vmap(jax.vmap(self.stein_kernel, in_axes=(None, 0)), in_axes=(0, None))( samples, samples ) ksd = (1.0 / num) * jnp.sqrt(jnp.sum(k)) return ksd
[docs] def kernel_stein_discrepancy(self, samples: Float[Array, "num dim"]) -> float: """ Computes the kernel Stein discrepancy for the given samples. Args: samples (Float[Array, "num dim"]): A collection of samples from the target distribution. Returns: float: The value of the kernel Stein discrepancy for the given samples. """ return self._vectorised_kernel_stein_discrepancy(samples).item()