dagma.nonlinear.DagmaNonlinear.fit
¶
-
dagma.nonlinear.DagmaNonlinear.fit(X: torch.Tensor | numpy.ndarray, lambda1: float =
0.02
, lambda2: float =0.005
, T: int =4
, mu_init: float =0.1
, mu_factor: float =0.1
, s: float =1.0
, warm_iter: int =50000.0
, max_iter: int =80000.0
, lr: float =0.0002
, w_threshold: float =0.3
, checkpoint: int =1000
) numpy.ndarray ¶ Runs the DAGMA algorithm and fits the model to the dataset.
- Parameters:¶
- X : Union[torch.Tensor, np.ndarray]¶
\((n,d)\) dataset.
- lambda1 : float, optional¶
Coefficient of the L1 penalty, by default .02.
- lambda2 : float, optional¶
Coefficient of the L2 penalty, by default .005.
- T : int, optional¶
Number of DAGMA iterations, by default 4.
- mu_init : float, optional¶
Initial value of \(\mu\), by default 0.1.
- mu_factor : float, optional¶
Decay factor for \(\mu\), by default .1.
- s : float, optional¶
Controls the domain of M-matrices, by default 1.0.
- warm_iter : int, optional¶
Number of iterations for
minimize()
for \(t < T\), by default 5e4.- max_iter : int, optional¶
Number of iterations for
minimize()
for \(t = T\), by default 8e4.- lr : float, optional¶
Learning rate, by default .0002.
- w_threshold : float, optional¶
Removes edges with weight value less than the given threshold, by default 0.3.
- checkpoint : int, optional¶
If
verbose
isTrue
, then prints to stdout everycheckpoint
iterations, by default 1000.
- Returns:¶
Estimated DAG from data.
- Return type:¶
np.ndarray
Important
If the output of
fit()
is not a DAG, then the user should try larger values ofT
(e.g., 6, 7, or 8) before raising an issue in github.