Skip to content

Commit 634e68c

Browse files
committed
Avoid manual polling
Rather than manually implementing non-blocking reads using polling, simply configure the underlying file descriptor to be non-blocking where necessary. Replace libc calls with the normal abstractions from the standard library. This makes the code less error prone and properly encodes type system invariants (such as `io::Read::read` requiring its receiver to be mutable) which were previously not preserved because the implementation used raw file descriptors.
1 parent b719bb4 commit 634e68c

File tree

1 file changed

+111
-173
lines changed

1 file changed

+111
-173
lines changed

src/unix_term.rs

Lines changed: 111 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
use std::env;
2-
use std::convert::TryFrom as _;
32
use std::fmt::Display;
43
use std::fs;
5-
use std::io::{self, BufRead, BufReader};
4+
use std::io::{self, BufRead, BufReader, Read};
65
use std::mem;
76
use std::os::fd::{AsRawFd, RawFd};
87
use std::str;
@@ -95,6 +94,15 @@ impl Input<fs::File> {
9594
}
9695
}
9796

97+
impl<T: Read> Read for Input<T> {
98+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
99+
match self {
100+
Self::Stdin(s) => s.read(buf),
101+
Self::File(f) => f.read(buf),
102+
}
103+
}
104+
}
105+
98106
// NB: this is not a full BufRead implementation because io::Stdin does not implement BufRead.
99107
impl<T: BufRead> Input<T> {
100108
fn read_line(&mut self, buf: &mut String) -> io::Result<usize> {
@@ -145,202 +153,132 @@ pub(crate) fn read_secure() -> io::Result<String> {
145153
})
146154
}
147155

