dagma.nonlinear.DagmaNonlinear

class dagma.nonlinear.DagmaNonlinear(model: torch.nn.Module, verbose: bool = False, dtype: torch.dtype = torch.double)

Class that implements the DAGMA algorithm

Parameters:
model : nn.Module

Neural net that models the structural equations.

verbose : bool, optional

If true, the loss/score and h values will print to stdout every checkpoint iterations, as defined in fit(). Defaults to False.

dtype : torch.dtype, optional

float number precision, by default torch.double.

Methods

log_mse_loss(→ torch.Tensor)

Computes the logarithm of the MSE loss:

minimize(→ bool)

Solves the optimization problem:

fit(→ numpy.ndarray)

Runs the DAGMA algorithm and fits the model to the dataset.


Last update: Jan 14, 2024