Skip to content

Commit 1fda830

Browse files
saygoxLee-W
authored andcommitted
feat(wrap_stdio): separate wrap stdio module
1 parent e224d8e commit 1fda830

File tree

6 files changed

+156
-93
lines changed

6 files changed

+156
-93
lines changed

commitizen/commands/commit.py

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,9 @@
22

33
import contextlib
44
import os
5-
import selectors
65
import shutil
76
import subprocess
8-
import sys
97
import tempfile
10-
from asyncio import DefaultEventLoopPolicy, get_event_loop_policy, set_event_loop_policy
11-
from io import IOBase
128

139
import questionary
1410

@@ -28,38 +24,7 @@
2824
NothingToCommitError,
2925
)
3026
from commitizen.git import smart_open
31-
32-
33-
class CZEventLoopPolicy(DefaultEventLoopPolicy): # type: ignore
34-
def get_event_loop(self):
35-
self.set_event_loop(self._loop_factory(selectors.SelectSelector()))
36-
return self._local._loop
37-
38-
39-
class WrapStdx:
40-
def __init__(self, stdx: IOBase):
41-
self._fileno = stdx.fileno()
42-
if sys.platform == "linux":
43-
if self._fileno == 0:
44-
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
45-
tty = open(fd, "wb+", buffering=0)
46-
else:
47-
tty = open("/dev/tty", "w") # type: ignore
48-
else:
49-
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
50-
if self._fileno == 0:
51-
tty = open(fd, "wb+", buffering=0)
52-
else:
53-
tty = open(fd, "rb+", buffering=0)
54-
self.tty = tty
55-
56-
def __getattr__(self, key):
57-
if key == "encoding" and (sys.platform != "linux" or self._fileno == 0):
58-
return "UTF-8"
59-
return getattr(self.tty, key)
60-
61-
def __del__(self):
62-
self.tty.close()
27+
from commitizen.wrap_stdio import unwrap_stdio, wrap_stdio
6328

6429

6530
class Commit:
@@ -143,14 +108,7 @@ def __call__(self):
143108

144109
commit_msg_file: str = self.arguments.get("commit_msg_file")
145110
if commit_msg_file:
146-
old_stdin = sys.stdin
147-
old_stdout = sys.stdout
148-
old_stderr = sys.stderr
149-
old_event_loop_policy = get_event_loop_policy()
150-
set_event_loop_policy(CZEventLoopPolicy())
151-
sys.stdin = WrapStdx(sys.stdin)
152-
sys.stdout = WrapStdx(sys.stdout)
153-
sys.stderr = WrapStdx(sys.stderr)
111+
wrap_stdio()
154112

155113
if git.is_staging_clean() and not (dry_run or allow_empty):
156114
raise NothingToCommitError("No files added to staging!")
@@ -174,13 +132,7 @@ def __call__(self):
174132
m = self.prompt_commit_questions()
175133

176134
if commit_msg_file:
177-
sys.stdin.close()
178-
sys.stdout.close()
179-
sys.stderr.close()
180-
set_event_loop_policy(old_event_loop_policy)
181-
sys.stdin = old_stdin
182-
sys.stdout = old_stdout
183-
sys.stderr = old_stderr
135+
unwrap_stdio()
184136

185137
if manual_edit:
186138
m = self.manual_edit(m)

commitizen/wrap_stdio.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import sys
2+
3+
if sys.platform == "win32": # pragma: no cover
4+
from .wrap_stdio_windows import _unwrap_stdio, _wrap_stdio
5+
elif sys.platform == "linux":
6+
from .wrap_stdio_linux import _unwrap_stdio, _wrap_stdio # pragma: no cover
7+
else:
8+
from .wrap_stdio_unix import _unwrap_stdio, _wrap_stdio # pragma: no cover
9+
10+
11+
def wrap_stdio():
12+
_wrap_stdio()
13+
return None
14+
15+
16+
def unwrap_stdio():
17+
_unwrap_stdio()
18+
return None

commitizen/wrap_stdio_linux.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
import sys
3+
from io import IOBase
4+
5+
6+
class WrapStdioLinux:
7+
def __init__(self, stdx: IOBase):
8+
self._fileno = stdx.fileno()
9+
if self._fileno == 0:
10+
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
11+
tty = open(fd, "wb+", buffering=0)
12+
else:
13+
tty = open("/dev/tty", "w") # type: ignore
14+
self.tty = tty
15+
16+
def __getattr__(self, key):
17+
if key == "encoding" and self._fileno == 0:
18+
return "UTF-8"
19+
return getattr(self.tty, key)
20+
21+
def __del__(self):
22+
self.tty.close()
23+
24+
25+
backup_stdin = None
26+
backup_stdout = None
27+
backup_stderr = None
28+
29+
30+
def _wrap_stdio():
31+
global backup_stdin
32+
backup_stdin = sys.stdin
33+
sys.stdin = WrapStdioLinux(sys.stdin)
34+
35+
global backup_stdout
36+
backup_stdout = sys.stdout
37+
sys.stdout = WrapStdioLinux(sys.stdout)
38+
39+
global backup_stderr
40+
backup_stderr = sys.stderr
41+
sys.stderr = WrapStdioLinux(sys.stderr)
42+
43+
44+
def _unwrap_stdio():
45+
global backup_stdin
46+
sys.stdin.close()
47+
sys.stdin = backup_stdin
48+
49+
global backup_stdout
50+
sys.stdout.close()
51+
sys.stdout = backup_stdout
52+
53+
global backup_stderr
54+
sys.stderr.close()
55+
sys.stderr = backup_stderr

