Skip to content

Commit d31186e

Browse files
committed
feat(wrap_stdio): sepalate wrap stdio module
1 parent 51d11a0 commit d31186e

File tree

6 files changed

+156
-93
lines changed

6 files changed

+156
-93
lines changed

commitizen/commands/commit.py

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import contextlib
22
import os
3-
import selectors
4-
import sys
53
import tempfile
6-
from asyncio import DefaultEventLoopPolicy, get_event_loop_policy, set_event_loop_policy
7-
from io import IOBase
84

95
import questionary
106

@@ -20,38 +16,7 @@
2016
NotAGitProjectError,
2117
NothingToCommitError,
2218
)
23-
24-
25-
class CZEventLoopPolicy(DefaultEventLoopPolicy): # type: ignore
26-
def get_event_loop(self):
27-
self.set_event_loop(self._loop_factory(selectors.SelectSelector()))
28-
return self._local._loop
29-
30-
31-
class WrapStdx:
32-
def __init__(self, stdx: IOBase):
33-
self._fileno = stdx.fileno()
34-
if sys.platform == "linux":
35-
if self._fileno == 0:
36-
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
37-
tty = open(fd, "wb+", buffering=0)
38-
else:
39-
tty = open("/dev/tty", "w") # type: ignore
40-
else:
41-
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
42-
if self._fileno == 0:
43-
tty = open(fd, "wb+", buffering=0)
44-
else:
45-
tty = open(fd, "rb+", buffering=0)
46-
self.tty = tty
47-
48-
def __getattr__(self, key):
49-
if key == "encoding" and (sys.platform != "linux" or self._fileno == 0):
50-
return "UTF-8"
51-
return getattr(self.tty, key)
52-
53-
def __del__(self):
54-
self.tty.close()
19+
from commitizen.wrap_stdio import unwrap_stdio, wrap_stdio
5520

5621

5722
class Commit:
@@ -101,14 +66,7 @@ def __call__(self):
10166

10267
commit_msg_file: str = self.arguments.get("commit_msg_file")
10368
if commit_msg_file:
104-
old_stdin = sys.stdin
105-
old_stdout = sys.stdout
106-
old_stderr = sys.stderr
107-
old_event_loop_policy = get_event_loop_policy()
108-
set_event_loop_policy(CZEventLoopPolicy())
109-
sys.stdin = WrapStdx(sys.stdin)
110-
sys.stdout = WrapStdx(sys.stdout)
111-
sys.stderr = WrapStdx(sys.stderr)
69+
wrap_stdio()
11270

11371
if git.is_staging_clean() and not dry_run:
11472
raise NothingToCommitError("No files added to staging!")
@@ -121,13 +79,7 @@ def __call__(self):
12179
m = self.prompt_commit_questions()
12280

12381
if commit_msg_file:
124-
sys.stdin.close()
125-
sys.stdout.close()
126-
sys.stderr.close()
127-
set_event_loop_policy(old_event_loop_policy)
128-
sys.stdin = old_stdin
129-
sys.stdout = old_stdout
130-
sys.stderr = old_stderr
82+
unwrap_stdio()
13183

13284
out.info(f"\n{m}\n")
13385

commitizen/wrap_stdio.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import sys
2+
3+
if sys.platform == "win32": # pragma: no cover
4+
from .wrap_stdio_windows import _unwrap_stdio, _wrap_stdio
5+
elif sys.platform == "linux":
6+
from .wrap_stdio_linux import _unwrap_stdio, _wrap_stdio # pragma: no cover
7+
else:
8+
from .wrap_stdio_unix import _unwrap_stdio, _wrap_stdio # pragma: no cover
9+
10+
11+
def wrap_stdio():
12+
_wrap_stdio()
13+
return None
14+
15+
16+
def unwrap_stdio():
17+
_unwrap_stdio()
18+
return None

commitizen/wrap_stdio_linux.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
import sys
3+
from io import IOBase
4+
5+
6+
class WrapStdioLinux:
7+
def __init__(self, stdx: IOBase):
8+
self._fileno = stdx.fileno()
9+
if self._fileno == 0:
10+
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
11+
tty = open(fd, "wb+", buffering=0)
12+
else:
13+
tty = open("/dev/tty", "w") # type: ignore
14+
self.tty = tty
15+
16+
def __getattr__(self, key):
17+
if key == "encoding" and self._fileno == 0:
18+
return "UTF-8"
19+
return getattr(self.tty, key)
20+
21+
def __del__(self):
22+
self.tty.close()
23+
24+
25+
backup_stdin = None
26+
backup_stdout = None
27+
backup_stderr = None
28+
29+
30+
def _wrap_stdio():
31+
global backup_stdin
32+
backup_stdin = sys.stdin
33+
sys.stdin = WrapStdioLinux(sys.stdin)
34+
35+
global backup_stdout
36+
backup_stdout = sys.stdout
37+
sys.stdout = WrapStdioLinux(sys.stdout)
38+
39+
global backup_stderr
40+
backup_stderr = sys.stderr
41+
sys.stderr = WrapStdioLinux(sys.stderr)
42+
43+
44+
def _unwrap_stdio():
45+
global backup_stdin
46+
sys.stdin.close()
47+
sys.stdin = backup_stdin
48+
49+
global backup_stdout
50+
sys.stdout.close()
51+
sys.stdout = backup_stdout
52+
53+
global backup_stderr
54+
sys.stderr.close()
55+
sys.stderr = backup_stderr

