|
1 | 1 | # Copyright 2024 Marimo. All rights reserved. |
2 | 2 | from __future__ import annotations |
3 | 3 |
|
| 4 | +from datetime import date, datetime |
4 | 5 | from typing import Any, cast |
5 | 6 | from unittest.mock import Mock |
6 | 7 |
|
@@ -443,6 +444,39 @@ def test_handle_filter_rows_6( |
443 | 444 | result = apply(df, transform) |
444 | 445 | assert_frame_equal(result, expected) |
445 | 446 |
|
| 447 | + @staticmethod |
| 448 | + @pytest.mark.parametrize( |
| 449 | + ("df", "expected"), |
| 450 | + [ |
| 451 | + ( |
| 452 | + pd.DataFrame({"date": [date(2001, 1, 1), date(2001, 1, 2)]}), |
| 453 | + pd.DataFrame({"date": [date(2001, 1, 1)]}), |
| 454 | + ), |
| 455 | + ( |
| 456 | + pl.DataFrame({"date": [date(2001, 1, 1), date(2001, 1, 2)]}), |
| 457 | + pl.DataFrame({"date": [date(2001, 1, 1)]}), |
| 458 | + ), |
| 459 | + ( |
| 460 | + ibis.memtable({"date": [date(2001, 1, 1), date(2001, 1, 2)]}), |
| 461 | + ibis.memtable({"date": [date(2001, 1, 1)]}), |
| 462 | + ), |
| 463 | + ], |
| 464 | + ) |
| 465 | + def test_handle_filter_rows_date( |
| 466 | + df: DataFrameType, expected: DataFrameType |
| 467 | + ) -> None: |
| 468 | + transform = FilterRowsTransform( |
| 469 | + type=TransformType.FILTER_ROWS, |
| 470 | + operation="keep_rows", |
| 471 | + where=[ |
| 472 | + Condition( |
| 473 | + column_id="date", operator="==", value=date(2001, 1, 1) |
| 474 | + ) |
| 475 | + ], |
| 476 | + ) |
| 477 | + result = apply(df, transform) |
| 478 | + assert_frame_equal(result, expected) |
| 479 | + |
446 | 480 | @staticmethod |
447 | 481 | @pytest.mark.parametrize( |
448 | 482 | ("df", "expected"), |
@@ -472,6 +506,58 @@ def test_filter_rows_in_operator( |
472 | 506 | result = apply(df, transform) |
473 | 507 | assert_frame_equal(result, expected) |
474 | 508 |
|
| 509 | + @staticmethod |
| 510 | + @pytest.mark.parametrize( |
| 511 | + ("df", "expected", "column"), |
| 512 | + [ |
| 513 | + # TODO: Pandas treats date objects as strings |
| 514 | + # ( |
| 515 | + # pd.DataFrame({"date": [date(2001, 1, 1), date(2001, 1, 2)]}), |
| 516 | + # pd.DataFrame({"date": [date(2001, 1, 1)]}), |
| 517 | + # ), |
| 518 | + ( |
| 519 | + pl.DataFrame({"date": [date(2001, 1, 1), date(2001, 1, 2)]}), |
| 520 | + pl.DataFrame({"date": [date(2001, 1, 1)]}), |
| 521 | + "date", |
| 522 | + ), |
| 523 | + ( |
| 524 | + pl.DataFrame( |
| 525 | + {"datetime": [datetime(2001, 1, 1), datetime(2001, 1, 2)]} |
| 526 | + ), |
| 527 | + pl.DataFrame({"datetime": [datetime(2001, 1, 1)]}), |
| 528 | + "datetime", |
| 529 | + ), |
| 530 | + ( |
| 531 | + ibis.memtable({"date": [date(2001, 1, 1), date(2001, 1, 2)]}), |
| 532 | + ibis.memtable({"date": [date(2001, 1, 1)]}), |
| 533 | + "date", |
| 534 | + ), |
| 535 | + ( |
| 536 | + ibis.memtable( |
| 537 | + {"datetime": [datetime(2001, 1, 1), datetime(2001, 1, 2)]} |
| 538 | + ), |
| 539 | + ibis.memtable({"datetime": [datetime(2001, 1, 1)]}), |
| 540 | + "datetime", |
| 541 | + ), |
| 542 | + ], |
| 543 | + ) |
| 544 | + def test_filter_rows_in_dates( |
| 545 | + df: DataFrameType, expected: DataFrameType, column: str |
| 546 | + ) -> None: |
| 547 | + transform = FilterRowsTransform( |
| 548 | + type=TransformType.FILTER_ROWS, |
| 549 | + operation="keep_rows", |
| 550 | + where=[ |
| 551 | + Condition( |
| 552 | + column_id=column, |
| 553 | + operator="in", |
| 554 | + value=["2001-01-01"], # Backend will receive as string |
| 555 | + ), |
| 556 | + ], |
| 557 | + ) |
| 558 | + result = apply(df, transform) |
| 559 | + assert_frame_equal(result, expected) |
| 560 | + |
475 | 561 | @staticmethod |
476 | 562 | @pytest.mark.parametrize( |
477 | 563 | ("df", "expected"), |
|
0 commit comments