Skip to content

compression.go: prevent flate decompression bomb attacks #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions minecraft/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ type Conn struct {
log *slog.Logger
authEnabled bool

proto Protocol
acceptedProto []Protocol
pool packet.Pool
enc *packet.Encoder
dec *packet.Decoder
compression packet.Compression
readerLimits bool
proto Protocol
acceptedProto []Protocol
pool packet.Pool
enc *packet.Encoder
dec *packet.Decoder
compression packet.Compression
maxDecompressedLen int
readerLimits bool

disconnectOnUnknownPacket bool
disconnectOnInvalidPacket bool
Expand Down Expand Up @@ -721,7 +722,7 @@ func (conn *Conn) handleRequestNetworkSettings(pk *packet.RequestNetworkSettings
}
_ = conn.Flush()
conn.enc.EnableCompression(conn.compression)
conn.dec.EnableCompression()
conn.dec.EnableCompression(conn.maxDecompressedLen)
return nil
}

Expand All @@ -732,7 +733,7 @@ func (conn *Conn) handleNetworkSettings(pk *packet.NetworkSettings) error {
return fmt.Errorf("unknown compression algorithm %v", pk.CompressionAlgorithm)
}
conn.enc.EnableCompression(alg)
conn.dec.EnableCompression()
conn.dec.EnableCompression(conn.maxDecompressedLen)
conn.readyToLogin = true
return nil
}
Expand Down
11 changes: 11 additions & 0 deletions minecraft/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/sandertv/gophertunnel/minecraft/protocol/packet"
"github.com/sandertv/gophertunnel/minecraft/resource"
"log/slog"
"math"
"net"
"slices"
"sync"
Expand Down Expand Up @@ -81,6 +82,10 @@ type ListenConfig struct {
// Login packet. The function is called with the header of the packet and its raw payload, the address
// from which the packet originated, and the destination address.
PacketFunc func(header packet.Header, payload []byte, src, dst net.Addr)

// MaxDecompressedLen is the maximum length of a decompressed packet to prevent potential exploits. If 0,
// the default value is 16MB (16 * 1024 * 1024). Setting this to a negative integer disables the limit.
MaxDecompressedLen int
}

// Listener implements a Minecraft listener on top of an unspecific net.Listener. It abstracts away the
Expand Down Expand Up @@ -120,6 +125,11 @@ func (cfg ListenConfig) Listen(network string, address string) (*Listener, error
if cfg.FlushRate == 0 {
cfg.FlushRate = time.Second / 20
}
if cfg.MaxDecompressedLen == 0 {
cfg.MaxDecompressedLen = 16 * 1024 * 1024 // 16MB
} else if cfg.MaxDecompressedLen < 0 {
cfg.MaxDecompressedLen = math.MaxInt
}

n, ok := networkByID(network, cfg.ErrorLog)
if !ok {
Expand Down Expand Up @@ -262,6 +272,7 @@ func (listener *Listener) createConn(netConn net.Conn) {
conn := newConn(netConn, listener.key, listener.cfg.ErrorLog, proto{}, listener.cfg.FlushRate, true)
conn.acceptedProto = append(listener.cfg.AcceptedProtocols, proto{})
conn.compression = listener.cfg.Compression
conn.maxDecompressedLen = listener.cfg.MaxDecompressedLen
conn.pool = conn.proto.Packets(true)

conn.packetFunc = listener.cfg.PacketFunc
Expand Down
20 changes: 15 additions & 5 deletions minecraft/protocol/packet/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Compression interface {
// Compress compresses the given data and returns the compressed data.
Compress(decompressed []byte) ([]byte, error)
// Decompress decompresses the given data and returns the decompressed data.
Decompress(compressed []byte) ([]byte, error)
Decompress(compressed []byte, limit int) ([]byte, error)
}

var (
Expand Down Expand Up @@ -67,7 +67,10 @@ func (nopCompression) Compress(decompressed []byte) ([]byte, error) {
}

// Decompress ...
func (nopCompression) Decompress(compressed []byte) ([]byte, error) {
func (nopCompression) Decompress(compressed []byte, limit int) ([]byte, error) {
if len(compressed) > limit {
return nil, fmt.Errorf("nop decompression: size %d exceeds limit %d", len(compressed), limit)
}
return compressed, nil
}

Expand Down Expand Up @@ -102,7 +105,7 @@ func (flateCompression) Compress(decompressed []byte) ([]byte, error) {
}

// Decompress ...
func (flateCompression) Decompress(compressed []byte) ([]byte, error) {
func (flateCompression) Decompress(compressed []byte, limit int) ([]byte, error) {
buf := bytes.NewReader(compressed)
c := flateDecompressPool.Get().(io.ReadCloser)
defer flateDecompressPool.Put(c)
Expand All @@ -114,7 +117,7 @@ func (flateCompression) Decompress(compressed []byte) ([]byte, error) {

// Guess an uncompressed size of 2*len(compressed).
decompressed := bytes.NewBuffer(make([]byte, 0, len(compressed)*2))
if _, err := io.Copy(decompressed, c); err != nil {
if _, err := io.Copy(decompressed, io.LimitReader(c, int64(limit))); err != nil {
return nil, fmt.Errorf("decompress flate: %w", err)
}
return decompressed.Bytes(), nil
Expand All @@ -135,10 +138,17 @@ func (snappyCompression) Compress(decompressed []byte) ([]byte, error) {
}

// Decompress ...
func (snappyCompression) Decompress(compressed []byte) ([]byte, error) {
func (snappyCompression) Decompress(compressed []byte, limit int) ([]byte, error) {
// Snappy writes a decoded data length prefix, so it can allocate the
// perfect size right away and only needs to allocate once. No need to pool
// byte slices here either.
decodedLen, err := snappy.DecodedLen(compressed)
if err != nil {
return nil, fmt.Errorf("snappy decoded length: %w", err)
}
if decodedLen > limit {
return nil, fmt.Errorf("snappy decoded size %d exceeds limit %d", decodedLen, limit)
}
decompressed, err := snappy.Decode(nil, compressed)
if err != nil {
return nil, fmt.Errorf("decompress snappy: %w", err)
Expand Down
10 changes: 6 additions & 4 deletions minecraft/protocol/packet/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ type Decoder struct {
// NewDecoder implements the packetReader interface.
pr packetReader

decompress bool
encrypt *encrypt
decompress bool
maxDecompressedLen int
encrypt *encrypt

checkPacketLimit bool
}
Expand Down Expand Up @@ -56,8 +57,9 @@ func (decoder *Decoder) EnableEncryption(keyBytes [32]byte) {
}

// EnableCompression enables compression for the Decoder.
func (decoder *Decoder) EnableCompression() {
func (decoder *Decoder) EnableCompression(maxDecompressedLen int) {
decoder.decompress = true
decoder.maxDecompressedLen = maxDecompressedLen
}

// DisableBatchPacketLimit disables the check that limits the number of packets allowed in a single packet
Expand Down Expand Up @@ -112,7 +114,7 @@ func (decoder *Decoder) Decode() (packets [][]byte, err error) {
if !ok {
return nil, fmt.Errorf("decompress batch: unknown compression algorithm %v", data[0])
}
data, err = compression.Decompress(data[1:])
data, err = compression.Decompress(data[1:], decoder.maxDecompressedLen)
if err != nil {
return nil, fmt.Errorf("decompress batch: %w", err)
}
Expand Down