Skip to content

Commit b79a7ff

Browse files
GarrettWuTrevorBergeron
authored andcommitted
fix!: exclude remote models for .register() (#465)
* fix: exclude remote models for .register() * fix mypy
1 parent fbea9df commit b79a7ff

File tree

3 files changed

+8
-16
lines changed

3 files changed

+8
-16
lines changed

bigframes/ml/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __repr__(self):
9090
return prettyprinter.pformat(self)
9191

9292

93+
# TODO(garrettwu): refactor to reflect the actual property. Now the class contains .register() method.
9394
class Predictor(BaseEstimator):
9495
"""A BigQuery DataFrames ML Model base class that can be used to predict outputs."""
9596

bigframes/ml/llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949

5050
@log_adapter.class_logger
51-
class PaLM2TextGenerator(base.Predictor):
51+
class PaLM2TextGenerator(base.BaseEstimator):
5252
"""PaLM2 text generator LLM model.
5353
5454
Args:
@@ -258,7 +258,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator:
258258

259259

260260
@log_adapter.class_logger
261-
class PaLM2TextEmbeddingGenerator(base.Predictor):
261+
class PaLM2TextEmbeddingGenerator(base.BaseEstimator):
262262
"""PaLM2 text embedding generator LLM model.
263263
264264
Args:
@@ -418,7 +418,7 @@ def to_gbq(
418418

419419

420420
@log_adapter.class_logger
421-
class GeminiTextGenerator(base.Predictor):
421+
class GeminiTextGenerator(base.BaseEstimator):
422422
"""Gemini text generator LLM model.
423423
424424
Args:

tests/system/small/ml/test_register.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from typing import cast
1616

17+
import pytest
18+
1719
from bigframes.ml import core, imported, linear_model, llm
1820

1921

@@ -54,19 +56,8 @@ def test_linear_reg_register_with_params(
5456
def test_palm2_text_generator_register(
5557
ephemera_palm2_text_generator_model: llm.PaLM2TextGenerator,
5658
):
57-
model = ephemera_palm2_text_generator_model
58-
model.register()
59-
60-
model_name = "bigframes_" + cast(
61-
str, cast(core.BqmlModel, model._bqml_model).model.model_id
62-
)
63-
# Only registered model contains the field, and the field includes project/dataset. Here only check model_id.
64-
assert (
65-
model_name[:63] # truncated
66-
in cast(core.BqmlModel, model._bqml_model).model.training_runs[-1][
67-
"vertexAiModelId"
68-
]
69-
)
59+
with pytest.raises(AttributeError):
60+
ephemera_palm2_text_generator_model.register() # type: ignore
7061

7162

7263
def test_imported_tensorflow_register(

0 commit comments

Comments
 (0)