Customized Using Interface

Here is a complete code example demonstrating how to use Customized target distribution and base kernel function.

 1import jax
 2from jax import numpy as jnp
 3from jaxtyping import Array, Float
 4
 5from ksd_metric.kernel import KernelJax
 6from ksd_metric.stein import KernelSteinDiscrepancyJax
 7from ksd_metric.target import TargetDistributionInterface
 8
 9key = jax.random.PRNGKey(42)
10dim = 2
11mean = jnp.zeros(dim)
12cov = jnp.eye(dim)
13N = 10_000
14x = jax.random.multivariate_normal(key, mean, cov, shape=(N,))
15
16
17class CustomTargetDistribution(TargetDistributionInterface):
18    def log_target_pdf(self, x: Float[Array, "num dim"]) -> Float[Array, "num dim"]:
19        return -0.5 * x @ x
20
21    def grad_log_target_pdf(
22        self, x: Float[Array, "num dim"]
23    ) -> Float[Array, "num dim"]:
24        return -x
25
26
27def custom_kernel(
28    x: Float[Array, "num"],
29    y: Float[Array, "num"],
30) -> Float[Array, "num"]:
31    dim = len(x)
32    diff = x - y
33    return (1.0 + (diff @ jnp.eye(dim) @ diff)) ** (-0.5)
34
35
36# Define the target distribution and kernel
37target = CustomTargetDistribution()
38kernel = KernelJax(custom_kernel)
39ksd = KernelSteinDiscrepancyJax(target=target, kernel=kernel)
40
41# Calculate kernel Stein discrepancy
42res = ksd.kernel_stein_discrepancy(x)
43print(res)