target

class ksd_metric.target.TargetDistributionInterface[source]

Bases: ABC

Interface for target distributions used in Kernel Stein Discrepancy (KSD).

abstract log_target_pdf(x)[source]

Logarithmic Probability Density Function.

Parameters:

x (ArrayLike) – Input data point.

Returns:

Logarithmic probability density at x.

Return type:

(ArrayLike)

abstract grad_log_target_pdf(x)[source]

Gradient of Logarithmic Probability Density Function.

Parameters:

x (ArrayLike) – Input data point.

Returns:

Gradient of logarithmic probability density at x.

Return type:

(ArrayLike)

class ksd_metric.target.TargetDistributionJax(log_target_pdf)[source]

Bases: TargetDistributionInterface

Represents a target distribution for the KSD.

__init__(log_target_pdf)[source]

Initializes the target distribution with a logarithmic probability density function.

Parameters:

log_target_pdf (DifferentiableFunction) – A function that computes the log of the target PDF.

log_target_pdf(x)[source]

Computes the log probability density at the given point.

Parameters:

x (Float[Array, "num dim"]) – Input data point.

Returns:

Logarithmic probability density at x.

Return type:

Float[Array, “num dim”]

grad_log_target_pdf(x)[source]

Gradient of Logarithmic Probability Density Function.

Parameters:

x (Float[Array, "num dim"]) – Input data point.

Returns:

Gradient of logarithmic probability density at x.

Return type:

Float[Array, “num dim”]