utils

class ksd_metric.utils.JaxKernelFunction[source]

Bases: object

A class to represent a JAX kernel function. This class is used to encapsulate the JAX kernel function and its associated parameters.

static rbf(x, y, sigma=1.0)[source]

Radial Basis Function (RBF) kernel, also known as Gaussian kernel.

Parameters:
  • x (Float[Array, "num"]) – Input data point.

  • y (Float[Array, "num"]) – Input data point.

  • sigma (float) – The bandwidth parameter for the RBF kernel.

Returns:

The value of the RBF kernel at (x, y).

Return type:

Float[Array, “num”]

static linear(x, y)[source]

Linear kernel function.

Parameters:
  • x (Float[Array, "num"]) – Input data point.

  • y (Float[Array, "num"]) – Input data point.

Returns:

The value of the linear kernel at (x, y).

Return type:

Float[Array, “num”]

static imq(x, y, linv, beta=0.5)[source]

Inverse Multiquadric kernel function.

Parameters:
  • x (Float[Array, "num"]) – Input data point.

  • y (Float[Array, "num"]) – Input data point.

  • linv (Float[Array, "num num"]) – Inverse of the length scale matrix.

  • beta (float) – The shape parameter for the inverse multiquadric kernel.

Returns:

The value of the inverse multiquadric kernel at (x, y).

Return type:

Float[Array, “num”]