From 79229a92c3836ee70f238c3f8906abf91e4e46f6 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Fri, 8 Sep 2017 22:00:38 -0400 Subject: nslcd_server: Add a request size limit --- nslcd_proto/io.go | 23 +++++++++++++++++++---- nslcd_server/func_handlerequest.go.gen | 29 ++++++++++++++++++++--------- nslcd_systemd/misc_test.go | 3 +++ 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/nslcd_proto/io.go b/nslcd_proto/io.go index daced37..bf59282 100644 --- a/nslcd_proto/io.go +++ b/nslcd_proto/io.go @@ -76,7 +76,9 @@ func Write(fd io.Writer, data interface{}) (err error) { } // Read an object from a stream. Any errors returned are of type -// NslcdError. +// NslcdError. If the type assertion succeeds, then +// fd.(*io.LimitedReader).N is used to prevent an overly-large buffer +// from being allocated. func Read(fd io.Reader, data interface{}) (err error) { defer func() { if r := recover(); r != nil { @@ -156,6 +158,16 @@ func write(fd io.Writer, data interface{}) { } } +// Assert that we *will* read n bytes. If we know now that < n bytes +// will be available, then this will let us avoid an allocation. +func willread(fd io.Reader, n int64) { + if lfd, ok := fd.(*io.LimitedReader); ok { + if n > lfd.N { + npanic(NslcdError(io.EOF.Error())) + } + } +} + // Read an object from a stream. In the event of an error, this // function may panic(NslcdError)! Handle it! func read(fd io.Reader, data interface{}) { @@ -179,13 +191,15 @@ func read(fd io.Reader, data interface{}) { case *string: var len int32 read(fd, &len) - buf := make([]byte, len) // BUG(lukeshu): Read: `string` length needs sanity checked + willread(fd, int64(len)) + buf := make([]byte, len) read(fd, &buf) *data = string(buf) case *[]string: var num int32 read(fd, &num) - *data = make([]string, num) // BUG(lukeshu): Read: `[]string` length needs sanity checked + willread(fd, int64(num * /* min size of a string is: */4)) + *data = make([]string, num) for i := 0; i < int(num); i++ { read(fd, &((*data)[i])) } @@ -212,7 +226,8 @@ func read(fd io.Reader, data interface{}) { case *[]net.IP: var num int32 read(fd, &num) - *data = make([]net.IP, num) // BUG(lukeshu): Read: `[]net.IP` length needs sanity checked + willread(fd, int64(num * /* min size of an IP is: */net.IPv4len)) + *data = make([]net.IP, num) for i := 0; i < int(num); i++ { read(fd, &((*data)[i])) } diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen index af36e84..7c28e7c 100755 --- a/nslcd_server/func_handlerequest.go.gen +++ b/nslcd_server/func_handlerequest.go.gen @@ -26,6 +26,7 @@ package nslcd_server import ( "fmt" + "io" "os" "time" @@ -53,6 +54,10 @@ type Limits struct { // How long can we spend writing a response? WriteTimeout time.Duration + + // What is the maximum request length in bytes that we are + // willing to handle? + RequestMaxSize int64 } type Conn interface { @@ -92,13 +97,19 @@ func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) ( } } + var in io.Reader = conn + if limits.RequestMaxSize > 0 { + in = &io.LimitedReader{R: in, N: limits.RequestMaxSize} + } + out := conn + var version int32 - maybePanic(p.Read(conn, &version)) + maybePanic(p.Read(in, &version)) if version != p.NSLCD_VERSION { return p.NslcdError(fmt.Sprintf("Version mismatch: server=%#08x client=%#08x", p.NSLCD_VERSION, version)) } var action int32 - maybePanic(p.Read(conn, &action)) + maybePanic(p.Read(in, &action)) switch action { $( @@ -106,7 +117,7 @@ while read -r request; do cat <