Source code for opengnc.kalman_filters.enkf

"""
Ensemble Kalman Filter (EnKF) using Monte Carlo samples for covariance representation.
"""

from collections.abc import Callable
from typing import Any, cast

from collections.abc import Callable
from typing import Any

import numpy as np


[docs] class EnKF: """ Ensemble Kalman Filter (EnKF). Uses an ensemble of states to represent the error covariance matrix. Highly efficient for high-dimensional systems (e.g., weather/climate models) where the full covariance matrix is too large to compute. Parameters ---------- dim_x : int Dimension of the state vector. dim_z : int Dimension of the measurement vector. ensemble_size : int, optional Number of ensemble members (N). Default is 50. """ def __init__(self, dim_x: int, dim_z: int, ensemble_size: int = 50) -> None: self.dim_x = dim_x self.dim_z = dim_z self.num_ensemble = ensemble_size # Ensemble of states: shape (dim_x, N) self.X = np.zeros((dim_x, self.num_ensemble)) self.Q = np.eye(dim_x) self.R = np.eye(dim_z)
[docs] def initialize_ensemble(self, x_mean: np.ndarray, p_cov: np.ndarray) -> None: """ Initialize the ensemble using a multivariate normal distribution. Parameters ---------- x_mean : np.ndarray Mean initial state (dim_x,). p_cov : np.ndarray Initial state error covariance (dim_x, dim_x). """ self.X = cast( np.ndarray, np.random.multivariate_normal(x_mean, p_cov, self.num_ensemble).T )
[docs] def predict( self, dt: float, fx_func: Callable, q_mat: np.ndarray | None = None, **kwargs: Any, ) -> None: r""" Predict step (Propagates each ensemble member). Parameters ---------- dt : float Time step (s). fx_func : Callable Nonlinear state transition function $f(x, dt, **kwargs) \to x_{new}$. q_mat : np.ndarray, optional Process noise covariance (dim_x, dim_x). If None, uses `self.Q`. **kwargs : Any Additional arguments passed to transition function. """ q_curr = q_mat if q_mat is not None else self.Q # Propagate each ensemble member for i in range(self.num_ensemble): # Propagate through nonlinear model self.X[:, i] = cast(np.ndarray, fx_func(self.X[:, i], dt, **kwargs)) # Add process noise to each member noise = np.random.multivariate_normal(np.zeros(self.dim_x), q_curr) self.X[:, i] += noise
[docs] def update( self, z: np.ndarray, hx_func: Callable, r_mat: np.ndarray | None = None, **kwargs: Any, ) -> None: r""" Update step (Ensemble transformation). Parameters ---------- z : np.ndarray Measurement vector (dim_z,). hx_func : Callable Nonlinear measurement function $h(x, **kwargs) \to z_{pred}$. r_mat : np.ndarray, optional Measurement noise covariance (dim_z, dim_z). If None, uses `self.R`. **kwargs : Any Additional arguments passed to measurement function. """ r_curr = r_mat if r_mat is not None else self.R # Transform ensemble to measurement space z_ensemble = np.zeros((self.dim_z, self.num_ensemble)) for i in range(self.num_ensemble): z_ensemble[:, i] = cast(np.ndarray, hx_func(self.X[:, i], **kwargs)) # Sample mean of measurement ensemble z_mean = np.mean(z_ensemble, axis=1, keepdims=True) # Calculate ensemble anomalies (perturbations) x_mean_vec = np.mean(self.X, axis=1, keepdims=True) anomalies_x = self.X - x_mean_vec # State anomalies anomalies_z = z_ensemble - z_mean # Measurement anomalies # Perturbed measurements (adding noise for each ensemble member) z_perturbed = np.zeros((self.dim_z, self.num_ensemble)) for i in range(self.num_ensemble): noise = np.random.multivariate_normal(np.zeros(self.dim_z), r_curr) z_perturbed[:, i] = z + noise # Innovation innov_ensemble = z_perturbed - z_ensemble # Innovation covariance S = (1/(N-1)) * B * B.T + R s_mat = (1.0 / (self.num_ensemble - 1)) * (anomalies_z @ anomalies_z.T) + r_curr # Cross-covariance Pxz = (1/(N-1)) * A * B.T pxz = (1.0 / (self.num_ensemble - 1)) * (anomalies_x @ anomalies_z.T) # Kalman Gain K = Pxz * inv(S) k_gain = pxz @ np.linalg.inv(s_mat) # Correct each ensemble member self.X += k_gain @ innov_ensemble
@property def x(self) -> np.ndarray: """Ensemble mean state vector.""" return cast(np.ndarray, np.mean(self.X, axis=1)) @property def P(self) -> np.ndarray: """Ensemble covariance matrix.""" anomalies_x = self.X - np.mean(self.X, axis=1, keepdims=True) return cast(np.ndarray, (1.0 / (self.num_ensemble - 1)) * (anomalies_x @ anomalies_x.T))