dagma.nonlinear.DagmaNonlinear.fit

dagma.nonlinear.DagmaNonlinear.fit(X: torch.Tensor | numpy.ndarray, lambda1: float = 0.02, lambda2: float = 0.005, T: int = 4, mu_init: float = 0.1, mu_factor: float = 0.1, s: float = 1.0, warm_iter: int = 50000.0, max_iter: int = 80000.0, lr: float = 0.0002, w_threshold: float = 0.3, checkpoint: int = 1000) numpy.ndarray

Runs the DAGMA algorithm and fits the model to the dataset.

Parameters:
X : Union[torch.Tensor, np.ndarray]

\((n,d)\) dataset.

lambda1 : float, optional

Coefficient of the L1 penalty, by default .02.

lambda2 : float, optional

Coefficient of the L2 penalty, by default .005.

T : int, optional

Number of DAGMA iterations, by default 4.

mu_init : float, optional

Initial value of \(\mu\), by default 0.1.

mu_factor : float, optional

Decay factor for \(\mu\), by default .1.

s : float, optional

Controls the domain of M-matrices, by default 1.0.

warm_iter : int, optional

Number of iterations for minimize() for \(t < T\), by default 5e4.

max_iter : int, optional

Number of iterations for minimize() for \(t = T\), by default 8e4.

lr : float, optional

Learning rate, by default .0002.

w_threshold : float, optional

Removes edges with weight value less than the given threshold, by default 0.3.

checkpoint : int, optional

If verbose is True, then prints to stdout every checkpoint iterations, by default 1000.

Returns:

Estimated DAG from data.

Return type:

np.ndarray

Important

If the output of fit() is not a DAG, then the user should try larger values of T (e.g., 6, 7, or 8) before raising an issue in github.


Last update: Jan 14, 2024