// Copyright (C) 2015, 2017 Luke Shumaker // // This library is free software; you can redistribute it and/or // modify it under the terms of the GNU Lesser General Public // License as published by the Free Software Foundation; either // version 2.1 of the License, or (at your option) any later version. // // This library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public // License along with this library; if not, write to the Free Software // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA // 02110-1301 USA package nslcd_proto import ( "encoding/binary" "fmt" "io" "net" "reflect" "golang.org/x/sys/unix" ) // NslcdError represents a normal, expected error when dealing with // the nslcd protocol. Passing invalid data to a Write operation is // *not* an NslcdError, nor is passing a non-pointer to a Read // operation; those are programming errors, and result in a panic(). type NslcdError string func (o NslcdError) Error() string { return string(o) } func npanic(err NslcdError) { panic(err) } // An nslcdObject is an object with a different network representation // than a naive structure. type nslcdObject interface { // May panic(interface{}) if given invalid data. // // May panic(NslcdError) if encountering a network error. nslcdWrite(fd io.Writer) } // An nslcdObjectPtr is a pointer to an object with a different // network representation than a naive structure. type nslcdObjectPtr interface { // May panic(NslcdError) if encountering a network error. nslcdRead(fd io.Reader) } // Write an object to a stream. Any errors returned are of type // NslcdError. func Write(fd io.Writer, data interface{}) (err error) { defer func() { if r := recover(); r != nil { switch r := r.(type) { case NslcdError: err = r default: panic(r) } } }() write(fd, data) return err } // Read an object from a stream. Any errors returned are of type // NslcdError. If the type assertion succeeds, then // fd.(*io.LimitedReader).N is used to prevent an overly-large buffer // from being allocated. func Read(fd io.Reader, data interface{}) (err error) { defer func() { if r := recover(); r != nil { switch r := r.(type) { case NslcdError: err = r default: panic(r) } } }() read(fd, data) return err } // Write an object to a stream. In the event of an error, this // function may panic(NslcdError)! Handle it! func write(fd io.Writer, data interface{}) { switch data := data.(type) { // basic data types case nslcdObject: data.nslcdWrite(fd) case []byte: if len(data) > 0 { _, err := fd.Write(data) if err != nil { npanic(NslcdError(err.Error())) } } case int32: err := binary.Write(fd, binary.BigEndian, data) if err != nil { npanic(NslcdError(err.Error())) } // composite datatypes case string: write(fd, int32(len(data))) write(fd, []byte(data)) case []string: write(fd, int32(len(data))) for _, item := range data { write(fd, item) } case net.IP: var af int32 = -1 switch len(data) { case net.IPv4len: af = unix.AF_INET case net.IPv6len: af = unix.AF_INET6 } var bytes []byte if af < 0 { bytes = make([]byte, 0) } else { bytes = data } write(fd, af) write(fd, int32(len(bytes))) write(fd, bytes) case []net.IP: write(fd, int32(len(data))) for _, item := range data { write(fd, item) } default: v := reflect.ValueOf(data) switch v.Kind() { case reflect.Struct: for i, n := 0, v.NumField(); i < n; i++ { write(fd, v.Field(i).Interface()) } default: panic(fmt.Sprintf("Invalid structure to write NSLCD protocol data from: %T ( %#v )", data, data)) } } } // Assert that we *will* read n bytes. If we know now that < n bytes // will be available, then this will let us avoid an allocation. func willread(fd io.Reader, n int64) { if lfd, ok := fd.(*io.LimitedReader); ok { if n > lfd.N { npanic(NslcdError(io.EOF.Error())) } } } // Read an object from a stream. In the event of an error, this // function may panic(NslcdError)! Handle it! func read(fd io.Reader, data interface{}) { switch data := data.(type) { // basic data types case nslcdObjectPtr: data.nslcdRead(fd) case *[]byte: if len(*data) > 0 { _, err := io.ReadFull(fd, *data) if err != nil { npanic(NslcdError(err.Error())) } } case *int32: err := binary.Read(fd, binary.BigEndian, data) if err != nil { npanic(NslcdError(err.Error())) } // composite datatypes case *string: var len int32 read(fd, &len) willread(fd, int64(len)) buf := make([]byte, len) read(fd, &buf) *data = string(buf) case *[]string: var num int32 read(fd, &num) willread(fd, int64(num * /* min size of a string is: */4)) *data = make([]string, num) for i := 0; i < int(num); i++ { read(fd, &((*data)[i])) } case *net.IP: var af int32 read(fd, &af) var _len int32 switch af { case unix.AF_INET: _len = net.IPv4len case unix.AF_INET6: _len = net.IPv6len default: npanic(NslcdError(fmt.Sprintf("incorrect address family specified: %d", af))) } var len int32 read(fd, &len) if len != _len { npanic(NslcdError(fmt.Sprintf("address length incorrect: %d", len))) } buf := make([]byte, len) read(fd, &buf) *data = buf case *[]net.IP: var num int32 read(fd, &num) willread(fd, int64(num * /* min size of an IP is: */net.IPv4len)) *data = make([]net.IP, num) for i := 0; i < int(num); i++ { read(fd, &((*data)[i])) } default: p := reflect.ValueOf(data) v := reflect.Indirect(p) if p == v || v.Kind() != reflect.Struct { panic(fmt.Sprintf("The argument to nslcd_proto.Read() must be a pointer: %T ( %#v )", data, data)) } for i, n := 0, v.NumField(); i < n; i++ { read(fd, v.Field(i).Addr().Interface()) } } }