Skip to content

Commit 297d3b5

Browse files
committed
Type testinfra;utils.ansible_runner
1 parent 8ecab66 commit 297d3b5

File tree

3 files changed

+61
-33
lines changed

3 files changed

+61
-33
lines changed

testinfra/host.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212

1313
import functools
1414
import os
15+
from typing import Any
1516

1617
import testinfra.backend
18+
import testinfra.backend.base
1719
import testinfra.modules
1820

1921

2022
class Host:
21-
_host_cache = {} # type: ignore[var-annotated]
23+
_host_cache: dict[tuple[str, frozenset[tuple[str, Any]]], "Host"] = {}
2224
_hosts_cache = {} # type: ignore[var-annotated]
2325

2426
def __init__(self, backend):
@@ -58,7 +60,9 @@ def find_command(self, command, extrapaths=("/sbin", "/usr/sbin")):
5860
return path
5961
raise ValueError('cannot find "{}" command'.format(command))
6062

61-
def run(self, command, *args, **kwargs):
63+
def run(
64+
self, command: str, *args: str, **kwargs: Any
65+
) -> testinfra.backend.base.CommandResult:
6266
"""Run given command and return rc (exit status), stdout and stderr
6367
6468
>>> cmd = host.run("ls -l /etc/passwd")
@@ -85,9 +89,11 @@ def run(self, command, *args, **kwargs):
8589
'ls: cannot access /;echo inject: No such file or directory\\n'),
8690
command="ls -l '/;echo inject'")
8791
"""
88-
return self.backend.run(command, *args, **kwargs)
92+
return self.backend.run(command, *args, **kwargs) # type: ignore[no-any-return]
8993

90-
def run_expect(self, expected, command, *args, **kwargs):
94+
def run_expect(
95+
self, expected: list[int], command: str, *args: str, **kwargs: Any
96+
) -> testinfra.backend.base.CommandResult:
9197
"""Run command and check it return an expected exit status
9298
9399
:param expected: A list of expected exit status
@@ -105,7 +111,7 @@ def run_test(self, command, *args, **kwargs):
105111
"""
106112
return self.run_expect([0, 1], command, *args, **kwargs)
107113

108-
def check_output(self, command, *args, **kwargs):
114+
def check_output(self, command: str, *args: str, **kwargs: Any) -> str:
109115
"""Get stdout of a command which has run successfully
110116
111117
:returns: stdout without trailing newline
@@ -127,7 +133,7 @@ def __getattr__(self, name):
127133
)
128134

129135
@classmethod
130-
def get_host(cls, hostspec, **kwargs):
136+
def get_host(cls, hostspec: str, **kwargs: Any) -> "Host":
131137
"""Return a Host instance from `hostspec`
132138
133139
`hostspec` should be like

testinfra/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
import pytest
1616

1717

18-
def main():
18+
def main() -> int:
1919
warnings.warn("calling testinfra is deprecated, call py.test instead", stacklevel=1)
2020
return pytest.main()

testinfra/utils/ansible_runner.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
import json
1818
import os
1919
import tempfile
20-
from typing import Any, Dict, List, Optional, Union
20+
from typing import Any, Callable, Iterator, Optional, Union
2121

2222
import testinfra
23+
import testinfra.host
2324

2425
__all__ = ["AnsibleRunner"]
2526

2627
local = testinfra.get_host("local://")
2728

2829

