summaryrefslogtreecommitdiff
path: root/nslcd_proto
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2017-09-08 22:00:38 -0400
committerLuke Shumaker <lukeshu@lukeshu.com>2017-09-08 22:00:38 -0400
commit79229a92c3836ee70f238c3f8906abf91e4e46f6 (patch)
tree1a7b0718bee989c41f70d1c203a4fc344e4ca659 /nslcd_proto
parentee701cc53db14144df5321e5861e5bcbde220193 (diff)
nslcd_server: Add a request size limit
Diffstat (limited to 'nslcd_proto')
-rw-r--r--nslcd_proto/io.go23
1 files changed, 19 insertions, 4 deletions
diff --git a/nslcd_proto/io.go b/nslcd_proto/io.go
index daced37..bf59282 100644
--- a/nslcd_proto/io.go
+++ b/nslcd_proto/io.go
@@ -76,7 +76,9 @@ func Write(fd io.Writer, data interface{}) (err error) {
}
// Read an object from a stream. Any errors returned are of type
-// NslcdError.
+// NslcdError. If the type assertion succeeds, then
+// fd.(*io.LimitedReader).N is used to prevent an overly-large buffer
+// from being allocated.
func Read(fd io.Reader, data interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
@@ -156,6 +158,16 @@ func write(fd io.Writer, data interface{}) {
}
}
+// Assert that we *will* read n bytes. If we know now that < n bytes
+// will be available, then this will let us avoid an allocation.
+func willread(fd io.Reader, n int64) {
+ if lfd, ok := fd.(*io.LimitedReader); ok {
+ if n > lfd.N {
+ npanic(NslcdError(io.EOF.Error()))
+ }
+ }
+}
+
// Read an object from a stream. In the event of an error, this
// function may panic(NslcdError)! Handle it!
func read(fd io.Reader, data interface{}) {
@@ -179,13 +191,15 @@ func read(fd io.Reader, data interface{}) {
case *string:
var len int32
read(fd, &len)
- buf := make([]byte, len) // BUG(lukeshu): Read: `string` length needs sanity checked
+ willread(fd, int64(len))
+ buf := make([]byte, len)
read(fd, &buf)
*data = string(buf)
case *[]string:
var num int32
read(fd, &num)
- *data = make([]string, num) // BUG(lukeshu): Read: `[]string` length needs sanity checked
+ willread(fd, int64(num * /* min size of a string is: */4))
+ *data = make([]string, num)
for i := 0; i < int(num); i++ {
read(fd, &((*data)[i]))
}
@@ -212,7 +226,8 @@ func read(fd io.Reader, data interface{}) {
case *[]net.IP:
var num int32
read(fd, &num)
- *data = make([]net.IP, num) // BUG(lukeshu): Read: `[]net.IP` length needs sanity checked
+ willread(fd, int64(num * /* min size of an IP is: */net.IPv4len))
+ *data = make([]net.IP, num)
for i := 0; i < int(num); i++ {
read(fd, &((*data)[i]))
}