commitizen/wrap_stdio_unix.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
import selectors
3+
import sys
4+
from asyncio import DefaultEventLoopPolicy, get_event_loop_policy, set_event_loop_policy
5+
from io import IOBase
6+
7+
8+
class CZEventLoopPolicy(DefaultEventLoopPolicy): # type: ignore
9+
def get_event_loop(self):
10+
self.set_event_loop(self._loop_factory(selectors.SelectSelector()))
11+
return self._local._loop
12+
13+
14+
class WrapStdioLinux:
15+
def __init__(self, stdx: IOBase):
16+
self._fileno = stdx.fileno()
17+
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
18+
if self._fileno == 0:
19+
tty = open(fd, "wb+", buffering=0)
20+
else:
21+
tty = open(fd, "rb+", buffering=0)
22+
self.tty = tty
23+
24+
def __getattr__(self, key):
25+
if key == "encoding":
26+
return "UTF-8"
27+
return getattr(self.tty, key)
28+
29+
def __del__(self):
30+
self.tty.close()
31+
32+
33+
backup_event_loop_policy = None
34+
backup_stdin = None
35+
backup_stdout = None
36+
backup_stderr = None
37+
38+
39+
def _wrap_stdio():
40+
global backup_event_loop_policy
41+
backup_event_loop_policy = get_event_loop_policy()
42+
set_event_loop_policy(CZEventLoopPolicy())
43+
44+
global backup_stdin
45+
backup_stdin = sys.stdin
46+
sys.stdin = WrapStdioLinux(sys.stdin)
47+
48+
global backup_stdout
49+
backup_stdout = sys.stdout
50+
sys.stdout = WrapStdioLinux(sys.stdout)
51+
52+
global backup_stderr
53+
backup_stdout = sys.stderr
54+
sys.stderr = WrapStdioLinux(sys.stderr)
55+
56+
57+
def _unwrap_stdio():
58+
global backup_event_loop_policy
59+
set_event_loop_policy(backup_event_loop_policy)
60+
61+
global backup_stdin
62+
sys.stdin.close()
63+
sys.stdin = backup_stdin
64+
65+
global backup_stdout
66+
sys.stdout.close()
67+
sys.stdout = backup_stdout
68+
69+
global backup_stderr
70+
sys.stderr.close()
71+
sys.stderr = backup_stderr

commitizen/wrap_stdio_windows.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def _wrap_stdio():
2+
pass
3+
4+
5+
def _unwrap_stdio():
6+
pass

tests/commands/test_commit_command.py

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,9 @@ def test_commit_from_pre_commit_msg_hook(config, mocker, capsys):
543543

544544
commit_mock = mocker.patch("commitizen.git.commit")
545545
commit_mock.return_value = cmd.Command("success", "", "", "", 0)
546-
mocker.patch("commitizen.commands.commit.WrapStdx")
547-
mocker.patch("os.open")
546+
547+
mocker.patch("commitizen.commands.commit.wrap_stdio")
548+
mocker.patch("commitizen.commands.commit.unwrap_stdio")
548549
reader_mock = mocker.mock_open(read_data="\n\n#test\n")
549550
mocker.patch("builtins.open", reader_mock, create=True)
550551

@@ -553,43 +554,3 @@ def test_commit_from_pre_commit_msg_hook(config, mocker, capsys):
553554
out, _ = capsys.readouterr()
554555
assert "Commit message is successful!" in out
555556
commit_mock.assert_not_called()
556-
557-
558-
def test_WrapStdx(mocker):
559-
mocker.patch("os.open")
560-
reader_mock = mocker.mock_open(read_data="data")
561-
mocker.patch("builtins.open", reader_mock, create=True)
562-
563-
stdin_mock_fileno = mocker.patch.object(sys.stdin, "fileno")
564-
stdin_mock_fileno.return_value = 0
565-
wrap_stdin = commands.commit.WrapStdx(sys.stdin)
566-
567-
assert wrap_stdin.encoding == "UTF-8"
568-
assert wrap_stdin.read() == "data"
569-
570-
writer_mock = mocker.mock_open(read_data="data")
571-
mocker.patch("builtins.open", writer_mock, create=True)
572-
stdout_mock_fileno = mocker.patch.object(sys.stdout, "fileno")
573-
stdout_mock_fileno.return_value = 1
574-
wrap_stout = commands.commit.WrapStdx(sys.stdout)
575-
wrap_stout.write("data")
576-
577-
if sys.platform == "linux":
578-
writer_mock.assert_called_once_with("/dev/tty", "w")
579-
else:
580-
pass
581-
writer_mock().write.assert_called_once_with("data")
582-
583-
writer_mock = mocker.mock_open(read_data="data")
584-
mocker.patch("builtins.open", writer_mock, create=True)
585-
stderr_mock_fileno = mocker.patch.object(sys.stdout, "fileno")
586-
stderr_mock_fileno.return_value = 2
587-
wrap_sterr = commands.commit.WrapStdx(sys.stderr)
588-
589-
wrap_sterr.write("data")
590-
591-
if sys.platform == "linux":
592-
writer_mock.assert_called_once_with("/dev/tty", "w")
593-
else:
594-
pass
595-
writer_mock().write.assert_called_once_with("data")

0 commit comments

Comments
 (0)