Source code for skdag.dag._builder

from collections.abc import Mapping, Sequence

import networkx as nx
from skdag.dag._dag import DAG, DAGStep
from skdag.exceptions import DAGError

__all__ = ["DAGBuilder"]


[docs]class DAGBuilder: """ Helper utility for creating a :class:`skdag.DAG`. ``DAGBuilder`` allows a graph to be defined incrementally by specifying one node (step) at a time. Graph edges are defined by providing optional dependency lists that reference each step by name. Note that steps must be defined before they are used as dependencies. Parameters ---------- infer_dataframe : bool, default = False If True, assume ``dataframe_columns="infer"`` every time :meth:`.add_step` is called, if ``dataframe_columns`` is set to ``None``. This effectively makes the resulting DAG always try to coerce output into pandas DataFrames wherever possible. See Also -------- :class:`skdag.DAG` : The estimator DAG created by this utility. Examples -------- >>> from skdag import DAGBuilder >>> from sklearn.decomposition import PCA >>> from sklearn.impute import SimpleImputer >>> from sklearn.linear_model import LogisticRegression >>> dag = ( ... DAGBuilder() ... .add_step("impute", SimpleImputer()) ... .add_step("vitals", "passthrough", deps={"impute": slice(0, 4)}) ... .add_step("blood", PCA(n_components=2, random_state=0), deps={"impute": slice(4, 10)}) ... .add_step("lr", LogisticRegression(random_state=0), deps=["blood", "vitals"]) ... .make_dag() ... ) >>> print(dag.draw().strip()) o impute |\\ o o blood,vitals |/ o lr """ def __init__(self, infer_dataframe=False): self.graph = nx.DiGraph() self.infer_dataframe = infer_dataframe def from_pipeline(self, steps, **kwargs): """ Construct a DAG from a simple linear sequence of steps. The resulting DAG will be equivalent to a :class:`~sklearn.pipeline.Pipeline`. Parameters ---------- steps : sequence of (str, estimator) An ordered sequence of pipeline steps. A step is simply a pair of ``(name, estimator)``, just like a scikit-learn Pipeline. infer_dataframe : bool, default = False If True, assume ``dataframe_columns="infer"`` every time :meth:`.add_step` is called, if ``dataframe_columns`` is set to ``None``. This effectively makes the resulting DAG always try to coerce output into pandas DataFrames wherever possible. kwargs : kwargs Any other hyperparameters that are accepted by :class:`~skdag.dag.DAG`'s contructor. """ if hasattr(steps, "steps"): pipe = steps steps = pipe.steps if hasattr(pipe, "get_params"): kwargs = { **{ k: v for k, v in pipe.get_params().items() if k in ("memory", "verbose") }, **kwargs, } dfcols = "infer" if self.infer_dataframe else None for i in range(len(steps)): name, estimator = steps[i] self._validate_name(name) deps = {} if i > 0: dep = steps[i - 1][0] deps[dep] = None self._validate_deps(deps) step = DAGStep(name, estimator, deps, dataframe_columns=dfcols) self.graph.add_node(name, step=step) if deps: self.graph.add_edge(dep, name) self._validate_graph() return self def add_step(self, name, est, deps=None, dataframe_columns=None): self._validate_name(name) if isinstance(deps, Sequence): deps = {dep: None for dep in deps} if deps is not None: self._validate_deps(deps) else: deps = {} if dataframe_columns is None and self.infer_dataframe: dfcols = "infer" else: dfcols = dataframe_columns step = DAGStep(name, est, deps=deps, dataframe_columns=dfcols) self.graph.add_node(name, step=step) for dep in deps: # Since node is new, edges will never form a cycle. self.graph.add_edge(dep, name) self._validate_graph() return self def _validate_name(self, name): if not isinstance(name, str): raise KeyError(f"step names must be strings, got '{type(name)}'") if name in self.graph.nodes: raise KeyError(f"step with name '{name}' already exists") def _validate_deps(self, deps): if not isinstance(deps, Mapping) or not all( [isinstance(dep, str) for dep in deps] ): raise ValueError( "deps parameter must be a map of labels to indices, " f"got '{type(deps)}'." ) missing = [dep for dep in deps if dep not in self.graph] if missing: raise ValueError(f"unresolvable dependencies: {', '.join(sorted(missing))}") def _validate_graph(self): if not nx.algorithms.dag.is_directed_acyclic_graph(self.graph): raise DAGError("Workflow is not a DAG.") def make_dag(self, **kwargs): self._validate_graph() # Give the DAG a read-only view of the graph. return DAG(graph=self.graph.copy(as_view=True), **kwargs) def _repr_html_(self): return self.make_dag()._repr_html_()