kernel

class ksd_metric.kernel.KernelInterface(base_kernel_function)[source]

Bases: ABC

__init__(base_kernel_function)[source]

Initializes the KernelInterface with a base kernel function.

Parameters:

base_kernel_function (Callable[[ArrayLike, ArrayLike], ArrayLike]) – A function that computes the kernel between two points.

abstract base_kernel_function(x, y)[source]

Base kernel function to be implemented by subclasses. This function should return the base kernel function used in KSD.

Parameters:
  • x (ArrayLike) – Input data point.

  • y (ArrayLike) – Input data point.

Returns:

The value of the base kernel function at (x, y).

Return type:

ArrayLike

abstract partial_derivative_x_kernel_function(x, y)[source]

Computes the partial derivative of the kernel function with respect to x.

Parameters:
  • x (ArrayLike) – Input data point.

  • y (ArrayLike) – Input data point.

Returns:

The partial derivative of the kernel function with respect to x.

Return type:

ArrayLike

abstract partial_derivative_y_kernel_function(x, y)[source]

Computes the partial derivative of the kernel function with respect to y.

Parameters:
  • x (ArrayLike) – Input data point.

  • y (ArrayLike) – Input data point.

Returns:

The partial derivative of the kernel function with respect to y.

Return type:

ArrayLike

abstract cross_partial_derivative_kernel_function(x, y)[source]

Computes the cross partial derivative of the kernel function with respect to x and y.

Parameters:
  • x (ArrayLike) – Input data point.

  • y (ArrayLike) – Input data point.

Returns:

The cross partial derivative of the kernel function with respect to x and y.

Return type:

ArrayLike

class ksd_metric.kernel.KernelJax(base_kernel_function)[source]

Bases: KernelInterface

Represents a kernel function for the KSD.

__init__(base_kernel_function)[source]

Initializes the kernel with a base kernel function.

Parameters:

base_kernel_function (DifferentiableFunction) – A function that computes the kernel between two points.

base_kernel_function(x, y)[source]

Base kernel function to be implemented by subclasses. This function should return the base kernel function used in KSD.

Parameters:
  • x (Float[Array, "num"]) – Input data point.

  • y (Float[Array, "num"]) – Input data point.

Returns:

The value of the base kernel function at (x, y).

Return type:

Float[Array, “num”]

partial_derivative_x_kernel_function(x, y)[source]

Computes the partial derivative of the kernel function with respect to x.

Parameters:
  • x (Float[Array, "num"]) – Input data point.

  • y (Float[Array, "num"]) – Input data point.

Returns:

The partial derivative of the kernel function with respect to x.

Return type:

Float[Array, “num”]

partial_derivative_y_kernel_function(x, y)[source]

Computes the partial derivative of the kernel function with respect to y.

Parameters:
  • x (Float[Array, "num"]) – Input data point.

  • y (Float[Array, "num"]) – Input data point.

Returns:

The partial derivative of the kernel function with respect to y.

Return type:

Float[Array, “num”]

cross_partial_derivative_kernel_function(x, y)[source]

Computes the cross partial derivative of the kernel function with respect to x and y.

Parameters:
  • x (Float[Array, "num"]) – Input data point.

  • y (Float[Array, "num"]) – Input data point.

Returns:

The cross partial derivative of the kernel function with respect to x and y.

Return type:

Float[Array, “num”]