Source code for ksd_metric.utils.utils

from jax import numpy as jnp
from jaxtyping import Array, Float


[docs] class JaxKernelFunction: """ A class to represent a JAX kernel function. This class is used to encapsulate the JAX kernel function and its associated parameters. """
[docs] @staticmethod def rbf( x: Float[Array, "num"], y: Float[Array, "num"], sigma: float = 1.0, ) -> Float[Array, "num"]: """ Radial Basis Function (RBF) kernel, also known as Gaussian kernel. Args: x (Float[Array, "num"]): Input data point. y (Float[Array, "num"]): Input data point. sigma (float): The bandwidth parameter for the RBF kernel. Returns: Float[Array, "num"]: The value of the RBF kernel at (x, y). """ diff = x - y return jnp.exp(-jnp.dot(diff, diff) / (2 * sigma**2))
[docs] @staticmethod def linear( x: Float[Array, "num"], y: Float[Array, "num"], ) -> Float[Array, "num"]: """ Linear kernel function. Args: x (Float[Array, "num"]): Input data point. y (Float[Array, "num"]): Input data point. Returns: Float[Array, "num"]: The value of the linear kernel at (x, y). """ return jnp.dot(x, y)
[docs] @staticmethod def imq( x: Float[Array, "num"], y: Float[Array, "num"], linv: Float[Array, "num num"], beta: float = 0.5, ) -> Float[Array, "num"]: """ Inverse Multiquadric kernel function. Args: x (Float[Array, "num"]): Input data point. y (Float[Array, "num"]): Input data point. linv (Float[Array, "num num"]): Inverse of the length scale matrix. beta (float): The shape parameter for the inverse multiquadric kernel. Returns: Float[Array, "num"]: The value of the inverse multiquadric kernel at (x, y). """ diff = x - y return (1.0 + (diff @ linv @ diff)) ** (-beta)