stein
- class ksd_metric.stein.KernelSteinDiscrepancyInterface(target, kernel)[source]
Bases:
ABC
- __init__(target, kernel)[source]
Initializes the KernelSteinDiscrepancyInterface with a target distribution.
- Parameters:
target (TargetDistributionInterface) – The target distribution interface.
kernel (KernelInterface) – The kernel interface used for computing the Stein kernel.
- abstract stein_kernel(x, y)[source]
Computes the Stein kernel using the base kernel function and the gradient of the log target PDF.
- Parameters:
x (Float[Array, "num"]) – Input data point.
y (Float[Array, "num"]) – Input data point.
- Returns:
The value of the Stein kernel at (x, y).
- Return type:
Float[Array, “num”]
- class ksd_metric.stein.KernelSteinDiscrepancyJax(target, kernel)[source]
Bases:
KernelSteinDiscrepancyInterface
- __init__(target, kernel)[source]
Initializes the KernelSteinDiscrepancyJax with a target distribution and a kernel.
- Parameters:
target (TargetDistributionJax) – The target distribution.
kernel (KernelJax) – The kernel function.
- stein_kernel(x, y)[source]
Computes the Stein kernel using the base kernel function and the gradient of the log target PDF.
- Parameters:
x (Float[Array, "num"]) – Input data point.
y (Float[Array, "num"]) – Input data point.
- Returns:
The value of the Stein kernel at (x, y).
- Return type:
Float[Array, “num”]