Skip to content

Commit bf9153a

Browse files
authored
fix: handle list when filtering rows with dates (#4700)
## 📝 Summary <!-- Provide a concise summary of what this pull request is addressing. If this PR fixes any issues, list them here by number (e.g., Fixes #123). --> Filtering lists of dates would previously throw an error as they were not converted to dates. ## 🔍 Description of Changes <!-- Detail the specific changes made in this pull request. Explain the problem addressed and how it was resolved. If applicable, provide before and after comparisons, screenshots, or any relevant details to help reviewers understand the changes easily. --> ## 📋 Checklist - [X] I have read the [contributor guidelines](https://github.com/marimo-team/marimo/blob/main/CONTRIBUTING.md). - [ ] For large changes, or changes that affect the public API: this change was discussed or approved through an issue, on [Discord](https://marimo.io/discord?ref=pr), or the community [discussions](https://github.com/marimo-team/marimo/discussions) (Please provide a link if applicable). - [X] I have added tests for the changes made. - [X] I have run the code and verified that it works as expected. ## 📜 Reviewers <!-- Tag potential reviewers from the community or maintainers who might be interested in reviewing this pull request. Your PR will be reviewed more quickly if you can figure out the right person to tag with @ --> @akshayka OR @mscolnick
1 parent 50ea876 commit bf9153a

File tree

2 files changed

+99
-7
lines changed

2 files changed

+99
-7
lines changed

marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import datetime
55
from collections.abc import Sequence
6-
from typing import TYPE_CHECKING, Any, Optional, cast
6+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
77

88
from marimo._plugins.ui._impl.dataframes.transforms.print_code import (
99
python_print_ibis,
@@ -268,6 +268,12 @@ def handle_filter_rows(
268268
# Start with no filter (all rows included)
269269
filter_expr: Optional[pl.Expr] = None
270270

271+
# Convert a value whether it's a list or single value
272+
def convert_value(v: Any, converter: Callable[[str], Any]) -> Any:
273+
if isinstance(v, (tuple, list)):
274+
return [converter(str(item)) for item in v]
275+
return converter(str(v))
276+
271277
# Iterate over all conditions and build the filter expression
272278
for condition in transform.where:
273279
column = col(str(condition.column_id))
@@ -276,12 +282,12 @@ def handle_filter_rows(
276282
value_str = str(value)
277283

278284
# If columns type is a Datetime, we need to convert the value to a datetime
279-
if dtype == pl.Datetime and isinstance(value, str):
280-
value = datetime.datetime.fromisoformat(value)
281-
elif dtype == pl.Date and isinstance(value, str):
282-
value = datetime.date.fromisoformat(value)
283-
elif dtype == pl.Time and isinstance(value, str):
284-
value = datetime.time.fromisoformat(value)
285+
if dtype == pl.Datetime:
286+
value = convert_value(value, datetime.datetime.fromisoformat)
287+
elif dtype == pl.Date:
288+
value = convert_value(value, datetime.date.fromisoformat)
289+
elif dtype == pl.Time:
290+
value = convert_value(value, datetime.time.fromisoformat)
285291

286292
# If columns type is a Categorical, we need to cast the value to a string
287293
if dtype == pl.Categorical:

tests/_plugins/ui/_impl/dataframes/test_handlers.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2024 Marimo. All rights reserved.
22
from __future__ import annotations
33

4+
from datetime import date, datetime
45
from typing import Any, cast
56
from unittest.mock import Mock
67

@@ -443,6 +444,39 @@ def test_handle_filter_rows_6(
443444
result = apply(df, transform)
444445
assert_frame_equal(result, expected)
445446

447+
@staticmethod
448+
@pytest.mark.parametrize(
449+
("df", "expected"),
450+
[
451+
(
452+
pd.DataFrame({"date": [date(2001, 1, 1), date(2001, 1, 2)]}),
453+
pd.DataFrame({"date": [date(2001, 1, 1)]}),
454+
),
455+
(
456+
pl.DataFrame({"date": [date(2001, 1, 1), date(2001, 1, 2)]}),
457+
pl.DataFrame({"date": [date(2001, 1, 1)]}),
458+
),
459+
(
460+
ibis.memtable({"date": [date(2001, 1, 1), date(2001, 1, 2)]}),
461+
ibis.memtable({"date": [date(2001, 1, 1)]}),
462+
),
463+
],
464+
)
465+
def test_handle_filter_rows_date(
466+
df: DataFrameType, expected: DataFrameType
467+
) -> None:
468+
transform = FilterRowsTransform(
469+
type=TransformType.FILTER_ROWS,
470+
operation="keep_rows",
471+
where=[
472+
Condition(
473+
column_id="date", operator="==", value=date(2001, 1, 1)
474+
)
475+
],
476+
)
477+
result = apply(df, transform)
478+
assert_frame_equal(result, expected)
479+
446480
@staticmethod
447481
@pytest.mark.parametrize(
448482
("df", "expected"),
@@ -472,6 +506,58 @@ def test_filter_rows_in_operator(
472506
result = apply(df, transform)
473507
assert_frame_equal(result, expected)
474508

509+
@staticmethod
510+
@pytest.mark.parametrize(
511+
("df", "expected", "column"),
512+
[
513+
# TODO: Pandas treats date objects as strings
514+
# (
515+
# pd.DataFrame({"date": [date(2001, 1, 1), date(2001, 1, 2)]}),
516+
# pd.DataFrame({"date": [date(2001, 1, 1)]}),
517+
# ),
518+
(
519+
pl.DataFrame({"date": [date(2001, 1, 1), date(2001, 1, 2)]}),
520+
pl.DataFrame({"date": [date(2001, 1, 1)]}),
521+
"date",
522+
),
523+
(
524+
pl.DataFrame(
525+
{"datetime": [datetime(2001, 1, 1), datetime(2001, 1, 2)]}
526+
),
527+
pl.DataFrame({"datetime": [datetime(2001, 1, 1)]}),
528+
"datetime",
529+
),
530+
(
531+
ibis.memtable({"date": [date(2001, 1, 1), date(2001, 1, 2)]}),
532+
ibis.memtable({"date": [date(2001, 1, 1)]}),
533+
"date",
534+
),
535+
(
536+
ibis.memtable(
537+
{"datetime": [datetime(2001, 1, 1), datetime(2001, 1, 2)]}
538+
),
539+
ibis.memtable({"datetime": [datetime(2001, 1, 1)]}),
540+
"datetime",
541+
),
542+
],
543+
)
544+
def test_filter_rows_in_dates(
545+
df: DataFrameType, expected: DataFrameType, column: str
546+
) -> None:
547+
transform = FilterRowsTransform(
548+
type=TransformType.FILTER_ROWS,
549+
operation="keep_rows",
550+
where=[
551+
Condition(
552+
column_id=column,
553+
operator="in",
554+
value=["2001-01-01"], # Backend will receive as string
555+
),
556+
],
557+
)
558+
result = apply(df, transform)
559+
assert_frame_equal(result, expected)
560+
475561
@staticmethod
476562
@pytest.mark.parametrize(
477563
("df", "expected"),

0 commit comments

Comments
 (0)