Skip to content

Commit 8ee333b

Browse files
committed
Added most voting method
1 parent 91ebec9 commit 8ee333b

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

tensorcircuit/templates/ensemble.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Useful utilities for ensemble
33
"""
44

5-
from typing import Any, List
5+
from typing import Any, List, Optional
66
import tensorflow as tf
77
import keras
88
import numpy as np
@@ -94,7 +94,19 @@ def __voting_weight_single(self, array: NDArray) -> float:
9494
result = array * weight
9595
return float(np.sum(result))
9696

97-
def predict(self, input_data: NDArray, voting_policy: str = "None") -> NDArray:
97+
def __voting_most(self, array: NDArray) -> NDArray:
98+
return_result = []
99+
for items in array:
100+
items_ = items > 0.5
101+
result = 0
102+
for i in items_:
103+
result += 1 if i else -1
104+
return_result.append(1 if result > 0 else 0)
105+
return np.array(return_result)
106+
107+
def predict(
108+
self, input_data: NDArray, voting_policy: Optional[str] = "None"
109+
) -> NDArray:
98110
"""
99111
Input data is expected to be a 2D array that the first layer is different input data (into the trained models)
100112
"""
@@ -105,9 +117,11 @@ def predict(self, input_data: NDArray, voting_policy: str = "None") -> NDArray:
105117
self.predictions = np.transpose(np.array(predictions))
106118
if voting_policy == "weight":
107119
return self.__voting_weight(self.predictions)
120+
elif voting_policy == "most":
121+
return self.__voting_most(self.predictions)
108122
elif voting_policy == "average":
109123
return self.__voting_average(self.predictions)
110-
elif voting_policy == "None":
124+
elif voting_policy is None:
111125
return self.predictions
112126
else:
113127
raise ValueError("voting_policy must be none, weight, most, or average")

tests/test_ensemble.py

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def model():
6161

6262
v_weight = obj_bagging.predict(x_train, "weight")
6363
v_average = obj_bagging.predict(x_train, "average")
64+
v_most = obj_bagging.predict(x_train, "most")
6465
validation_data = []
6566
validation_data.append(obj_bagging.eval([y_train, v_weight], "acc"))
6667
validation_data.append(obj_bagging.eval([y_train, v_average], "auc"))
68+
validation_data.append(obj_bagging.eval([y_train, v_most], "acc"))

0 commit comments

Comments
 (0)