Skip to content

Commit 0d4d5e5

Browse files
authored
Merge pull request #691 from projectmesa/merge-test
Resources #651 in docs branch
2 parents 52adffa + 333562e commit 0d4d5e5

File tree

2 files changed

+171
-44
lines changed

2 files changed

+171
-44
lines changed

mesa/batchrunner.py

Lines changed: 137 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from itertools import product, count
1111
import pandas as pd
1212
from tqdm import tqdm
13+
14+
import random
15+
1316
try:
1417
from pathos.multiprocessing import ProcessPool
1518
except ImportError:
@@ -18,9 +21,9 @@
1821
pathos_support = True
1922

2023

21-
class VariableParameterError(TypeError):
22-
MESSAGE = ('variable_parameters must map a name to a sequence of values. '
23-
'These parameters were given with non-sequence values: {}')
24+
class ParameterError(TypeError):
25+
MESSAGE = ('parameters must map a name to a value. '
26+
'These names did not match paramerets: {}')
2427

2528
def __init__(self, bad_names):
2629
self.bad_names = bad_names
@@ -29,7 +32,15 @@ def __str__(self):
2932
return self.MESSAGE.format(self.bad_names)
3033

3134

32-
class BatchRunner:
35+
class VariableParameterError(ParameterError):
36+
MESSAGE = ('variable_parameters must map a name to a sequence of values. '
37+
'These parameters were given with non-sequence values: {}')
38+
39+
def __init__(self, bad_names):
40+
super().__init__(bad_names)
41+
42+
43+
class FixedBatchRunner:
3344
""" This class is instantiated with a model class, and model parameters
3445
associated with one or more values. It is also instantiated with model and
3546
agent-level reporters, dictionaries mapping a variable name to a function
@@ -39,9 +50,8 @@ class BatchRunner:
3950
Note that by default, the reporters only collect data at the *end* of the
4051
run. To get step by step data, simply have a reporter store the model's
4152
entire DataCollector object.
42-
4353
"""
44-
def __init__(self, model_cls, variable_parameters=None,
54+
def __init__(self, model_cls, parameters_list=None,
4555
fixed_parameters=None, iterations=1, max_steps=1000,
4656
model_reporters=None, agent_reporters=None,
4757
display_progress=True):
@@ -50,20 +60,20 @@ def __init__(self, model_cls, variable_parameters=None,
5060
5161
Args:
5262
model_cls: The class of model to batch-run.
53-
variable_parameters: Dictionary of parameters to lists of values.
54-
The model will be run with every combo of these paramters.
55-
For example, given variable_parameters of
56-
{"param_1": range(5),
57-
"param_2": [1, 5, 10]}
58-
models will be run with {param_1=1, param_2=1},
59-
{param_1=2, param_2=1}, ..., {param_1=4, param_2=10}.
63+
parameters_list: A list of dictionaries of parameter sets.
64+
The model will be run with dictionary of paramters.
65+
For example, given parameters_list of
66+
[{"homophily": 3, "density": 0.8, "minority_pc": 0.2},
67+
{"homophily": 2, "density": 0.9, "minority_pc": 0.1},
68+
{"homophily": 4, "density": 0.6, "minority_pc": 0.5}]
69+
3 models will be run, one for each provided set of parameters.
6070
fixed_parameters: Dictionary of parameters that stay same through
6171
all batch runs. For example, given fixed_parameters of
6272
{"constant_parameter": 3},
6373
every instantiated model will be passed constant_parameter=3
6474
as a kwarg.
65-
iterations: The total number of times to run the model for each
66-
combination of parameters.
75+
iterations: The total number of times to run the model for each set
76+
of parameters.
6777
max_steps: Upper limit of steps above which each run will be halted
6878
if it hasn't halted on its own.
6979
model_reporters: The dictionary of variables to collect on each run
@@ -77,9 +87,9 @@ def __init__(self, model_cls, variable_parameters=None,
7787
7888
"""
7989
self.model_cls = model_cls
80-
if variable_parameters is None:
81-
variable_parameters = {}
82-
self.variable_parameters = self._process_parameters(variable_parameters)
90+
if parameters_list is None:
91+
parameters_list = []
92+
self.parameters_list = list(parameters_list)
8393
self.fixed_parameters = fixed_parameters or {}
8494
self._include_fixed = len(self.fixed_parameters.keys()) > 0
8595
self.iterations = iterations
@@ -96,16 +106,6 @@ def __init__(self, model_cls, variable_parameters=None,
96106

97107
self.display_progress = display_progress
98108

99-
def _process_parameters(self, params):
100-
params = copy.deepcopy(params)
101-
bad_names = []
102-
for name, values in params.items():
103-
if (isinstance(values, str) or not hasattr(values, "__iter__")):
104-
bad_names.append(name)
105-
if bad_names:
106-
raise VariableParameterError(bad_names)
107-
return params
108-
109109
def _make_model_args(self):
110110
"""Prepare all combinations of parameter values for `run_all`
111111
@@ -117,21 +117,20 @@ def _make_model_args(self):
117117
all_kwargs = []
118118
all_param_values = []
119119

120-
if len(self.variable_parameters) > 0:
121-
param_names, param_ranges = zip(*self.variable_parameters.items())
122-
for param_range in param_ranges:
123-
total_iterations *= len(param_range)
124-
125-
for param_values in product(*param_ranges):
126-
kwargs = dict(zip(param_names, param_values))
120+
count = len(self.parameters_list)
121+
if count:
122+
for params in self.parameters_list:
123+
kwargs = params.copy()
127124
kwargs.update(self.fixed_parameters)
128125
all_kwargs.append(kwargs)
129-
all_param_values.append(param_values)
130-
else:
131-
kwargs = self.fixed_parameters
132-
param_values = None
133-
all_kwargs = [kwargs]
134-
all_param_values = [None]
126+
all_param_values.append(params.values())
127+
elif len(self.fixed_parameters):
128+
count = 1
129+
kwargs = self.fixed_parameters.copy()
130+
all_kwargs.append(kwargs)
131+
all_param_values.append(kwargs.values())
132+
133+
total_iterations *= count
135134

136135
return (total_iterations, all_kwargs, all_param_values)
137136

@@ -154,7 +153,7 @@ def run_iteration(self, kwargs, param_values, run_count):
154153

155154
# Collect and store results:
156155
if param_values is not None:
157-
model_key = param_values + (run_count,)
156+
model_key = tuple(param_values) + (run_count,)
158157
else:
159158
model_key = (run_count,)
160159

@@ -215,7 +214,10 @@ def _prepare_report_table(self, vars_dict, extra_cols=None):
215214
column as a key.
216215
"""
217216
extra_cols = ['Run'] + (extra_cols or [])
218-
index_cols = list(self.variable_parameters.keys()) + extra_cols
217+
index_cols = set()
218+
for params in self.parameters_list:
219+
index_cols |= params.keys()
220+
index_cols = list(index_cols) + extra_cols
219221

220222
records = []
221223
for param_key, values in vars_dict.items():
@@ -237,6 +239,98 @@ def _prepare_report_table(self, vars_dict, extra_cols=None):
237239
return ordered
238240

239241

242+
# This is kind of a useless class, but it does carry the 'source' parameters with it
243+
class ParameterProduct:
244+
def __init__(self, variable_parameters):
245+
self.param_names, self.param_lists = \
246+
zip(*(copy.deepcopy(variable_parameters)).items())
247+
self._product = product(*self.param_lists)
248+
249+
def __iter__(self):
250+
return self
251+
252+
def __next__(self):
253+
return dict(zip(self.param_names, next(self._product)))
254+
255+
256+
# Roughly inspired by sklearn.model_selection.ParameterSampler. Does not handle
257+
# distributions, only lists.
258+
class ParameterSampler:
259+
def __init__(self, parameter_lists, n, random_state=None):
260+
self.param_names, self.param_lists = \
261+
zip(*(copy.deepcopy(parameter_lists)).items())
262+
self.n = n
263+
if random_state is None:
264+
self.random_state = random.Random()
265+
elif isinstance(random_state, int):
266+
self.random_state = random.Random(random_state)
267+
else:
268+
self.random_state = random_state
269+
self.count = 0
270+
271+
def __iter__(self):
272+
return self
273+
274+
def __next__(self):
275+
self.count += 1
276+
if self.count <= self.n:
277+
return dict(zip(self.param_names, [self.random_state.choice(l) for l in self.param_lists]))
278+
raise StopIteration()
279+
280+
281+
class BatchRunner(FixedBatchRunner):
282+
""" This class is instantiated with a model class, and model parameters
283+
associated with one or more values. It is also instantiated with model and
284+
agent-level reporters, dictionaries mapping a variable name to a function
285+
which collects some data from the model or its agents at the end of the run
286+
and stores it.
287+
288+
Note that by default, the reporters only collect data at the *end* of the
289+
run. To get step by step data, simply have a reporter store the model's
290+
entire DataCollector object.
291+
292+
"""
293+
def __init__(self, model_cls, variable_parameters=None,
294+
fixed_parameters=None, iterations=1, max_steps=1000,
295+
model_reporters=None, agent_reporters=None,
296+
display_progress=True):
297+
""" Create a new BatchRunner for a given model with the given
298+
parameters.
299+
300+
Args:
301+
model_cls: The class of model to batch-run.
302+
variable_parameters: Dictionary of parameters to lists of values.
303+
The model will be run with every combo of these paramters.
304+
For example, given variable_parameters of
305+
{"param_1": range(5),
306+
"param_2": [1, 5, 10]}
307+
models will be run with {param_1=1, param_2=1},
308+
{param_1=2, param_2=1}, ..., {param_1=4, param_2=10}.
309+
fixed_parameters: Dictionary of parameters that stay same through
310+
all batch runs. For example, given fixed_parameters of
311+
{"constant_parameter": 3},
312+
every instantiated model will be passed constant_parameter=3
313+
as a kwarg.
314+
iterations: The total number of times to run the model for each
315+
combination of parameters.
316+
max_steps: Upper limit of steps above which each run will be halted
317+
if it hasn't halted on its own.
318+
model_reporters: The dictionary of variables to collect on each run
319+
at the end, with variable names mapped to a function to collect
320+
them. For example:
321+
{"agent_count": lambda m: m.schedule.get_agent_count()}
322+
agent_reporters: Like model_reporters, but each variable is now
323+
collected at the level of each agent present in the model at
324+
the end of the run.
325+
display_progress: Display progresss bar with time estimation?
326+
327+
"""
328+
super().__init__(model_cls, ParameterProduct(variable_parameters),
329+
fixed_parameters, iterations, max_steps,
330+
model_reporters, agent_reporters,
331+
display_progress)
332+
333+
240334
class MPSupport(Exception):
241335
def __str__(self):
242336
return ("BatchRunnerMP depends on pathos, which is either not "

tests/test_batchrunner.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from mesa import Agent, Model
99
from mesa.time import BaseScheduler
10-
from mesa.batchrunner import BatchRunner
10+
from mesa.batchrunner import BatchRunner, ParameterProduct, ParameterSampler
1111

1212

1313
NUM_AGENTS = 7
@@ -163,5 +163,38 @@ def test_model_with_variable_and_fixed_kwargs(self):
163163
self.fixed_params['fixed_name'])
164164

165165

166+
class TestParameters(unittest.TestCase):
167+
def test_product(self):
168+
params = ParameterProduct({
169+
"var_alpha": ['a', 'b', 'c'],
170+
"var_num": [10, 20]
171+
})
172+
173+
lp = list(params)
174+
self.assertCountEqual(lp, [{'var_alpha': 'a', 'var_num': 10},
175+
{'var_alpha': 'a', 'var_num': 20},
176+
{'var_alpha': 'b', 'var_num': 10},
177+
{'var_alpha': 'b', 'var_num': 20},
178+
{'var_alpha': 'c', 'var_num': 10},
179+
{'var_alpha': 'c', 'var_num': 20}])
180+
181+
def test_sampler(self):
182+
params1 = ParameterSampler({
183+
"var_alpha": ['a', 'b', 'c', 'd', 'e'],
184+
"var_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]},
185+
n=10,
186+
random_state=1)
187+
params2 = ParameterSampler({
188+
"var_alpha": ['a', 'b', 'c', 'd', 'e'],
189+
"var_num": range(16)},
190+
n=10,
191+
random_state=1
192+
)
193+
194+
lp = list(params1)
195+
self.assertEqual(10, len(lp))
196+
self.assertEqual(lp, list(params2))
197+
198+
166199
if __name__ == '__main__':
167200
unittest.main()

0 commit comments

Comments
 (0)