Skip to content

Commit f9b2f6f

Browse files
committed
Type testinfra.host
1 parent 297d3b5 commit f9b2f6f

File tree

5 files changed

+61
-40
lines changed

5 files changed

+61
-40
lines changed

testinfra/backend/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
import importlib
1414
import os
1515
import urllib.parse
16+
from typing import TYPE_CHECKING, Any, Iterable
17+
18+
if TYPE_CHECKING:
19+
import testinfra.backend.base
1620

1721
BACKENDS = {
1822
"local": "testinfra.backend.local.LocalBackend",
@@ -31,13 +35,13 @@
3135
}
3236

3337

34-
def get_backend_class(connection):
38+
def get_backend_class(connection: str) -> type["testinfra.backend.base.BaseBackend"]:
3539
try:
3640
classpath = BACKENDS[connection]
3741
except KeyError:
3842
raise RuntimeError("Unknown connection type '{}'".format(connection))
3943
module, name = classpath.rsplit(".", 1)
40-
return getattr(importlib.import_module(module), name)
44+
return getattr(importlib.import_module(module), name) # type: ignore[no-any-return]
4145

4246

4347
def parse_hostspec(hostspec):
@@ -75,7 +79,7 @@ def parse_hostspec(hostspec):
7579
return host, kw
7680

7781

78-
def get_backend(hostspec, **kwargs):
82+
def get_backend(hostspec: str, **kwargs: Any) -> "testinfra.backend.base.BaseBackend":
7983
host, kw = parse_hostspec(hostspec)
8084
for k, v in kwargs.items():
8185
kw.setdefault(k, v)
@@ -86,7 +90,9 @@ def get_backend(hostspec, **kwargs):
8690
return klass(host, **kw)
8791

8892

89-
def get_backends(hosts, **kwargs):
93+
def get_backends(
94+
hosts: Iterable[str], **kwargs: Any
95+
) -> list["testinfra.backend.base.BaseBackend"]:
9096
backends = {}
9197
for hostspec in hosts:
9298
host, kw = parse_hostspec(hostspec)

testinfra/backend/base.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
import shlex
1818
import subprocess
1919
import urllib.parse
20-
from typing import Any, Optional
20+
from typing import TYPE_CHECKING, Any, Optional
21+
22+
if TYPE_CHECKING:
23+
import testinfra.host
2124

2225
logger = logging.getLogger("testinfra")
2326

@@ -117,29 +120,29 @@ def __init__(
117120
self,
118121
hostname: str,
119122
sudo: bool = False,
120-
sudo_user: Optional[bool] = None,
123+
sudo_user: Optional[str] = None,
121124
*args: Any,
122125
**kwargs: Any,
123126
):
124-
self._encoding = None
125-
self._host = None
127+
self._encoding: Optional[str] = None
128+
self._host: Optional["testinfra.host.Host"] = None
126129
self.hostname = hostname
127130
self.sudo = sudo
128131
self.sudo_user = sudo_user
129132
super().__init__()
130133

131-
def set_host(self, host):
134+
def set_host(self, host: "testinfra.host.Host") -> None:
132135
self._host = host
133136

134137
@classmethod
135-
def get_connection_type(cls):
138+
def get_connection_type(cls) -> str:
136139
"""Return the connection backend used as string.
137140
138141
Can be local, paramiko, ssh, docker, salt or ansible
139142
"""
140143
return cls.NAME
141144

142-
def get_hostname(self):
145+
def get_hostname(self) -> str:
143146
"""Return the hostname (for testinfra) of the remote or local system
144147
145148
@@ -166,11 +169,11 @@ def test(TestinfraBackend):
166169
"""
167170
return self.hostname
168171

169-
def get_pytest_id(self):
172+
def get_pytest_id(self) -> str:
170173
return self.get_connection_type() + "://" + self.get_hostname()
171174

