Customized Using Interface
Here is a complete code example demonstrating how to use Customized target distribution and base kernel function.
1import jax
2from jax import numpy as jnp
3from jaxtyping import Array, Float
4
5from ksd_metric.kernel import KernelJax
6from ksd_metric.stein import KernelSteinDiscrepancyJax
7from ksd_metric.target import TargetDistributionInterface
8
9key = jax.random.PRNGKey(42)
10dim = 2
11mean = jnp.zeros(dim)
12cov = jnp.eye(dim)
13N = 10_000
14x = jax.random.multivariate_normal(key, mean, cov, shape=(N,))
15
16
17class CustomTargetDistribution(TargetDistributionInterface):
18 def log_target_pdf(self, x: Float[Array, "num dim"]) -> Float[Array, "num dim"]:
19 return -0.5 * x @ x
20
21 def grad_log_target_pdf(
22 self, x: Float[Array, "num dim"]
23 ) -> Float[Array, "num dim"]:
24 return -x
25
26
27def custom_kernel(
28 x: Float[Array, "num"],
29 y: Float[Array, "num"],
30) -> Float[Array, "num"]:
31 dim = len(x)
32 diff = x - y
33 return (1.0 + (diff @ jnp.eye(dim) @ diff)) ** (-0.5)
34
35
36# Define the target distribution and kernel
37target = CustomTargetDistribution()
38kernel = KernelJax(custom_kernel)
39ksd = KernelSteinDiscrepancyJax(target=target, kernel=kernel)
40
41# Calculate kernel Stein discrepancy
42res = ksd.kernel_stein_discrepancy(x)
43print(res)