Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions hamilton/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,54 @@
# under the License.

"""This module contains base constructs for executing a hamilton graph.
It should only import hamilton.node, numpy, pandas.
It should keep imports light and defer pandas/numpy until result builders need them.
It cannot import hamilton.graph, or hamilton.driver.
"""

from __future__ import annotations

import abc
import collections
import logging
from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
import numpy as np
import pandas as pd

import numpy as np
import pandas as pd
from pandas.core.indexes import extension as pd_extension
try:
from . import node
except ImportError:
import node

from hamilton.lifecycle import api as lifecycle_api

try:
from . import htypes, node
from . import htypes
except ImportError:
import node
import htypes

logger = logging.getLogger(__name__)


def _get_pandas():
import pandas as pd

return pd


def _get_pandas_extension():
from pandas.core.indexes import extension as pd_extension

return pd_extension


def _get_numpy():
import numpy as np

return np


class ResultMixin(lifecycle_api.LegacyResultMixin):
"""Legacy result builder -- see lifecycle methods for more information."""

Expand Down Expand Up @@ -123,6 +148,8 @@ def pandas_index_types(
all_index_types = collections.defaultdict(list)
time_indexes = collections.defaultdict(list)
no_indexes = collections.defaultdict(list)
pd = _get_pandas()
pd_extension = _get_pandas_extension()

def index_key_name(pd_object: pd.DataFrame | pd.Series) -> str:
"""Creates a string helping identify the index and it's type.
Expand Down Expand Up @@ -221,6 +248,7 @@ def build_result(**outputs: dict[str, Any]) -> pd.DataFrame:
:param outputs: the outputs to build a dataframe from.
"""
# TODO check inputs are pd.Series, arrays, or scalars -- else error
pd = _get_pandas()
output_index_type_tuple = PandasDataFrameResult.pandas_index_types(outputs)
# this next line just log warnings
# we don't actually care about the result since this is the current default behavior.
Expand Down Expand Up @@ -255,6 +283,7 @@ def build_dataframe_with_dataframes(outputs: dict[str, Any]) -> pd.DataFrame:
:param outputs: The outputs to build the dataframe from.
:return: A dataframe with the outputs.
"""
pd = _get_pandas()

def get_output_name(output_name: str, column_name: str) -> str:
"""Add function prefix to columns.
Expand Down Expand Up @@ -300,7 +329,7 @@ def input_types(self) -> list[type[type]]:
return [Any]

def output_type(self) -> type:
return pd.DataFrame
return _get_pandas().DataFrame


class StrictIndexTypePandasDataFrameResult(PandasDataFrameResult):
Expand Down Expand Up @@ -366,6 +395,7 @@ def build_result(**outputs: dict[str, Any]) -> np.matrix:
:return: numpy matrix
"""
# TODO check inputs are all numpy arrays/array like things -- else error
np = _get_numpy()
num_rows = -1
columns_with_lengths = collections.OrderedDict()
for col, val in outputs.items(): # assumption is fixed order
Expand Down Expand Up @@ -402,7 +432,7 @@ def input_types(self) -> list[type[type]]:
return [Any] # Typing

def output_type(self) -> type:
return pd.DataFrame
return _get_pandas().DataFrame


class HamiltonGraphAdapter(lifecycle_api.GraphAdapter, abc.ABC):
Expand Down
109 changes: 89 additions & 20 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import abc
import importlib
import importlib.util
import json
import logging
import operator
import pathlib
import sys

# required if we want to run this code stand alone.
Expand All @@ -30,26 +31,30 @@
import warnings
from collections.abc import Callable, Collection, Sequence
from datetime import datetime
from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
Literal,
Optional,
)

import pandas as pd
from typing_extensions import Self

from hamilton import common, graph_types, htypes
from hamilton.caching.adapter import HamiltonCacheAdapter
from hamilton.caching.stores.base import MetadataStore, ResultStore
from hamilton.dev_utils import deprecation
from hamilton.execution import executors, graph_functions, grouping, state
from hamilton.graph_types import HamiltonNode
from hamilton.io import materialization
from hamilton.io.materialization import ExtractorFactory, MaterializerFactory
from hamilton.lifecycle import base as lifecycle_base

if TYPE_CHECKING:
import pathlib
from types import ModuleType

from hamilton.caching.adapter import HamiltonCacheAdapter
from hamilton.caching.stores.base import MetadataStore, ResultStore
from hamilton.io import materialization
from hamilton.io.materialization import ExtractorFactory, MaterializerFactory

SLACK_ERROR_MESSAGE = (
"-------------------------------------------------------------------\n"
"Oh no an error! Need help with Hamilton?\n"
Expand All @@ -59,14 +64,66 @@

if __name__ == "__main__":
import base
import graph
import node
else:
from . import base, graph, node
from . import base, node

logger = logging.getLogger(__name__)


class _LazyModule:
def __init__(self, module_name: str, global_name: str):
self.module_name = module_name
self.global_name = global_name

def _load(self) -> ModuleType:
module = importlib.import_module(self.module_name)
globals()[self.global_name] = module
return module

def __getattr__(self, name: str) -> Any:
return getattr(self._load(), name)


graph = _LazyModule("hamilton.graph", "graph")


def _get_hamilton_cache_adapter():
from hamilton.caching.adapter import HamiltonCacheAdapter

return HamiltonCacheAdapter


def _get_materialization():
return importlib.import_module("hamilton.io.materialization")


def _get_materializer_types():
materialization = _get_materialization()
return materialization.MaterializerFactory, materialization.ExtractorFactory


_LAZY_IMPORTS = {
"HamiltonCacheAdapter": ("hamilton.caching.adapter", "HamiltonCacheAdapter"),
"MetadataStore": ("hamilton.caching.stores.base", "MetadataStore"),
"ResultStore": ("hamilton.caching.stores.base", "ResultStore"),
"ExtractorFactory": ("hamilton.io.materialization", "ExtractorFactory"),
"MaterializerFactory": ("hamilton.io.materialization", "MaterializerFactory"),
"materialization": ("hamilton.io.materialization", None),
}


def __getattr__(name: str) -> Any:
"""Lazily expose names that used to be imported at module load."""
if name in _LAZY_IMPORTS:
module_name, attr_name = _LAZY_IMPORTS[name]
module = importlib.import_module(module_name)
value = module if attr_name is None else getattr(module, attr_name)
globals()[name] = value
return value
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def capture_function_usage(call_fn: Callable) -> Callable:
"""No-op decorator kept for backwards compatibility.

Expand Down Expand Up @@ -400,7 +457,7 @@ def __init__(
| list[lifecycle_base.LifecycleAdapter]
| None = None,
allow_module_overrides: bool = False,
_materializers: typing.Sequence[ExtractorFactory | MaterializerFactory] = None,
_materializers: typing.Sequence["ExtractorFactory | MaterializerFactory"] | None = None,
_graph_executor: GraphExecutor = None,
_use_legacy_adapter: bool = True,
):
Expand Down Expand Up @@ -440,6 +497,7 @@ def __init__(
materializer_factories, extractor_factories = self._process_materializers(
_materializers
)
materialization = _get_materialization()
self.graph = materialization.modify_graph(
self.graph, materializer_factories, extractor_factories
)
Expand Down Expand Up @@ -1347,15 +1405,16 @@ def visualize_path_between(
logger.warning(f"Unable to import {e}", exc_info=True)

def _process_materializers(
self, materializers: typing.Sequence[MaterializerFactory | ExtractorFactory]
) -> tuple[list[MaterializerFactory], list[ExtractorFactory]]:
self, materializers: "typing.Sequence[MaterializerFactory | ExtractorFactory]"
) -> "tuple[list[MaterializerFactory], list[ExtractorFactory]]":
"""Processes materializers, splitting them into materializers and extractors.
Note that this also sanitizes the variable names in the materializer dependencies,
so one can pass in functions instead of strings.

:param materializers: Materializers to process
:return: Tuple of materializers and extractors
"""
MaterializerFactory, ExtractorFactory = _get_materializer_types()
module_set = {_module.__name__ for _module in self.graph_modules}
materializer_factories = [
m.sanitize_dependencies(module_set)
Expand All @@ -1368,7 +1427,7 @@ def _process_materializers(
@capture_function_usage
def materialize(
self,
*materializers: materialization.MaterializerFactory | materialization.ExtractorFactory,
*materializers: "materialization.MaterializerFactory | materialization.ExtractorFactory",
additional_vars: list[str | Callable | Variable] = None,
overrides: dict[str, Any] = None,
inputs: dict[str, Any] = None,
Expand Down Expand Up @@ -1546,6 +1605,7 @@ def materialize(
# This is so the finally logging statement does not accidentally die
materializer_vars = []
try:
materialization = _get_materialization()
materializer_factories, extractor_factories = self._process_materializers(materializers)
if len(materializer_factories) == len(final_vars) == 0:
raise ValueError(
Expand Down Expand Up @@ -1614,7 +1674,7 @@ def materialize(
@capture_function_usage
def visualize_materialization(
self,
*materializers: MaterializerFactory | ExtractorFactory,
*materializers: "MaterializerFactory | ExtractorFactory",
output_file_path: str = None,
render_kwargs: dict = None,
additional_vars: list[str | Callable | Variable] = None,
Expand Down Expand Up @@ -1655,6 +1715,7 @@ def visualize_materialization(
"""
if additional_vars is None:
additional_vars = []
materialization = _get_materialization()
materializer_factories, extractor_factories = self._process_materializers(materializers)
function_graph = materialization.modify_graph(
self.graph, materializer_factories, extractor_factories
Expand Down Expand Up @@ -1701,7 +1762,7 @@ def validate_execution(

def validate_materialization(
self,
*materializers: materialization.MaterializerFactory,
*materializers: "materialization.MaterializerFactory",
additional_vars: list[str | Callable | Variable] = None,
overrides: dict[str, Any] = None,
inputs: dict[str, Any] = None,
Expand All @@ -1719,6 +1780,7 @@ def validate_materialization(
additional_vars = []
final_vars = self._create_final_vars(additional_vars)
module_set = {_module.__name__ for _module in self.graph_modules}
materialization = _get_materialization()
materializer_factories, extractor_factories = self._process_materializers(materializers)
materializer_factories = [
m.sanitize_dependencies(module_set) for m in materializer_factories
Expand All @@ -1737,8 +1799,9 @@ def validate_materialization(
self.graph_executor.validate(list(all_nodes))

@property
def cache(self) -> HamiltonCacheAdapter:
def cache(self) -> "HamiltonCacheAdapter":
"""Directly access the cache adapter"""
HamiltonCacheAdapter = _get_hamilton_cache_adapter()
if self.adapter:
for adapter in self.adapter.adapters:
if isinstance(adapter, HamiltonCacheAdapter):
Expand Down Expand Up @@ -1835,6 +1898,7 @@ def with_adapters(self, *adapters: lifecycle_base.LifecycleAdapter) -> Self:
:param adapter: Adapter to use.
:return: self
"""
HamiltonCacheAdapter = _get_hamilton_cache_adapter()
if any(isinstance(adapter, HamiltonCacheAdapter) for adapter in adapters):
self._require_field_unset(
"cache", "Cannot use `.with_cache()` or with `.with_adapters(SmartCacheAdapter())`."
Expand All @@ -1843,13 +1907,14 @@ def with_adapters(self, *adapters: lifecycle_base.LifecycleAdapter) -> Self:
self.adapters.extend(adapters)
return self

def with_materializers(self, *materializers: ExtractorFactory | MaterializerFactory) -> Self:
def with_materializers(self, *materializers: "ExtractorFactory | MaterializerFactory") -> Self:
"""Add materializer nodes to the `Driver`
The generated nodes can be referenced by name in `.execute()`

:param materializers: materializers to add to the dataflow
:return: self
"""
MaterializerFactory, ExtractorFactory = _get_materializer_types()
if any(
m for m in materializers if not isinstance(m, (ExtractorFactory, MaterializerFactory))
):
Expand All @@ -1868,8 +1933,8 @@ def with_materializers(self, *materializers: ExtractorFactory | MaterializerFact
def with_cache(
self,
path: str | pathlib.Path = ".hamilton_cache",
metadata_store: MetadataStore | None = None,
result_store: ResultStore | None = None,
metadata_store: "MetadataStore" | None = None,
result_store: "ResultStore" | None = None,
default: Literal[True] | Sequence[str] | None = None,
recompute: Literal[True] | Sequence[str] | None = None,
ignore: Literal[True] | Sequence[str] | None = None,
Expand Down Expand Up @@ -1920,6 +1985,7 @@ def with_cache(
self._require_field_unset(
"cache", "Cannot use `.with_cache()` or with `.with_adapters(SmartCacheAdapter())`."
)
HamiltonCacheAdapter = _get_hamilton_cache_adapter()
adapter = HamiltonCacheAdapter(
path=path,
metadata_store=metadata_store,
Expand All @@ -1937,12 +2003,13 @@ def with_cache(
return self

@property
def cache(self) -> HamiltonCacheAdapter | None:
def cache(self) -> "HamiltonCacheAdapter" | None:
"""Attribute to check if a cache was set, either via `.with_cache()` or
`.with_adapters(SmartCacheAdapter())`

Required for the check `._require_field_unset()`
"""
HamiltonCacheAdapter = _get_hamilton_cache_adapter()
if self.adapters:
for adapter in self.adapters:
if isinstance(adapter, HamiltonCacheAdapter):
Expand Down Expand Up @@ -2083,6 +2150,8 @@ def copy(self) -> "Builder":
"""some example test code"""
import importlib

import pandas as pd

formatter = logging.Formatter("[%(levelname)s] %(asctime)s %(name)s(%(lineno)s): %(message)s")
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
Expand Down
Loading