Source code for ksd_metric.target.target

from abc import ABC, abstractmethod

from jax import jacfwd, jit
from jaxtyping import Array, Float

from ..typing import ArrayLike, DifferentiableFunction


[docs] class TargetDistributionInterface(ABC): """ Interface for target distributions used in Kernel Stein Discrepancy (KSD). """
[docs] @abstractmethod def log_target_pdf( self, x: ArrayLike, ) -> ArrayLike: """ Logarithmic Probability Density Function. Args: x (ArrayLike): Input data point. Returns: (ArrayLike): Logarithmic probability density at x. """ raise NotImplementedError("log_target_pdf method must be implemented.")
[docs] @abstractmethod def grad_log_target_pdf( self, x: ArrayLike, ) -> ArrayLike: """ Gradient of Logarithmic Probability Density Function. Args: x (ArrayLike): Input data point. Returns: (ArrayLike): Gradient of logarithmic probability density at x. """ raise NotImplementedError("grad_log_target_pdf method must be implemented.")
[docs] class TargetDistributionJax(TargetDistributionInterface): """ Represents a target distribution for the KSD. """
[docs] def __init__( self, log_target_pdf: DifferentiableFunction, ) -> None: """ Initializes the target distribution with a logarithmic probability density function. Args: log_target_pdf (DifferentiableFunction): A function that computes the log of the target PDF. """ self._log_target_pdf_func = log_target_pdf
[docs] def log_target_pdf( self, x: Float[Array, "num dim"], ) -> Float[Array, "num dim"]: """ Computes the log probability density at the given point. Args: x (Float[Array, "num dim"]): Input data point. Returns: Float[Array, "num dim"]: Logarithmic probability density at x. """ return self._log_target_pdf_func(x)
[docs] def grad_log_target_pdf( self, x: Float[Array, "num dim"], ) -> Float[Array, "num dim"]: """ Gradient of Logarithmic Probability Density Function. Args: x (Float[Array, "num dim"]): Input data point. Returns: Float[Array, "num dim"]: Gradient of logarithmic probability density at x. """ return jit(jacfwd(self._log_target_pdf_func))(x)