Skip to content

Commit 3bafa14

Browse files
feature: Enable AgentSet.get to retrieve one or more than one attribute (#2044)
* new feature: AgentSet.get can retrieve one or more than one attribute * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor docstring update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add AttributeError to docstring and fixes capital L in List * Update agent.py * Update agent.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 913dfb8 commit 3bafa14

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

mesa/agent.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,17 +264,29 @@ def do(
264264

265265
return res if return_results else self
266266

267-
def get(self, attr_name: str) -> list[Any]:
267+
def get(self, attr_names: str | list[str]) -> list[Any]:
268268
"""
269-
Retrieve a specified attribute from each agent in the AgentSet.
269+
Retrieve the specified attribute(s) from each agent in the AgentSet.
270270
271271
Args:
272-
attr_name (str): The name of the attribute to retrieve from each agent.
272+
attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent.
273273
274274
Returns:
275-
list[Any]: A list of attribute values from each agent in the set.
275+
list[Any]: A list with the attribute value for each agent in the set if attr_names is a str
276+
list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str
277+
278+
Raises:
279+
AttributeError if an agent does not have the specified attribute(s)
280+
276281
"""
277-
return [getattr(agent, attr_name) for agent in self._agents]
282+
283+
if isinstance(attr_names, str):
284+
return [getattr(agent, attr_names) for agent in self._agents]
285+
else:
286+
return [
287+
[getattr(agent, attr_name) for attr_name in attr_names]
288+
for agent in self._agents
289+
]
278290

279291
def __getitem__(self, item: int | slice) -> Agent:
280292
"""

tests/test_agent.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,24 @@ def test_agentset_get_attribute():
221221
with pytest.raises(AttributeError):
222222
agentset.get("non_existing_attribute")
223223

224+
model = Model()
225+
agents = []
226+
for i in range(10):
227+
agent = TestAgent(model.next_id(), model)
228+
agent.i = i**2
229+
agents.append(agent)
230+
agentset = AgentSet(agents, model)
231+
232+
values = agentset.get(["unique_id", "i"])
233+
234+
for value, agent in zip(values, agents):
235+
(
236+
unique_id,
237+
i,
238+
) = value
239+
assert agent.unique_id == unique_id
240+
assert agent.i == i
241+
224242

225243
class OtherAgentType(Agent):
226244
def get_unique_identifier(self):

0 commit comments

Comments
 (0)