|
7 | 7 | from mpl_toolkits.mplot3d import Axes3D
|
8 | 8 | from data_processing import DataProcessor, DataImporter
|
9 | 9 | import randomcolor
|
10 |
| - |
| 10 | +import pprint |
| 11 | +from collections import OrderedDict |
11 | 12 |
|
12 | 13 | class KMeansClusterer:
|
13 | 14 |
|
@@ -115,21 +116,54 @@ def _make_ellipses(self, gmm, ax, colors):
|
115 | 116 | importer = DataImporter()
|
116 | 117 | processor = DataProcessor()
|
117 | 118 |
|
118 |
| - trajectories = importer.import_csv_to_list('../toy_data/*.csv') |
119 |
| - observations = processor.concatenate_trajectory_observations(trajectories) |
| 119 | + trajectories_dict = importer.import_csv_to_dict('../toy_data/raw_trajectories/*.csv') |
| 120 | + observations = [] |
| 121 | + for t in trajectories_dict["trajectories"]: |
| 122 | + for observation in t["observations"]: |
| 123 | + observations.append(processor.convert_trajectory_dict_to_list(observation)) |
120 | 124 | observations = [[entry[1], entry[2], entry[3], entry[4], entry[5], entry[6], entry[7]] for entry in observations]
|
121 | 125 | np_observation = processor.convert_to_numpy(observations)
|
122 | 126 |
|
123 | 127 | km_clusterer = KMeansClusterer(np_observation, n_clusters=5)
|
124 | 128 | km_clusterer.kmeans_fit()
|
125 |
| - km_clusterer.view_XYZ_clusters() |
126 |
| - cluster_data = km_clusterer.get_cluster_samples() |
| 129 | + |
| 130 | + cluster_data = OrderedDict |
| 131 | + for t in trajectories_dict["trajectories"]: |
| 132 | + for observation in t["observations"]: |
| 133 | + sample = processor.convert_to_numpy([processor.convert_trajectory_dict_to_list(observation, key_order=["PoseX", "PoseY", "PoseZ", "OrienX", "OrienY", "OrienZ", "OrienW"])]) |
| 134 | + cluster = km_clusterer.kmeans.predict(sample)[0] |
| 135 | + observation["cluster"] = cluster |
| 136 | + if cluster in cluster_data.keys(): |
| 137 | + cluster_data[cluster].append(observation) |
| 138 | + else: |
| 139 | + cluster_data[cluster] = [] |
| 140 | + |
| 141 | + key_frame_data = OrderedDict |
| 142 | + counter = 1 |
| 143 | + for cluster_number, observations in cluster_data.items(): |
| 144 | + # standard keyframe data |
| 145 | + keyframe_data = [] |
| 146 | + for observation in observations: |
| 147 | + keyframe_data.append(processor.convert_trajectory_dict_to_list(observation)) |
| 148 | + key_frame_data[counter] = processor.convert_to_numpy(keyframe_data) |
| 149 | + |
| 150 | + # transition keyframe data |
| 151 | + # Preceeding 12 samples |
| 152 | + key_frame_data[counter] |
| 153 | + |
| 154 | + ## |
| 155 | + |
| 156 | + |
| 157 | + |
| 158 | + pprint.pprint(cluster_data) |
| 159 | + |
| 160 | + # km_clusterer.view_XYZ_clusters(1, 2, 3) |
127 | 161 |
|
128 | 162 | counter = 0
|
129 | 163 | for cluster_number, np_array in cluster_data.items():
|
130 | 164 | gmm_keyframer = GMMKeyframe(np_array[0])
|
131 | 165 | gmm_keyframer.gmm_fit()
|
132 |
| - gmm_keyframer.view_2D_gaussians() |
| 166 | + gmm_keyframer.view_2D_gaussians(1, 2) |
133 | 167 | points, labels = gmm_keyframer.gmm.sample(500)
|
134 | 168 | samples = list(zip(points, labels))
|
135 | 169 | print(samples)
|
|
0 commit comments