Skip to content

Commit 29a4fbe

Browse files
authored
Merge pull request #910 from devlights/add-fd-passing-example
2 parents 55f83bf + 2ee9e7e commit 29a4fbe

File tree

5 files changed

+324
-0
lines changed

5 files changed

+324
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
tcp-server
2+
tcp-client
3+
uds-server
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# https://taskfile.dev
2+
3+
version: '3'
4+
5+
vars:
6+
GREETING: Hello, World!
7+
8+
tasks:
9+
default:
10+
cmds:
11+
- task: build
12+
- task: run
13+
build:
14+
cmds:
15+
- go build -o tcp-client tcpclient/main.go
16+
- go build -o tcp-server tcpserver/main.go
17+
- go build -o uds-server udsserver/main.go
18+
run:
19+
cmds:
20+
- ./uds-server &
21+
- sleep 1
22+
- ./tcp-server &
23+
- sleep 1
24+
- ./tcp-client
25+
- pkill tcp-server
26+
- pkill uds-server
27+
ignore_error: true
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
"io"
7+
"log"
8+
"net"
9+
"time"
10+
)
11+
12+
func main() {
13+
log.SetFlags(log.Lmicroseconds)
14+
15+
if err := run(); err != nil {
16+
panic(err)
17+
}
18+
}
19+
20+
func run() error {
21+
conn, err := net.Dial("tcp", ":8888")
22+
if err != nil {
23+
return err
24+
}
25+
defer func() {
26+
conn.Close()
27+
log.Println("[TCP-C] close")
28+
}()
29+
log.Println("[TCP-C] connect tcp-server")
30+
31+
buf := make([]byte, 5)
32+
n, err := conn.Read(buf)
33+
if err != nil {
34+
switch {
35+
case errors.Is(err, io.EOF):
36+
return nil
37+
default:
38+
return err
39+
}
40+
}
41+
42+
msg := buf[:n]
43+
log.Printf("[TCP-C] recv (%s)", msg)
44+
45+
msg = bytes.ToUpper(msg)
46+
_, err = conn.Write(msg)
47+
if err != nil {
48+
return err
49+
}
50+
log.Printf("[TCP-C] send (%s)", msg)
51+
52+
buf = make([]byte, 1)
53+
for {
54+
conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
55+
56+
_, err = conn.Read(buf)
57+
if err != nil {
58+
if errors.Is(err, io.EOF) {
59+
log.Println("[TCP-C] disconnect")
60+
break
61+
}
62+
63+
return err
64+
}
65+
}
66+
67+
return nil
68+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package main
2+
3+
import (
4+
"errors"
5+
"log"
6+
"net"
7+
"time"
8+
9+
"golang.org/x/sys/unix"
10+
)
11+
12+
func main() {
13+
log.SetFlags(log.Lmicroseconds)
14+
15+
if err := run(); err != nil {
16+
panic(err)
17+
}
18+
}
19+
20+
func run() error {
21+
var (
22+
udsConn net.Conn
23+
err error
24+
)
25+
for range 3 {
26+
udsConn, err = net.DialTimeout("unix", "@tcp_fd_passing", 1*time.Second)
27+
if err != nil {
28+
var netErr net.Error
29+
if errors.As(err, &netErr); netErr.Timeout() {
30+
continue
31+
}
32+
33+
return err
34+
}
35+
36+
break
37+
}
38+
defer udsConn.Close()
39+
log.Println("[TCP-S] connect uds-server")
40+
41+
ln, err := net.Listen("tcp", ":8888")
42+
if err != nil {
43+
return err
44+
}
45+
defer ln.Close()
46+
log.Println("[TCP-S] listen on :8888")
47+
48+
for {
49+
errCh := make(chan error, 1)
50+
func() {
51+
conn, err := ln.Accept()
52+
if err != nil {
53+
errCh <- err
54+
return
55+
}
56+
defer func() {
57+
conn.Close()
58+
log.Println("[TCP-S] close")
59+
}()
60+
log.Println("[TCP-S] accept client")
61+
62+
unixConn, _ := udsConn.(*net.UnixConn)
63+
tcpConn, _ := conn.(*net.TCPConn)
64+
file, _ := tcpConn.File()
65+
err = sendFD(unixConn, int(file.Fd()))
66+
if err != nil {
67+
errCh <- err
68+
return
69+
}
70+
log.Printf("[TCP-S] send fd=%d to uds-server", file.Fd())
71+
72+
errCh <- nil
73+
}()
74+
75+
err = <-errCh
76+
if err != nil {
77+
return err
78+
}
79+
}
80+
}
81+
82+
func sendFD(sock *net.UnixConn, fd int) error {
83+
var (
84+
dummy = make([]byte, 1)
85+
rights = unix.UnixRights(fd)
86+
err error
87+
)
88+
_, _, err = sock.WriteMsgUnix(dummy, rights, nil)
89+
if err != nil {
90+
return err
91+
}
92+
93+
return nil
94+
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package main
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"io"
7+
"log"
8+
"net"
9+
"os"
10+
11+
"golang.org/x/sys/unix"
12+
)
13+
14+
func main() {
15+
log.SetFlags(log.Lmicroseconds)
16+
17+
if err := run(); err != nil {
18+
panic(err)
19+
}
20+
}
21+
22+
func run() error {
23+
ln, err := net.Listen("unix", "@tcp_fd_passing")
24+
if err != nil {
25+
return err
26+
}
27+
defer ln.Close()
28+
29+
log.Println("[UDS-S] server listening on")
30+
31+
udsConn, err := ln.Accept()
32+
if err != nil {
33+
return err
34+
}
35+
defer udsConn.Close()
36+
37+
log.Printf("[UDS-S] %v", udsConn.RemoteAddr())
38+
39+
unixConn, ok := udsConn.(*net.UnixConn)
40+
if !ok {
41+
return fmt.Errorf("not net.UnixConn")
42+
}
43+
44+
fd, err := recvFD(unixConn)
45+
if err != nil {
46+
return err
47+
}
48+
log.Printf("[UDS-S] recv fd=%d", fd)
49+
50+
file := os.NewFile(uintptr(fd), "client-socket")
51+
if file == nil {
52+
return fmt.Errorf("os.NewFile() failed")
53+
}
54+
defer file.Close()
55+
56+
conn, err := net.FileConn(file)
57+
if err != nil {
58+
return fmt.Errorf("net.FileConn() failed")
59+
}
60+
defer func() {
61+
conn.Close()
62+
log.Println("[UDS-S] close")
63+
}()
64+
65+
buf := []byte("hello")
66+
_, err = conn.Write(buf)
67+
if err != nil {
68+
return err
69+
}
70+
log.Printf("[UDS-S] send (%s)", buf)
71+
72+
buf = make([]byte, 5)
73+
n, err := conn.Read(buf)
74+
if err != nil {
75+
switch {
76+
case errors.Is(err, io.EOF):
77+
log.Println("[UDS-S] disconnect")
78+
default:
79+
return err
80+
}
81+
}
82+
log.Printf("[UDS-S] recv (%s)", buf[:n])
83+
84+
tcpConn, _ := conn.(*net.TCPConn)
85+
tcpConn.CloseWrite()
86+
log.Println("[UDS-S] shutdown(SHUT_WR)")
87+
88+
return nil
89+
}
90+
91+
func recvFD(sock *net.UnixConn) (int, error) {
92+
var (
93+
dummy = make([]byte, 1)
94+
oob = make([]byte, unix.CmsgSpace(4))
95+
flags int
96+
err error
97+
)
98+
_, _, flags, _, err = sock.ReadMsgUnix(dummy, oob)
99+
if err != nil {
100+
return -1, err
101+
}
102+
103+
if flags&unix.MSG_TRUNC != 0 {
104+
return -1, fmt.Errorf("control message is truncated")
105+
}
106+
107+
var (
108+
msgs []unix.SocketControlMessage
109+
)
110+
msgs, err = unix.ParseSocketControlMessage(oob)
111+
if err != nil {
112+
return -1, err
113+
}
114+
115+
if len(msgs) != 1 {
116+
return -1, fmt.Errorf("want: 1 control message; got: %d", len(msgs))
117+
}
118+
119+
var (
120+
fds []int
121+
)
122+
fds, err = unix.ParseUnixRights(&msgs[0])
123+
if err != nil {
124+
return -1, err
125+
}
126+
127+
if len(fds) != 1 {
128+
return -1, fmt.Errorf("want: 1 fd; got: %d", len(fds))
129+
}
130+
131+
return fds[0], nil
132+
}

0 commit comments

Comments
 (0)