172175
@classmethod
173-
def get_hosts(cls, host, **kwargs):
176+
def get_hosts(cls, host: str, **kwargs: Any) -> list[str]:
174177
if host is None:
175178
raise RuntimeError(
176179
"One or more hosts is required with the {} backend".format(
@@ -180,41 +183,41 @@ def get_hosts(cls, host, **kwargs):
180183
return [host]
181184

182185
@staticmethod
183-
def quote(command, *args):
186+
def quote(command: str, *args: str) -> str:
184187
if args:
185188
return command % tuple(shlex.quote(a) for a in args) # noqa: S001
186189
return command
187190

188-
def get_sudo_command(self, command, sudo_user):
191+
def get_sudo_command(self, command: str, sudo_user: Optional[str]) -> str:
189192
if sudo_user is None:
190193
return self.quote("sudo /bin/sh -c %s", command)
191194
return self.quote("sudo -u %s /bin/sh -c %s", sudo_user, command)
192195

193-
def get_command(self, command, *args):
196+
def get_command(self, command: str, *args: str) -> str:
194197
command = self.quote(command, *args)
195198
if self.sudo:
196199
command = self.get_sudo_command(command, self.sudo_user)
197200
return command
198201

199-
def run(self, command, *args, **kwargs):
202+
def run(self, command: str, *args: str, **kwargs: Any) -> CommandResult:
200203
raise NotImplementedError
201204

202-
def run_local(self, command, *args):
205+
def run_local(self, command: str, *args: str) -> CommandResult:
203206
command = self.quote(command, *args)
204-
command = self.encode(command)
207+
cmd = self.encode(command)
205208
p = subprocess.Popen(
206-
command,
209+
cmd,
207210
shell=True,
208211
stdin=subprocess.PIPE,
209212
stdout=subprocess.PIPE,
210213
stderr=subprocess.PIPE,
211214
)
212215
stdout, stderr = p.communicate()
213-
result = self.result(p.returncode, command, stdout, stderr)
216+
result = self.result(p.returncode, cmd, stdout, stderr)
214217
return result
215218

216219
@staticmethod
217-
def parse_hostspec(hostspec):
220+
def parse_hostspec(hostspec: str) -> HostSpec:
218221
name = hostspec
219222
port = None
220223
user = None
@@ -246,14 +249,14 @@ def parse_hostspec(hostspec):
246249
return HostSpec(name, port, user, password)
247250

248251
@staticmethod
249-
def parse_containerspec(containerspec):
252+
def parse_containerspec(containerspec: str) -> tuple[str, Optional[str]]:
250253
name = containerspec
251254
user = None
252255
if "@" in name:
253256
user, name = name.split("@", 1)
254257
return name, user
255258

256-
def get_encoding(self):
259+
def get_encoding(self) -> str:
257260
encoding = None
258261
for python in ("python3", "python"):
259262
cmd = self.run(
@@ -274,7 +277,7 @@ def get_encoding(self):
274277
return encoding
275278

276279
@property
277-
def encoding(self):
280+
def encoding(self) -> str:
278281
if self._encoding is None:
279282
self._encoding = self.get_encoding()
280283
return self._encoding

testinfra/host.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,44 @@
1212

1313
import functools
1414
import os
15+
from collections.abc import Iterable
1516
from typing import Any
1617

1718
import testinfra.backend
1819
import testinfra.backend.base
1920
import testinfra.modules
21+
import testinfra.modules.base
2022

2123

2224
class Host:
2325
_host_cache: dict[tuple[str, frozenset[tuple[str, Any]]], "Host"] = {}
24-
_hosts_cache = {} # type: ignore[var-annotated]
26+
_hosts_cache: dict[
27+
tuple[frozenset[str], frozenset[tuple[str, Any]]], list["Host"]
28+
] = {}
2529

26-
def __init__(self, backend):
30+
def __init__(self, backend: testinfra.backend.base.BaseBackend):
2731
self.backend = backend
2832
super().__init__()
2933

30-
def __repr__(self):
34+
def __repr__(self) -> str:
3135
return "<testinfra.host.Host {}>".format(self.backend.get_pytest_id())
3236

3337
@functools.cached_property
34-
def has_command_v(self):
38+
def has_command_v(self) -> bool:
3539
"""Return True if `command -v` is available"""
3640
return self.run("command -v command").rc == 0
3741

38-
def exists(self, command):
42+
def exists(self, command: str) -> bool:
3943
"""Return True if given command exist in $PATH"""
4044
if self.has_command_v:
4145
out = self.run("command -v %s", command)
4246
else:
4347
out = self.run_expect([0, 1], "which %s", command)
4448
return out.rc == 0
4549

46-
def find_command(self, command, extrapaths=("/sbin", "/usr/sbin")):
50+
def find_command(
51+
self, command: str, extrapaths: Iterable[str] = ("/sbin", "/usr/sbin")
52+
) -> str:
4753
"""Return path of given command
4854
4955
raise ValueError if command cannot be found
@@ -89,7 +95,7 @@ def run(
8995
'ls: cannot access /;echo inject: No such file or directory\\n'),
9096
command="ls -l '/;echo inject'")
9197
"""
92-
return self.backend.run(command, *args, **kwargs) # type: ignore[no-any-return]
98+
return self.backend.run(command, *args, **kwargs)
9399

94100
def run_expect(
95101
self, expected: list[int], command: str, *args: str, **kwargs: Any
@@ -104,7 +110,9 @@ def run_expect(
104110
assert out.rc in expected, "Unexpected exit code {} for {}".format(out.rc, out)
105111
return out
106112

107-
def run_test(self, command, *args, **kwargs):
113+
def run_test(
114+
self, command: str, *args: str, **kwargs: Any
115+
) -> testinfra.backend.base.CommandResult:
108116
"""Run command and check it return an exit status of 0 or 1
109117
110118
:raises: AssertionError
@@ -122,7 +130,7 @@ def check_output(self, command: str, *args: str, **kwargs: Any) -> str:
122130
assert out.rc == 0, "Unexpected exit code {} for {}".format(out.rc, out)
123131
return out.stdout.rstrip("\r\n")
124132

125-
def __getattr__(self, name):
133+
def __getattr__(self, name: str) -> type[testinfra.modules.base.Module]:
126134
if name in testinfra.modules.modules:
127135
module_class = testinfra.modules.get_module_class(name)
128136
obj = module_class.get_module(self)
@@ -157,7 +165,7 @@ def get_host(cls, hostspec: str, **kwargs: Any) -> "Host":
157165
return cache[key]
158166

159167
@classmethod
160-
def get_hosts(cls, hosts, **kwargs):
168+
def get_hosts(cls, hosts: Iterable[str], **kwargs: Any) -> list["Host"]:
161169
key = (frozenset(hosts), frozenset(kwargs.items()))
162170
cache = cls._hosts_cache
163171
if key not in cache:

testinfra/modules/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
# limitations under the License.
1212

1313
import importlib
14+
from typing import TYPE_CHECKING
15+
16+
if TYPE_CHECKING:
17+
import testinfra.modules.base
1418

1519
modules = {
1620
"addr": "addr:Addr",
@@ -42,8 +46,8 @@
4246
}
4347

4448

45-
def get_module_class(name):
49+
def get_module_class(name: str) -> type["testinfra.modules.base.Module"]:
4650
modname, classname = modules[name].split(":")
4751
modname = ".".join([__name__, modname])
4852
module = importlib.import_module(modname)
49-
return getattr(module, classname)
53+
return getattr(module, classname) # type: ignore[no-any-return]

testinfra/modules/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12-
import typing
12+
from typing import TYPE_CHECKING
1313

1414

1515
class Module:
16-
if typing.TYPE_CHECKING:
16+
if TYPE_CHECKING:
1717
import testinfra.host
1818

1919
_host: testinfra.host.Host
2020

2121
@classmethod
22-
def get_module(cls, _host):
22+
def get_module(cls, _host: "testinfra.host.Host") -> type["Module"]:
2323
klass = cls.get_module_class(_host)
2424
return type(
2525
klass.__name__,

0 commit comments

Comments
 (0)