dagma.linear.DagmaLinear.fit
¶
-
dagma.linear.DagmaLinear.fit(X: numpy.ndarray, lambda1: float =
0.03
, w_threshold: float =0.3
, T: int =5
, mu_init: float =1.0
, mu_factor: float =0.1
, s: list[float] | float =[1.0, 0.9, 0.8, 0.7, 0.6]
, warm_iter: int =30000.0
, max_iter: int =60000.0
, lr: float =0.0003
, checkpoint: int =1000
, beta_1: float =0.99
, beta_2: float =0.999
, exclude_edges: list[tuple[int, int]] | None =None
, include_edges: list[tuple[int, int]] | None =None
) numpy.ndarray ¶ Runs the DAGMA algorithm and returns a weighted adjacency matrix.
- Parameters:¶
- X : np.ndarray¶
\((n,d)\) dataset.
- lambda1 : float¶
Coefficient of the L1 penalty. Defaults to 0.03.
- w_threshold : float, optional¶
Removes edges with weight value less than the given threshold. Defaults to 0.3.
- T : int, optional¶
Number of DAGMA iterations. Defaults to 5.
- mu_init : float, optional¶
Initial value of \(\mu\). Defaults to 1.0.
- mu_factor : float, optional¶
Decay factor for \(\mu\). Defaults to 0.1.
- s : Union[List[float], float], optional¶
Controls the domain of M-matrices. Defaults to [1.0, .9, .8, .7, .6].
- warm_iter : int, optional¶
Number of iterations for
minimize()
for \(t < T\). Defaults to 3e4.- max_iter : int, optional¶
Number of iterations for
minimize()
for \(t = T\). Defaults to 6e4.- lr : float, optional¶
Learning rate. Defaults to 0.0003.
- checkpoint : int, optional¶
If
verbose
isTrue
, then prints to stdout everycheckpoint
iterations. Defaults to 1000.- beta_1 : float, optional¶
Adam hyperparameter. Defaults to 0.99.
- beta_2 : float, optional¶
Adam hyperparameter. Defaults to 0.999.
- exclude_edges : Optional[List[Tuple[int, int]]], optional¶
Tuple of edges that should be excluded from the DAG solution, e.g.,
((1,3), (2,4), (5,1))
. Defaults to None.- include_edges : Optional[List[Tuple[int, int]]], optional¶
Tuple of edges that should be included from the DAG solution, e.g.,
((1,3), (2,4), (5,1))
. Defaults to None.
- Returns:¶
Estimated DAG from data.
- Return type:¶
np.ndarray
Important
If the output of
fit()
is not a DAG, then the user should try larger values ofT
(e.g., 6, 7, or 8) before raising an issue in github.Warning
While DAGMA ensures to exclude the edges given in
exclude_edges
, the current implementation does not guarantee that all edges inincluded edges
will be part of the final DAG.