Source code for ksd_metric.kernel.kernel

from abc import ABC, abstractmethod
from typing import Callable

from jax import jacfwd, jit
from jaxtyping import Array, Float

from ..typing import ArrayLike, DifferentiableFunction


[docs] class KernelInterface(ABC):
[docs] def __init__( self, base_kernel_function: Callable[[ArrayLike, ArrayLike], ArrayLike], ) -> None: """ Initializes the KernelInterface with a base kernel function. Args: base_kernel_function (Callable[[ArrayLike, ArrayLike], ArrayLike]): A function that computes the kernel between two points. """ self._base_kernel_function = base_kernel_function
[docs] @abstractmethod def base_kernel_function( self, x: ArrayLike, y: ArrayLike, ) -> ArrayLike: """ Base kernel function to be implemented by subclasses. This function should return the base kernel function used in KSD. Args: x (ArrayLike): Input data point. y (ArrayLike): Input data point. Returns: ArrayLike: The value of the base kernel function at (x, y). """ raise NotImplementedError("base_kernel_function must be implemented.")
[docs] @abstractmethod def partial_derivative_x_kernel_function( self, x: ArrayLike, y: ArrayLike, ) -> ArrayLike: """ Computes the partial derivative of the kernel function with respect to x. Args: x (ArrayLike): Input data point. y (ArrayLike): Input data point. Returns: ArrayLike: The partial derivative of the kernel function with respect to x. """ raise NotImplementedError( "partial_derivative_x_kernel_function must be implemented." )
[docs] @abstractmethod def partial_derivative_y_kernel_function( self, x: ArrayLike, y: ArrayLike, ) -> ArrayLike: """ Computes the partial derivative of the kernel function with respect to y. Args: x (ArrayLike): Input data point. y (ArrayLike): Input data point. Returns: ArrayLike: The partial derivative of the kernel function with respect to y. """ raise NotImplementedError( "partial_derivative_y_kernel_function must be implemented." )
[docs] @abstractmethod def cross_partial_derivative_kernel_function( self, x: ArrayLike, y: ArrayLike, ) -> ArrayLike: """ Computes the cross partial derivative of the kernel function with respect to x and y. Args: x (ArrayLike): Input data point. y (ArrayLike): Input data point. Returns: ArrayLike: The cross partial derivative of the kernel function with respect to x and y. """ raise NotImplementedError( "cross_partial_derivative_kernel_function must be implemented." )
[docs] class KernelJax(KernelInterface): """ Represents a kernel function for the KSD. """
[docs] def __init__( self, base_kernel_function: DifferentiableFunction, ) -> None: """ Initializes the kernel with a base kernel function. Args: base_kernel_function (DifferentiableFunction): A function that computes the kernel between two points. """ super().__init__(base_kernel_function=base_kernel_function)
[docs] def base_kernel_function( self, x: Float[Array, "num"], y: Float[Array, "num"], ) -> Float[Array, "num"]: """ Base kernel function to be implemented by subclasses. This function should return the base kernel function used in KSD. Args: x (Float[Array, "num"]): Input data point. y (Float[Array, "num"]): Input data point. Returns: Float[Array, "num"]: The value of the base kernel function at (x, y). """ return self._base_kernel_function(x, y)
[docs] def partial_derivative_x_kernel_function( self, x: Float[Array, "num"], y: Float[Array, "num"], ) -> Float[Array, "num"]: """ Computes the partial derivative of the kernel function with respect to x. Args: x (Float[Array, "num"]): Input data point. y (Float[Array, "num"]): Input data point. Returns: Float[Array, "num"]: The partial derivative of the kernel function with respect to x. """ return jit(jacfwd(self.base_kernel_function, argnums=0))(x, y)
[docs] def partial_derivative_y_kernel_function( self, x: Float[Array, "num"], y: Float[Array, "num"], ) -> Float[Array, "num"]: """ Computes the partial derivative of the kernel function with respect to y. Args: x (Float[Array, "num"]): Input data point. y (Float[Array, "num"]): Input data point. Returns: Float[Array, "num"]: The partial derivative of the kernel function with respect to y. """ return jit(jacfwd(self.base_kernel_function, argnums=1))(x, y)
[docs] def cross_partial_derivative_kernel_function( self, x: Float[Array, "num"], y: Float[Array, "num"], ) -> Float[Array, "num"]: """ Computes the cross partial derivative of the kernel function with respect to x and y. Args: x (Float[Array, "num"]): Input data point. y (Float[Array, "num"]): Input data point. Returns: Float[Array, "num"]: The cross partial derivative of the kernel function with respect to x and y. """ return jit(jacfwd(self.partial_derivative_y_kernel_function, argnums=0))(x, y)