utils
- class ksd_metric.utils.JaxKernelFunction[source]
Bases:
objectA class to represent a JAX kernel function. This class is used to encapsulate the JAX kernel function and its associated parameters.
- static rbf(x, y, sigma=1.0)[source]
Radial Basis Function (RBF) kernel, also known as Gaussian kernel.
- Parameters:
x (Float[Array, "num"]) – Input data point.
y (Float[Array, "num"]) – Input data point.
sigma (float) – The bandwidth parameter for the RBF kernel.
- Returns:
The value of the RBF kernel at (x, y).
- Return type:
Float[Array, “num”]
- static linear(x, y)[source]
Linear kernel function.
- Parameters:
x (Float[Array, "num"]) – Input data point.
y (Float[Array, "num"]) – Input data point.
- Returns:
The value of the linear kernel at (x, y).
- Return type:
Float[Array, “num”]
- static imq(x, y, linv, beta=0.5)[source]
Inverse Multiquadric kernel function.
- Parameters:
x (Float[Array, "num"]) – Input data point.
y (Float[Array, "num"]) – Input data point.
linv (Float[Array, "num num"]) – Inverse of the length scale matrix.
beta (float) – The shape parameter for the inverse multiquadric kernel.
- Returns:
The value of the inverse multiquadric kernel at (x, y).
- Return type:
Float[Array, “num”]