dagma.linear.DagmaLinear.fit

dagma.linear.DagmaLinear.fit(X: numpy.ndarray, lambda1: float = 0.03, w_threshold: float = 0.3, T: int = 5, mu_init: float = 1.0, mu_factor: float = 0.1, s: list[float] | float = [1.0, 0.9, 0.8, 0.7, 0.6], warm_iter: int = 30000.0, max_iter: int = 60000.0, lr: float = 0.0003, checkpoint: int = 1000, beta_1: float = 0.99, beta_2: float = 0.999, exclude_edges: list[tuple[int, int]] | None = None, include_edges: list[tuple[int, int]] | None = None) numpy.ndarray

Runs the DAGMA algorithm and returns a weighted adjacency matrix.

Parameters:
X : np.ndarray

\((n,d)\) dataset.

lambda1 : float

Coefficient of the L1 penalty. Defaults to 0.03.

w_threshold : float, optional

Removes edges with weight value less than the given threshold. Defaults to 0.3.

T : int, optional

Number of DAGMA iterations. Defaults to 5.

mu_init : float, optional

Initial value of \(\mu\). Defaults to 1.0.

mu_factor : float, optional

Decay factor for \(\mu\). Defaults to 0.1.

s : Union[List[float], float], optional

Controls the domain of M-matrices. Defaults to [1.0, .9, .8, .7, .6].

warm_iter : int, optional

Number of iterations for minimize() for \(t < T\). Defaults to 3e4.

max_iter : int, optional

Number of iterations for minimize() for \(t = T\). Defaults to 6e4.

lr : float, optional

Learning rate. Defaults to 0.0003.

checkpoint : int, optional

If verbose is True, then prints to stdout every checkpoint iterations. Defaults to 1000.

beta_1 : float, optional

Adam hyperparameter. Defaults to 0.99.

beta_2 : float, optional

Adam hyperparameter. Defaults to 0.999.

exclude_edges : Optional[List[Tuple[int, int]]], optional

Tuple of edges that should be excluded from the DAG solution, e.g., ((1,3), (2,4), (5,1)). Defaults to None.

include_edges : Optional[List[Tuple[int, int]]], optional

Tuple of edges that should be included from the DAG solution, e.g., ((1,3), (2,4), (5,1)). Defaults to None.

Returns:

Estimated DAG from data.

Return type:

np.ndarray

Important

If the output of fit() is not a DAG, then the user should try larger values of T (e.g., 6, 7, or 8) before raising an issue in github.

Warning

While DAGMA ensures to exclude the edges given in exclude_edges, the current implementation does not guarantee that all edges in included edges will be part of the final DAG.


Last update: Jan 14, 2024