148-
fn poll_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
149-
let mut pollfd = libc::pollfd {
150-
fd,
151-
events: libc::POLLIN,
152-
revents: 0,
153-
};
154-
let ret = unsafe { libc::poll(&mut pollfd as *mut _, 1, timeout) };
155-
if ret < 0 {
156-
Err(io::Error::last_os_error())
157-
} else {
158-
Ok(pollfd.revents & libc::POLLIN != 0)
159-
}
160-
}
161-
162-
#[cfg(target_os = "macos")]
163-
fn select_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
164-
unsafe {
165-
let mut read_fd_set: libc::fd_set = mem::zeroed();
166-
167-
let mut timeout_val;
168-
let timeout = if timeout < 0 {
169-
std::ptr::null_mut()
170-
} else {
171-
timeout_val = libc::timeval {
172-
tv_sec: (timeout / 1000) as _,
173-
tv_usec: (timeout * 1000) as _,
174-
};
175-
&mut timeout_val
176-
};
177-
178-
libc::FD_ZERO(&mut read_fd_set);
179-
libc::FD_SET(fd, &mut read_fd_set);
180-
let ret = libc::select(
181-
fd + 1,
182-
&mut read_fd_set,
183-
std::ptr::null_mut(),
184-
std::ptr::null_mut(),
185-
timeout,
186-
);
187-
if ret < 0 {
188-
Err(io::Error::last_os_error())
189-
} else {
190-
Ok(libc::FD_ISSET(fd, &read_fd_set))
156+
fn read_single_char<T: Read + AsRawFd>(input: &mut T) -> io::Result<Option<char>> {
157+
let original = unsafe { libc::fcntl(input.as_raw_fd(), libc::F_GETFL) };
158+
c_result(|| unsafe {
159+
libc::fcntl(
160+
input.as_raw_fd(),
161+
libc::F_SETFL,
162+
original | libc::O_NONBLOCK,
163+
)
164+
})?;
165+
let mut buf = [0u8; 1];
166+
let result = read_bytes(input, &mut buf);
167+
c_result(|| unsafe { libc::fcntl(input.as_raw_fd(), libc::F_SETFL, original) })?;
168+
match result {
169+
Ok(()) => {
170+
let [byte] = buf;
171+
Ok(Some(byte as char))
191172
}
192-
}
193-
}
194-
195-
fn select_or_poll_term_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
196-
// There is a bug on macos that ttys cannot be polled, only select()
197-
// works. However given how problematic select is in general, we
198-
// normally want to use poll there too.
199-
#[cfg(target_os = "macos")]
200-
{
201-
if unsafe { libc::isatty(fd) == 1 } {
202-
return select_fd(fd, timeout);
173+
Err(err) => {
174+
if err.kind() == io::ErrorKind::WouldBlock {
175+
Ok(None)
176+
} else {
177+
Err(err)
178+
}
203179
}
204180
}
205-
poll_fd(fd, timeout)
206181
}
207182

208-
fn read_single_char(fd: RawFd) -> io::Result<Option<char>> {
209-
// timeout of zero means that it will not block
210-
let is_ready = select_or_poll_term_fd(fd, 0)?;
211-
212-
if is_ready {
213-
// if there is something to be read, take 1 byte from it
214-
let mut buf: [u8; 1] = [0];
215-
216-
read_bytes(fd, &mut buf)?;
217-
Ok(Some(buf[0] as char))
183+
fn read_bytes(input: &mut impl Read, buf: &mut [u8]) -> io::Result<()> {
184+
input.read_exact(buf)?;
185+
if buf.starts_with(b"\x03") {
186+
Err(io::Error::new(
187+
io::ErrorKind::Interrupted,
188+
"read interrupted",
189+
))
218190
} else {
219-
//there is nothing to be read
220-
Ok(None)
221-
}
222-
}
223-
224-
// Similar to libc::read. Read count bytes into slice buf from descriptor fd.
225-
// If successful, return the number of bytes read.
226-
// Will return an error if nothing was read, i.e when called at end of file.
227-
fn read_bytes(fd: RawFd, buf: &mut [u8]) -> io::Result<()> {
228-
let read = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut _, buf.len()) };
229-
match usize::try_from(read) {
230-
Err(std::num::TryFromIntError { .. }) => Err(io::Error::last_os_error()),
231-
Ok(read) => {
232-
if read != buf.len() {
233-
Err(io::Error::new(
234-
io::ErrorKind::UnexpectedEof,
235-
"Reached end of file",
236-
))
237-
} else if buf.starts_with(b"\x03") {
238-
Err(io::Error::new(
239-
io::ErrorKind::Interrupted,
240-
"read interrupted",
241-
))
242-
} else {
243-
Ok(())
244-
}
245-
}
191+
Ok(())
246192
}
247193
}
248194

249-
fn read_single_key_impl(fd: RawFd) -> Result<Key, io::Error> {
250-
loop {
251-
match read_single_char(fd)? {
252-
Some('\x1b') => {
253-
// Escape was read, keep reading in case we find a familiar key
254-
break if let Some(c1) = read_single_char(fd)? {
255-
if c1 == '[' {
256-
if let Some(c2) = read_single_char(fd)? {
257-
match c2 {
258-
'A' => Ok(Key::ArrowUp),
259-
'B' => Ok(Key::ArrowDown),
260-
'C' => Ok(Key::ArrowRight),
261-
'D' => Ok(Key::ArrowLeft),
262-
'H' => Ok(Key::Home),
263-
'F' => Ok(Key::End),
264-
'Z' => Ok(Key::BackTab),
265-
_ => {
266-
let c3 = read_single_char(fd)?;
267-
if let Some(c3) = c3 {
268-
if c3 == '~' {
269-
match c2 {
270-
'1' => Ok(Key::Home), // tmux
271-
'2' => Ok(Key::Insert),
272-
'3' => Ok(Key::Del),
273-
'4' => Ok(Key::End), // tmux
274-
'5' => Ok(Key::PageUp),
275-
'6' => Ok(Key::PageDown),
276-
'7' => Ok(Key::Home), // xrvt
277-
'8' => Ok(Key::End), // xrvt
278-
_ => Ok(Key::UnknownEscSeq(vec![c1, c2, c3])),
279-
}
280-
} else {
281-
Ok(Key::UnknownEscSeq(vec![c1, c2, c3]))
195+
fn read_single_key_impl<T: Read + AsRawFd>(fd: &mut T) -> Result<Key, io::Error> {
196+
// NB: this doesn't use `read_single_char` because we want a blocking read here.
197+
let mut buf = [0u8; 1];
198+
read_bytes(fd, &mut buf)?;
199+
let [byte] = buf;
200+
match byte {
201+
b'\x1b' => {
202+
// Escape was read, keep reading in case we find a familiar key
203+
if let Some(c1) = read_single_char(fd)? {
204+
if c1 == '[' {
205+
if let Some(c2) = read_single_char(fd)? {
206+
match c2 {
207+
'A' => Ok(Key::ArrowUp),
208+
'B' => Ok(Key::ArrowDown),
209+
'C' => Ok(Key::ArrowRight),
210+
'D' => Ok(Key::ArrowLeft),
211+
'H' => Ok(Key::Home),
212+
'F' => Ok(Key::End),
213+
'Z' => Ok(Key::BackTab),
214+
_ => {
215+
let c3 = read_single_char(fd)?;
216+
if let Some(c3) = c3 {
217+
if c3 == '~' {
218+
match c2 {
219+
'1' => Ok(Key::Home), // tmux
220+
'2' => Ok(Key::Insert),
221+
'3' => Ok(Key::Del),
222+
'4' => Ok(Key::End), // tmux
223+
'5' => Ok(Key::PageUp),
224+
'6' => Ok(Key::PageDown),
225+
'7' => Ok(Key::Home), // xrvt
226+
'8' => Ok(Key::End), // xrvt
227+
_ => Ok(Key::UnknownEscSeq(vec![c1, c2, c3])),
282228
}
283229
} else {
284-
// \x1b[ and 1 more char
285-
Ok(Key::UnknownEscSeq(vec![c1, c2]))
230+
Ok(Key::UnknownEscSeq(vec![c1, c2, c3]))
286231
}
232+
} else {
233+
// \x1b[ and 1 more char
234+
Ok(Key::UnknownEscSeq(vec![c1, c2]))
287235
}
288236
}
289-
} else {
290-
// \x1b[ and no more input
291-
Ok(Key::UnknownEscSeq(vec![c1]))
292237
}
293238
} else {
294-
// char after escape is not [
239+
// \x1b[ and no more input
295240
Ok(Key::UnknownEscSeq(vec![c1]))
296241
}
297242
} else {
298-
//nothing after escape
299-
Ok(Key::Escape)
300-
};
301-
}
302-
Some(c) => {
303-
let byte = c as u8;
304-
let mut buf: [u8; 4] = [byte, 0, 0, 0];
305-
306-
break if byte & 224u8 == 192u8 {
307-
// a two byte unicode character
308-
read_bytes(fd, &mut buf[1..][..1])?;
309-
Ok(key_from_utf8(&buf[..2]))
310-
} else if byte & 240u8 == 224u8 {
311-
// a three byte unicode character
312-
read_bytes(fd, &mut buf[1..][..2])?;
313-
Ok(key_from_utf8(&buf[..3]))
314-
} else if byte & 248u8 == 240u8 {
315-
// a four byte unicode character
316-
read_bytes(fd, &mut buf[1..][..3])?;
317-
Ok(key_from_utf8(&buf[..4]))
318-
} else {
319-
Ok(match c {
320-
'\n' | '\r' => Key::Enter,
321-
'\x7f' => Key::Backspace,
322-
'\t' => Key::Tab,
323-
'\x01' => Key::Home, // Control-A (home)
324-
'\x05' => Key::End, // Control-E (end)
325-
'\x08' => Key::Backspace, // Control-H (8) (Identical to '\b')
326-
_ => Key::Char(c),
327-
})
328-
};
329-
}
330-
None => {
331-
// there is no subsequent byte ready to be read, block and wait for input
332-
// negative timeout means that it will block indefinitely
333-
match select_or_poll_term_fd(fd, -1) {
334-
Ok(_) => continue,
335-
Err(_) => break Err(io::Error::last_os_error()),
243+
// char after escape is not [
244+
Ok(Key::UnknownEscSeq(vec![c1]))
336245
}
246+
} else {
247+
//nothing after escape
248+
Ok(Key::Escape)
249+
}
250+
}
251+
byte => {
252+
let mut buf: [u8; 4] = [byte, 0, 0, 0];
253+
if byte & 224u8 == 192u8 {
254+
// a two byte unicode character
255+
read_bytes(fd, &mut buf[1..][..1])?;
256+
Ok(key_from_utf8(&buf[..2]))
257+
} else if byte & 240u8 == 224u8 {
258+
// a three byte unicode character
259+
read_bytes(fd, &mut buf[1..][..2])?;
260+
Ok(key_from_utf8(&buf[..3]))
261+
} else if byte & 248u8 == 240u8 {
262+
// a four byte unicode character
263+
read_bytes(fd, &mut buf[1..][..3])?;
264+
Ok(key_from_utf8(&buf[..4]))
265+
} else {
266+
Ok(match byte as char {
267+
'\n' | '\r' => Key::Enter,
268+
'\x7f' => Key::Backspace,
269+
'\t' => Key::Tab,
270+
'\x01' => Key::Home, // Control-A (home)
271+
'\x05' => Key::End, // Control-E (end)
272+
'\x08' => Key::Backspace, // Control-H (8) (Identical to '\b')
273+
c => Key::Char(c),
274+
})
337275
}
338276
}
339277
}
340278
}
341279

342280
pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result<Key> {
343-
let input = Input::unbuffered()?;
281+
let mut input = Input::unbuffered()?;
344282

345283
let mut termios = core::mem::MaybeUninit::uninit();
346284
c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?;
@@ -349,7 +287,7 @@ pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result<Key> {
349287
unsafe { libc::cfmakeraw(&mut termios) };
350288
termios.c_oflag = original.c_oflag;
351289
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &termios) })?;
352-
let rv = read_single_key_impl(input.as_raw_fd());
290+
let rv = read_single_key_impl(&mut input);
353291
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &original) })?;
354292

355293
// if the user hit ^C we want to signal SIGINT to ourselves.

0 commit comments

Comments
 (0)