29-
def get_ansible_config():
30+
def get_ansible_config() -> configparser.ConfigParser:
3031
fname = os.environ.get("ANSIBLE_CONFIG")
3132
if not fname:
3233
for possible in (
@@ -44,18 +45,29 @@ def get_ansible_config():
4445
return config
4546

4647

47-
def get_ansible_inventory(config, inventory_file):
48+
Inventory = dict[str, Any]
49+
50+
51+
def get_ansible_inventory(
52+
config: configparser.ConfigParser, inventory_file: Optional[str]
53+
) -> Inventory:
4854
# Disable ansible verbosity to avoid
4955
# https://github.com/ansible/ansible/issues/59973
5056
cmd = "ANSIBLE_VERBOSITY=0 ansible-inventory --list"
5157
args = []
5258
if inventory_file:
5359
cmd += " -i %s"
5460
args += [inventory_file]
55-
return json.loads(local.check_output(cmd, *args))
61+
return json.loads(local.check_output(cmd, *args)) # type: ignore[no-any-return]
5662

5763

58-
def get_ansible_host(config, inventory, host, ssh_config=None, ssh_identity_file=None):
64+
def get_ansible_host(
65+
config: configparser.ConfigParser,
66+
inventory: Inventory,
67+
host: str,
68+
ssh_config: Optional[str] = None,
69+
ssh_identity_file: Optional[str] = None,
70+
) -> Optional[testinfra.host.Host]:
5971
if is_empty_inventory(inventory):
6072
if host == "localhost":
6173
return testinfra.get_host("local://")
@@ -81,7 +93,7 @@ def get_ansible_host(config, inventory, host, ssh_config=None, ssh_identity_file
8193
"smart": "ssh",
8294
}.get(connection, connection)
8395

84-
options: Dict[str, Any] = {
96+
options: dict[str, Any] = {
8597
"ansible_become": {
8698
"ini": {
8799
"section": "privilege_escalation",
@@ -126,7 +138,9 @@ def get_ansible_host(config, inventory, host, ssh_config=None, ssh_identity_file
126138
},
127139
}
128140

129-
def get_config(name, default=None):
141+
def get_config(
142+
name: str, default: Union[None, bool, str] = None
143+
) -> Union[None, bool, str]:
130144
value = default
131145
option = options.get(name, {})
132146

@@ -144,11 +158,12 @@ def get_config(name, default=None):
144158
return value
145159

146160
testinfra_host = get_config("ansible_host", host)
161+
assert isinstance(testinfra_host, str), testinfra_host
147162
user = get_config("ansible_user")
148163
password = get_config("ansible_ssh_pass")
149164
port = get_config("ansible_port")
150165

151-
kwargs: Dict[str, Union[str, bool]] = {}
166+
kwargs: dict[str, Union[None, str, bool]] = {}
152167
if get_config("ansible_become", False):
153168
kwargs["sudo"] = True
154169
kwargs["sudo_user"] = get_config("ansible_become_user")
@@ -165,8 +180,8 @@ def get_config(name, default=None):
165180
kwargs["ssh_extra_args"] = " ".join(
166181
[
167182
config.get("ssh_connection", "ssh_args", fallback=""),
168-
get_config("ansible_ssh_common_args", ""),
169-
get_config("ansible_ssh_extra_args", ""),
183+
get_config("ansible_ssh_common_args", ""), # type: ignore[list-item]
184+
get_config("ansible_ssh_extra_args", ""), # type: ignore[list-item]
170185
]
171186
).strip()
172187

@@ -191,20 +206,20 @@ def get_config(name, default=None):
191206
return testinfra.get_host(spec, **kwargs)
192207

193208

194-
def itergroup(inventory, group):
209+
def itergroup(inventory: Inventory, group: str) -> Iterator[str]:
195210
for host in inventory.get(group, {}).get("hosts", []):
196211
yield host
197212
for g in inventory.get(group, {}).get("children", []):
198213
for host in itergroup(inventory, g):
199214
yield host
200215

201216

202-
def is_empty_inventory(inventory):
217+
def is_empty_inventory(inventory: Inventory) -> bool:
203218
return not any(True for _ in itergroup(inventory, "all"))
204219

205220

206221
class AnsibleRunner:
207-
_runners: Dict[Optional[str], "AnsibleRunner"] = {}
222+
_runners: dict[Optional[str], "AnsibleRunner"] = {}
208223
_known_options = {
209224
# Boolean arguments.
210225
"become": {
@@ -243,12 +258,12 @@ class AnsibleRunner:
243258
},
244259
}
245260

246-
def __init__(self, inventory_file=None):
261+
def __init__(self, inventory_file: Optional[str] = None):
247262
self.inventory_file = inventory_file
248-
self._host_cache = {}
263+
self._host_cache: dict[str, Optional[testinfra.host.Host]] = {}
249264
super().__init__()
250265

251-
def get_hosts(self, pattern="all"):
266+
def get_hosts(self, pattern: str = "all") -> list[str]:
252267
inventory = self.inventory
253268
result = set()
254269
if is_empty_inventory(inventory):
@@ -271,18 +286,18 @@ def get_hosts(self, pattern="all"):
271286
return sorted(result)
272287

273288
@functools.cached_property
274-
def inventory(self):
289+
def inventory(self) -> Inventory:
275290
return get_ansible_inventory(self.ansible_config, self.inventory_file)
276291

277292
@functools.cached_property
278-
def ansible_config(self):
293+
def ansible_config(self) -> configparser.ConfigParser:
279294
return get_ansible_config()
280295

281-
def get_variables(self, host):
296+
def get_variables(self, host: str) -> dict[str, Any]:
282297
inventory = self.inventory
283298
# inventory_hostname, group_names and groups are for backward
284299
# compatibility with testinfra 2.X
285-
hostvars = inventory["_meta"].get("hostvars", {}).get(host, {})
300+
hostvars: dict[str, Any] = inventory["_meta"].get("hostvars", {}).get(host, {})
286301
hostvars.setdefault("inventory_hostname", host)
287302
group_names = []
288303
groups = {}
@@ -296,7 +311,7 @@ def get_variables(self, host):
296311
hostvars.setdefault("groups", groups)
297312
return hostvars
298313

299-
def get_host(self, host, **kwargs):
314+
def get_host(self, host: str, **kwargs: Any) -> Optional[testinfra.host.Host]:
300315
try:
301316
return self._host_cache[host]
302317
except KeyError:
@@ -305,14 +320,14 @@ def get_host(self, host, **kwargs):
305320
)
306321
return self._host_cache[host]
307322

308-
def options_to_cli(self, options):
323+
def options_to_cli(self, options: dict[str, Any]) -> tuple[str, list[str]]:
309324
verbose = options.pop("verbose", 0)
310325

311326
args = {"become": False, "check": True}
312327
args.update(options)
313328

314-
cli: List[str] = []
315-
cli_args: List[str] = []
329+
cli: list[str] = []
330+
cli_args: list[str] = []
316331
if verbose:
317332
cli.append("-" + "v" * verbose)
318333
for arg_name, value in args.items():
@@ -334,7 +349,14 @@ def options_to_cli(self, options):
334349
raise TypeError("Unsupported argument type '{}'.".format(opt_type))
335350
return " ".join(cli), cli_args
336351

337-
def run_module(self, host, module_name, module_args, get_encoding=None, **options):
352+
def run_module(
353+
self,
354+
host: str,
355+
module_name: str,
356+
module_args: str,
357+
get_encoding: Optional[Callable[[], str]] = None,
358+
**options: Any,
359+
) -> Any:
338360
cmd, args = "ansible --tree %s", []
339361
if self.inventory_file:
340362
cmd += " -i %s"
@@ -375,7 +397,7 @@ def run_module(self, host, module_name, module_args, get_encoding=None, **option
375397
return json.load(f)
376398

377399
@classmethod
378-
def get_runner(cls, inventory):
400+
def get_runner(cls, inventory: Optional[str]) -> "AnsibleRunner":
379401
try:
380402
return cls._runners[inventory]
381403
except KeyError:

0 commit comments

Comments
 (0)