Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autogluon #33

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
update fit_model for new prediction method with X_prediction
  • Loading branch information
diegomarvid committed Jul 17, 2024
commit 38a68bf503970b7163b82f56bca1de96c6e131c5
78 changes: 72 additions & 6 deletions ml_garden/core/model_registry.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,71 @@
import importlib
import logging
import pkgutil
from typing import Dict, Type

from ml_garden.core.model import Model


class ModelClassNotFoundError(Exception):
"""Exception raised when a model class is not found in the registry."""

pass


class ModelRegistry:
def __init__(self):
self._model_registry = {}
"""
Initialize a new ModelRegistry instance.

Attributes
----------
_model_registry : dict
A dictionary mapping model names to model classes.
logger : logging.Logger
Logger for the class.
"""
self._model_registry: Dict[str, Type[Model]] = {}
self.logger = logging.getLogger(__name__)

def register_model(self, model_class: type):
model_name = model_class.__name__
def register_model(self, model_class: Type[Model]) -> None:
"""
Register a model class in the registry.

Parameters
----------
model_class : Type[Model]
The model class to be registered.

Raises
------
ValueError
If the model_class is not a subclass of Model.
"""
model_name = model_class.__name__.lower()
if not issubclass(model_class, Model):
raise ValueError(f"{model_class} must be a subclass of Model")
self._model_registry[model_name] = model_class

def get_model_class(self, model_name: str) -> type:
def get_model_class(self, model_name: str) -> Type[Model]:
"""
Retrieve a model class from the registry.

Parameters
----------
model_name : str
The name of the model class to retrieve.

Returns
-------
Type[Model]
The model class.

Raises
------
ModelClassNotFoundError
If the model class is not found in the registry.
"""
model_name = model_name.lower()
if model_name in self._model_registry:
return self._model_registry[model_name]
else:
Expand All @@ -29,10 +74,31 @@ def get_model_class(self, model_name: str) -> type:
f" {list(self._model_registry.keys())}"
)

def get_all_model_classes(self) -> dict:
def get_all_model_classes(self) -> Dict[str, Type[Model]]:
"""
Get all registered model classes.

Returns
-------
dict
A dictionary of all registered model classes.
"""
return self._model_registry

def auto_register_models_from_package(self, package_name: str):
def auto_register_models_from_package(self, package_name: str) -> None:
"""
Automatically register all model classes from a given package.

Parameters
----------
package_name : str
The name of the package to search for model classes.

Raises
------
ImportError
If the package cannot be imported.
"""
try:
package = importlib.import_module(package_name)
prefix = package.__name__ + "."
Expand Down
1 change: 1 addition & 0 deletions ml_garden/core/steps/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def predict(self, data: DataContainer) -> DataContainer:
The updated data container
"""
self.logger.info(f"Predicting with {self.model_class.__name__} model")
data.X_prediction = data.flow.drop(columns=data.columns_to_ignore_for_training)
data.flow[data.prediction_column] = data.model.predict(data.X_prediction)
data.predictions = data.flow[data.prediction_column]
return data
3 changes: 2 additions & 1 deletion ml_garden/implementation/tabular/autogluon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pandas as pd
from autogluon.tabular import TabularPredictor

from ml_garden.core.constants import Task
from ml_garden.core.model import Model

logger = logging.getLogger(__file__)


class AutoGluon(Model):
TASKS = ["regression", "classification"]
TASKS = [Task.REGRESSION, Task.CLASSIFICATION]

def __init__(self, **params):
self.params = params
Expand Down