commitizen/wrap_stdio_unix.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
import selectors
3+
import sys
4+
from asyncio import DefaultEventLoopPolicy, get_event_loop_policy, set_event_loop_policy
5+
from io import IOBase
6+
7+
8+
class CZEventLoopPolicy(DefaultEventLoopPolicy): # type: ignore
9+
def get_event_loop(self):
10+
self.set_event_loop(self._loop_factory(selectors.SelectSelector()))
11+
return self._local._loop
12+
13+
14+
class WrapStdioLinux:
15+
def __init__(self, stdx: IOBase):
16+
self._fileno = stdx.fileno()
17+
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
18+
if self._fileno == 0:
19+
tty = open(fd, "wb+", buffering=0)
20+
else:
21+
tty = open(fd, "rb+", buffering=0)
22+
self.tty = tty
23+
24+
def __getattr__(self, key):
25+
if key == "encoding":
26+
return "UTF-8"
27+
return getattr(self.tty, key)
28+
29+
def __del__(self):
30+
self.tty.close()
31+
32+
33+
backup_event_loop_policy = None
34+
backup_stdin = None
35+
backup_stdout = None
36+
backup_stderr = None
37+
38+
39+
def _wrap_stdio():
40+
global backup_event_loop_policy
41+
backup_event_loop_policy = get_event_loop_policy()
42+
set_event_loop_policy(CZEventLoopPolicy())
43+
44+
global backup_stdin
45+
backup_stdin = sys.stdin
46+
sys.stdin = WrapStdioLinux(sys.stdin)
47+
48+
global backup_stdout
49+
backup_stdout = sys.stdout
50+
sys.stdout = WrapStdioLinux(sys.stdout)
51+
52+
global backup_stderr
53+
backup_stdout = sys.stderr
54+
sys.stderr = WrapStdioLinux(sys.stderr)
55+
56+
57+
def _unwrap_stdio():
58+
global backup_event_loop_policy
59+
set_event_loop_policy(backup_event_loop_policy)
60+
61+
global backup_stdin
62+
sys.stdin.close()
63+
sys.stdin = backup_stdin
64+
65+
global backup_stdout
66+
sys.stdout.close()
67+
sys.stdout = backup_stdout
68+
69+
global backup_stderr
70+
sys.stderr.close()
71+
sys.stderr = backup_stderr

commitizen/wrap_stdio_windows.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def _wrap_stdio():
2+
pass
3+
4+
5+
def _unwrap_stdio():
6+
pass

tests/commands/test_commit_command.py

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,9 @@ def test_commit_from_pre_commit_msg_hook(config, mocker, capsys):
198198

199199
commit_mock = mocker.patch("commitizen.git.commit")
200200
commit_mock.return_value = cmd.Command("success", "", "", "", 0)
201-
mocker.patch("commitizen.commands.commit.WrapStdx")
202-
mocker.patch("os.open")
201+
202+
mocker.patch("commitizen.commands.commit.wrap_stdio")
203+
mocker.patch("commitizen.commands.commit.unwrap_stdio")
203204
reader_mock = mocker.mock_open(read_data="\n\n#test\n")
204205
mocker.patch("builtins.open", reader_mock, create=True)
205206

@@ -208,43 +209,3 @@ def test_commit_from_pre_commit_msg_hook(config, mocker, capsys):
208209
out, _ = capsys.readouterr()
209210
assert "Commit message is successful!" in out
210211
commit_mock.assert_not_called()
211-
212-
213-
def test_WrapStdx(mocker):
214-
mocker.patch("os.open")
215-
reader_mock = mocker.mock_open(read_data="data")
216-
mocker.patch("builtins.open", reader_mock, create=True)
217-
218-
stdin_mock_fileno = mocker.patch.object(sys.stdin, "fileno")
219-
stdin_mock_fileno.return_value = 0
220-
wrap_stdin = commands.commit.WrapStdx(sys.stdin)
221-
222-
assert wrap_stdin.encoding == "UTF-8"
223-
assert wrap_stdin.read() == "data"
224-
225-
writer_mock = mocker.mock_open(read_data="data")
226-
mocker.patch("builtins.open", writer_mock, create=True)
227-
stdout_mock_fileno = mocker.patch.object(sys.stdout, "fileno")
228-
stdout_mock_fileno.return_value = 1
229-
wrap_stout = commands.commit.WrapStdx(sys.stdout)
230-
wrap_stout.write("data")
231-
232-
if sys.platform == "linux":
233-
writer_mock.assert_called_once_with("/dev/tty", "w")
234-
else:
235-
pass
236-
writer_mock().write.assert_called_once_with("data")
237-
238-
writer_mock = mocker.mock_open(read_data="data")
239-
mocker.patch("builtins.open", writer_mock, create=True)
240-
stderr_mock_fileno = mocker.patch.object(sys.stdout, "fileno")
241-
stderr_mock_fileno.return_value = 2
242-
wrap_sterr = commands.commit.WrapStdx(sys.stderr)
243-
244-
wrap_sterr.write("data")
245-
246-
if sys.platform == "linux":
247-
writer_mock.assert_called_once_with("/dev/tty", "w")
248-
else:
249-
pass
250-
writer_mock().write.assert_called_once_with("data")

0 commit comments

Comments
 (0)