Quick Start
Here is a complete code example demonstrating how to use KernelSteinDiscrepancyJax
:
1import jax
2from jax import numpy as jnp
3from jax.scipy.stats import multivariate_normal
4
5from ksd_metric.kernel import KernelJax
6from ksd_metric.stein import KernelSteinDiscrepancyJax
7from ksd_metric.target import TargetDistributionJax
8from ksd_metric.utils import JaxKernelFunction
9
10# Example usage of KernelSteinDiscrepancyJax
11key = jax.random.PRNGKey(42)
12dim = 2
13mean = jnp.zeros(dim)
14cov = jnp.eye(dim)
15N = 10_000
16x = jax.random.multivariate_normal(key, mean, cov, shape=(N,))
17
18# Define the target distribution and kernel
19log_target_pdf = lambda x: multivariate_normal.logpdf(x, mean=mean, cov=cov)
20target = TargetDistributionJax(log_target_pdf=log_target_pdf)
21kernel = KernelJax(lambda x, y: JaxKernelFunction.imq(x, y, jnp.eye(dim)))
22ksd = KernelSteinDiscrepancyJax(target=target, kernel=kernel)
23
24# Calculate kernel Stein discrepancy
25res = ksd.kernel_stein_discrepancy(x)
26print(res)