summaryrefslogtreecommitdiff
path: root/nslcd_systemd
diff options
context:
space:
mode:
Diffstat (limited to 'nslcd_systemd')
-rw-r--r--nslcd_systemd/misc_test.go45
-rw-r--r--nslcd_systemd/util_test.go19
2 files changed, 64 insertions, 0 deletions
diff --git a/nslcd_systemd/misc_test.go b/nslcd_systemd/misc_test.go
index 8ef697a..0e6f2e9 100644
--- a/nslcd_systemd/misc_test.go
+++ b/nslcd_systemd/misc_test.go
@@ -283,3 +283,48 @@ func TestClientHang(t *testing.T) {
}()
}
}
+
+func TestLargeRequest(t *testing.T) {
+ tmpdir, err := ioutil.TempDir("", "go-test-libnslcd-large-request.")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tmpdir)
+ defer sdActivatedReset()
+ t.Run("large-request", func(t *testing.T) {
+ testWithTimeout(t, 2*time.Second, func(t *testing.T, s chan<- string) {
+ GiB := 1024*1024*1024
+ ctx := &testContext{T: t, tmpdir: tmpdir, status: s}
+
+ backend := &LockingBackend{}
+
+ limits := nslcd_server.Limits{
+ Timeout: 1 * time.Second,
+ }
+
+ notifyHandler := func(dat []byte, oob []byte) error { return nil }
+
+ client := func(sockname string) {
+ errfatal := func(err error) {
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ conn, err := net.Dial("unix", sockname)
+ errfatal(err)
+ errfatal(nslcd_proto.Write(conn, nslcd_proto.NSLCD_VERSION))
+ errfatal(nslcd_proto.Write(conn, nslcd_proto.NSLCD_ACTION_PASSWD_BYNAME))
+ errfatal(nslcd_proto.Write(conn, int32(1*GiB)))
+ }
+
+ testDriver(ctx, backend, limits, notifyHandler, client)
+
+ var memstats runtime.MemStats
+ runtime.ReadMemStats(&memstats)
+ if memstats.HeapSys > uint64(1*GiB) {
+ t.Fatalf("Used more than 1 GiB heap: %s B", humanizeU64(memstats.HeapSys))
+ }
+ })
+ })
+}
diff --git a/nslcd_systemd/util_test.go b/nslcd_systemd/util_test.go
index 15147b7..89cf7cb 100644
--- a/nslcd_systemd/util_test.go
+++ b/nslcd_systemd/util_test.go
@@ -191,3 +191,22 @@ func sdNotifyHandle(sock *net.UnixConn, fn func(dat []byte, oob []byte) error) e
}
}
}
+
+func humanizeU64(n uint64) string {
+ str := fmt.Sprintf("%d", n)
+ bts := make([]byte, len(str)+(len(str)-1)/3)
+
+ s := 0
+ b := 0
+ for s < len(str) && b < len(bts) {
+ if (s % 3 == 0 && s > 0) {
+ bts[len(bts)-1-b] = ','
+ b++
+ }
+ bts[len(bts)-1-b] = str[len(str)-1-s]
+ b++
+ s++
+ }
+
+ return string(bts)
+}