forked from argoverse/argoverse-api
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_argoverse_forecasting_loader.py
49 lines (31 loc) · 1.4 KB
/
test_argoverse_forecasting_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# <Copyright 2019, Argo AI, LLC. Released under the MIT license.>
"""Forecasting Loader unit tests"""
import glob
import pathlib
import numpy as np
import pytest
from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader
TEST_DATA_LOC = pathlib.Path(__file__).parent.parent / "tests" / "test_data" / "forecasting"
@pytest.fixture
def data_loader() -> ArgoverseForecastingLoader:
return ArgoverseForecastingLoader(TEST_DATA_LOC)
def test_id_list(data_loader: ArgoverseForecastingLoader) -> None:
track_id_gt = [
"00000000-0000-0000-0000-000000000000",
"00000000-0000-0000-0000-000000007735",
"00000000-0000-0000-0000-000000008206",
]
assert data_loader.track_id_list == track_id_gt
def test_city_name(data_loader: ArgoverseForecastingLoader) -> None:
assert data_loader.city == "MIA"
def test_num_track(data_loader: ArgoverseForecastingLoader) -> None:
assert data_loader.num_tracks == 3
def test_seq_df(data_loader: ArgoverseForecastingLoader) -> None:
assert data_loader.seq_df is not None
def test_agent_traj(data_loader: ArgoverseForecastingLoader) -> None:
traj_gt = [[10, 5], [10, 10]]
assert np.array_equal(data_loader.agent_traj, traj_gt)
def test_get(data_loader: ArgoverseForecastingLoader) -> None:
data_1 = data_loader.get("0")
data_2 = data_loader[0]
assert data_1.current_seq == data_2.current_seq