dagma.nonlinear.DagmaMLP
¶
-
class dagma.nonlinear.DagmaMLP(dims: list[int], bias: bool =
True
, dtype: torch.dtype =torch.double
)¶ Bases:
torch.nn.Module
Class that models the structural equations for the causal graph using MLPs.
- Parameters:¶
Methods¶
forward
(→ torch.Tensor)Applies the current states of the structural equations to the dataset X
h_func
(→ torch.Tensor)Constrain 2-norm-squared of fc1 weights along m1 dim to be a DAG
fc1_l1_reg
(→ torch.Tensor)Takes L1 norm of the weights in the first fully-connected layer
fc1_to_adj
(→ numpy.ndarray)Computes the induced weighted adjacency matrix W from the first FC weights.
Last update:
Jan 14, 2024