dagma.nonlinear.DagmaMLP

class dagma.nonlinear.DagmaMLP(dims: list[int], bias: bool = True, dtype: torch.dtype = torch.double)

Bases: torch.nn.Module

Class that models the structural equations for the causal graph using MLPs.

Parameters:
dims : List[int]

Number of neurons in hidden layers of each MLP representing each structural equation.

bias : bool, optional

Flag whether to consider bias or not, by default True

dtype : torch.dtype, optional

Float precision, by default torch.double

Methods

forward(→ torch.Tensor)

Applies the current states of the structural equations to the dataset X

h_func(→ torch.Tensor)

Constrain 2-norm-squared of fc1 weights along m1 dim to be a DAG

fc1_l1_reg(→ torch.Tensor)

Takes L1 norm of the weights in the first fully-connected layer

fc1_to_adj(→ numpy.ndarray)

Computes the induced weighted adjacency matrix W from the first FC weights.


Last update: Jan 14, 2024