stein

class ksd_metric.stein.KernelSteinDiscrepancyInterface(target, kernel)[source]

Bases: ABC

__init__(target, kernel)[source]

Initializes the KernelSteinDiscrepancyInterface with a target distribution.

Parameters:
abstract stein_kernel(x, y)[source]

Computes the Stein kernel using the base kernel function and the gradient of the log target PDF.

Parameters:
  • x (Float[Array, "num"]) – Input data point.

  • y (Float[Array, "num"]) – Input data point.

Returns:

The value of the Stein kernel at (x, y).

Return type:

Float[Array, “num”]

abstract kernel_stein_discrepancy(samples)[source]

Computes the kernel Stein discrepancy for the given samples.

Parameters:

samples (ArrayLike) – A collection of samples from the target distribution.

Returns:

The value of the kernel Stein discrepancy for the given samples.

Return type:

float

class ksd_metric.stein.KernelSteinDiscrepancyJax(target, kernel)[source]

Bases: KernelSteinDiscrepancyInterface

__init__(target, kernel)[source]

Initializes the KernelSteinDiscrepancyJax with a target distribution and a kernel.

Parameters:
stein_kernel(x, y)[source]

Computes the Stein kernel using the base kernel function and the gradient of the log target PDF.

Parameters:
  • x (Float[Array, "num"]) – Input data point.

  • y (Float[Array, "num"]) – Input data point.

Returns:

The value of the Stein kernel at (x, y).

Return type:

Float[Array, “num”]

kernel_stein_discrepancy(samples)[source]

Computes the kernel Stein discrepancy for the given samples.

Parameters:

samples (Float[Array, "num dim"]) – A collection of samples from the target distribution.

Returns:

The value of the kernel Stein discrepancy for the given samples.

Return type:

float