summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2017-09-08 22:00:38 -0400
committerLuke Shumaker <lukeshu@lukeshu.com>2017-09-08 22:00:38 -0400
commit79229a92c3836ee70f238c3f8906abf91e4e46f6 (patch)
tree1a7b0718bee989c41f70d1c203a4fc344e4ca659
parentee701cc53db14144df5321e5861e5bcbde220193 (diff)
nslcd_server: Add a request size limit
-rw-r--r--nslcd_proto/io.go23
-rwxr-xr-xnslcd_server/func_handlerequest.go.gen29
-rw-r--r--nslcd_systemd/misc_test.go3
3 files changed, 42 insertions, 13 deletions
diff --git a/nslcd_proto/io.go b/nslcd_proto/io.go
index daced37..bf59282 100644
--- a/nslcd_proto/io.go
+++ b/nslcd_proto/io.go
@@ -76,7 +76,9 @@ func Write(fd io.Writer, data interface{}) (err error) {
}
// Read an object from a stream. Any errors returned are of type
-// NslcdError.
+// NslcdError. If the type assertion succeeds, then
+// fd.(*io.LimitedReader).N is used to prevent an overly-large buffer
+// from being allocated.
func Read(fd io.Reader, data interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
@@ -156,6 +158,16 @@ func write(fd io.Writer, data interface{}) {
}
}
+// Assert that we *will* read n bytes. If we know now that < n bytes
+// will be available, then this will let us avoid an allocation.
+func willread(fd io.Reader, n int64) {
+ if lfd, ok := fd.(*io.LimitedReader); ok {
+ if n > lfd.N {
+ npanic(NslcdError(io.EOF.Error()))
+ }
+ }
+}
+
// Read an object from a stream. In the event of an error, this
// function may panic(NslcdError)! Handle it!
func read(fd io.Reader, data interface{}) {
@@ -179,13 +191,15 @@ func read(fd io.Reader, data interface{}) {
case *string:
var len int32
read(fd, &len)
- buf := make([]byte, len) // BUG(lukeshu): Read: `string` length needs sanity checked
+ willread(fd, int64(len))
+ buf := make([]byte, len)
read(fd, &buf)
*data = string(buf)
case *[]string:
var num int32
read(fd, &num)
- *data = make([]string, num) // BUG(lukeshu): Read: `[]string` length needs sanity checked
+ willread(fd, int64(num * /* min size of a string is: */4))
+ *data = make([]string, num)
for i := 0; i < int(num); i++ {
read(fd, &((*data)[i]))
}
@@ -212,7 +226,8 @@ func read(fd io.Reader, data interface{}) {
case *[]net.IP:
var num int32
read(fd, &num)
- *data = make([]net.IP, num) // BUG(lukeshu): Read: `[]net.IP` length needs sanity checked
+ willread(fd, int64(num * /* min size of an IP is: */net.IPv4len))
+ *data = make([]net.IP, num)
for i := 0; i < int(num); i++ {
read(fd, &((*data)[i]))
}
diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen
index af36e84..7c28e7c 100755
--- a/nslcd_server/func_handlerequest.go.gen
+++ b/nslcd_server/func_handlerequest.go.gen
@@ -26,6 +26,7 @@ package nslcd_server
import (
"fmt"
+ "io"
"os"
"time"
@@ -53,6 +54,10 @@ type Limits struct {
// How long can we spend writing a response?
WriteTimeout time.Duration
+
+ // What is the maximum request length in bytes that we are
+ // willing to handle?
+ RequestMaxSize int64
}
type Conn interface {
@@ -92,13 +97,19 @@ func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) (
}
}
+ var in io.Reader = conn
+ if limits.RequestMaxSize > 0 {
+ in = &io.LimitedReader{R: in, N: limits.RequestMaxSize}
+ }
+ out := conn
+
var version int32
- maybePanic(p.Read(conn, &version))
+ maybePanic(p.Read(in, &version))
if version != p.NSLCD_VERSION {
return p.NslcdError(fmt.Sprintf("Version mismatch: server=%#08x client=%#08x", p.NSLCD_VERSION, version))
}
var action int32
- maybePanic(p.Read(conn, &action))
+ maybePanic(p.Read(in, &action))
switch action {
$(
@@ -106,7 +117,7 @@ while read -r request; do
cat <<EOT
case p.NSLCD_ACTION_${request^^}:
var req p.Request_${request}
- maybePanic(p.Read(conn, &req))
+ maybePanic(p.Read(in, &req))
$(
case "$request" in
PAM_Authentication)
@@ -133,24 +144,24 @@ while read -r request; do
esac
)
if limits.WriteTimeout != 0 {
- err = conn.SetWriteDeadline(time.Now().Add(limits.WriteTimeout))
+ err = out.SetWriteDeadline(time.Now().Add(limits.WriteTimeout))
if err != nil {
return err
}
}
- maybePanic(p.Write(conn, p.NSLCD_VERSION))
- maybePanic(p.Write(conn, action))
+ maybePanic(p.Write(out, p.NSLCD_VERSION))
+ maybePanic(p.Write(out, action))
ch := backend.${request}(cred, req)
for result := range ch {
if err == nil {
- err = p.Write(conn, p.NSLCD_RESULT_BEGIN)
+ err = p.Write(out, p.NSLCD_RESULT_BEGIN)
}
if err == nil {
- err = p.Write(conn, result)
+ err = p.Write(out, result)
}
}
maybePanic(err)
- maybePanic(p.Write(conn, p.NSLCD_RESULT_END))
+ maybePanic(p.Write(out, p.NSLCD_RESULT_END))
return nil
EOT
done < "$requests"
diff --git a/nslcd_systemd/misc_test.go b/nslcd_systemd/misc_test.go
index 0e6f2e9..be6d40c 100644
--- a/nslcd_systemd/misc_test.go
+++ b/nslcd_systemd/misc_test.go
@@ -293,13 +293,16 @@ func TestLargeRequest(t *testing.T) {
defer sdActivatedReset()
t.Run("large-request", func(t *testing.T) {
testWithTimeout(t, 2*time.Second, func(t *testing.T, s chan<- string) {
+ KiB := 1024
GiB := 1024*1024*1024
+
ctx := &testContext{T: t, tmpdir: tmpdir, status: s}
backend := &LockingBackend{}
limits := nslcd_server.Limits{
Timeout: 1 * time.Second,
+ RequestMaxSize: int64(1*KiB),
}
notifyHandler := func(dat []byte, oob []byte) error { return nil }