kernel
- class ksd_metric.kernel.KernelInterface(base_kernel_function)[source]
Bases:
ABC- __init__(base_kernel_function)[source]
Initializes the KernelInterface with a base kernel function.
- Parameters:
base_kernel_function (Callable[[ArrayLike, ArrayLike], ArrayLike]) – A function that computes the kernel between two points.
- abstract base_kernel_function(x, y)[source]
Base kernel function to be implemented by subclasses. This function should return the base kernel function used in KSD.
- Parameters:
x (ArrayLike) – Input data point.
y (ArrayLike) – Input data point.
- Returns:
The value of the base kernel function at (x, y).
- Return type:
ArrayLike
- abstract partial_derivative_x_kernel_function(x, y)[source]
Computes the partial derivative of the kernel function with respect to x.
- Parameters:
x (ArrayLike) – Input data point.
y (ArrayLike) – Input data point.
- Returns:
The partial derivative of the kernel function with respect to x.
- Return type:
ArrayLike
- abstract partial_derivative_y_kernel_function(x, y)[source]
Computes the partial derivative of the kernel function with respect to y.
- Parameters:
x (ArrayLike) – Input data point.
y (ArrayLike) – Input data point.
- Returns:
The partial derivative of the kernel function with respect to y.
- Return type:
ArrayLike
- abstract cross_partial_derivative_kernel_function(x, y)[source]
Computes the cross partial derivative of the kernel function with respect to x and y.
- Parameters:
x (ArrayLike) – Input data point.
y (ArrayLike) – Input data point.
- Returns:
The cross partial derivative of the kernel function with respect to x and y.
- Return type:
ArrayLike
- class ksd_metric.kernel.KernelJax(base_kernel_function)[source]
Bases:
KernelInterfaceRepresents a kernel function for the KSD.
- __init__(base_kernel_function)[source]
Initializes the kernel with a base kernel function.
- Parameters:
base_kernel_function (DifferentiableFunction) – A function that computes the kernel between two points.
- base_kernel_function(x, y)[source]
Base kernel function to be implemented by subclasses. This function should return the base kernel function used in KSD.
- Parameters:
x (Float[Array, "num"]) – Input data point.
y (Float[Array, "num"]) – Input data point.
- Returns:
The value of the base kernel function at (x, y).
- Return type:
Float[Array, “num”]
- partial_derivative_x_kernel_function(x, y)[source]
Computes the partial derivative of the kernel function with respect to x.
- Parameters:
x (Float[Array, "num"]) – Input data point.
y (Float[Array, "num"]) – Input data point.
- Returns:
The partial derivative of the kernel function with respect to x.
- Return type:
Float[Array, “num”]
- partial_derivative_y_kernel_function(x, y)[source]
Computes the partial derivative of the kernel function with respect to y.
- Parameters:
x (Float[Array, "num"]) – Input data point.
y (Float[Array, "num"]) – Input data point.
- Returns:
The partial derivative of the kernel function with respect to y.
- Return type:
Float[Array, “num”]
- cross_partial_derivative_kernel_function(x, y)[source]
Computes the cross partial derivative of the kernel function with respect to x and y.
- Parameters:
x (Float[Array, "num"]) – Input data point.
y (Float[Array, "num"]) – Input data point.
- Returns:
The cross partial derivative of the kernel function with respect to x and y.
- Return type:
Float[Array, “num”]