/***
  This file is part of systemd.

  Copyright 2014 Lennart Poettering

  systemd is free software; you can redistribute it and/or modify it
  under the terms of the GNU Lesser General Public License as published by
  the Free Software Foundation; either version 2.1 of the License, or
  (at your option) any later version.

  systemd is distributed in the hope that it will be useful, but
  WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public License
  along with systemd; If not, see <http://www.gnu.org/licenses/>.
***/

#include <dirent.h>
#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <mqueue.h>
#include <stdbool.h>
#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/sem.h>
#include <sys/shm.h>
#include <sys/stat.h>
#include <unistd.h>

#include "clean-ipc.h"
#include "dirent-util.h"
#include "fd-util.h"
#include "fileio.h"
#include "formats-util.h"
#include "log.h"
#include "macro.h"
#include "string-util.h"
#include "strv.h"
#include "user-util.h"

static bool match_uid_gid(uid_t subject_uid, gid_t subject_gid, uid_t delete_uid, gid_t delete_gid) {

        if (uid_is_valid(delete_uid) && subject_uid == delete_uid)
                return true;

        if (gid_is_valid(delete_gid) && subject_gid == delete_gid)
                return true;

        return false;
}

static int clean_sysvipc_shm(uid_t delete_uid, gid_t delete_gid) {
        _cleanup_fclose_ FILE *f = NULL;
        char line[LINE_MAX];
        bool first = true;
        int ret = 0;

        f = fopen("/proc/sysvipc/shm", "re");
        if (!f) {
                if (errno == ENOENT)
                        return 0;

                return log_warning_errno(errno, "Failed to open /proc/sysvipc/shm: %m");
        }

        FOREACH_LINE(line, f, goto fail) {
                unsigned n_attached;
                pid_t cpid, lpid;
                uid_t uid, cuid;
                gid_t gid, cgid;
                int shmid;

                if (first) {
                        first = false;
                        continue;
                }

                truncate_nl(line);

                if (sscanf(line, "%*i %i %*o %*u " PID_FMT " " PID_FMT " %u " UID_FMT " " GID_FMT " " UID_FMT " " GID_FMT,
                           &shmid, &cpid, &lpid, &n_attached, &uid, &gid, &cuid, &cgid) != 8)
                        continue;

                if (n_attached > 0)
                        continue;

                if (!match_uid_gid(uid, gid, delete_uid, delete_gid))
                        continue;

                if (shmctl(shmid, IPC_RMID, NULL) < 0) {

                        /* Ignore entries that are already deleted */
                        if (errno == EIDRM || errno == EINVAL)
                                continue;

                        ret = log_warning_errno(errno,
                                                "Failed to remove SysV shared memory segment %i: %m",
                                                shmid);
                } else
                        log_debug("Removed SysV shared memory segment %i.", shmid);
        }

        return ret;

fail:
        return log_warning_errno(errno, "Failed to read /proc/sysvipc/shm: %m");
}

static int clean_sysvipc_sem(uid_t delete_uid, gid_t delete_gid) {
        _cleanup_fclose_ FILE *f = NULL;
        char line[LINE_MAX];
        bool first = true;
        int ret = 0;

        f = fopen("/proc/sysvipc/sem", "re");
        if (!f) {
                if (errno == ENOENT)
                        return 0;

                return log_warning_errno(errno, "Failed to open /proc/sysvipc/sem: %m");
        }

        FOREACH_LINE(line, f, goto fail) {
                uid_t uid, cuid;
                gid_t gid, cgid;
                int semid;

                if (first) {
                        first = false;
                        continue;
                }

                truncate_nl(line);

                if (sscanf(line, "%*i %i %*o %*u " UID_FMT " " GID_FMT " " UID_FMT " " GID_FMT,
                           &semid, &uid, &gid, &cuid, &cgid) != 5)
                        continue;

                if (!match_uid_gid(uid, gid, delete_uid, delete_gid))
                        continue;

                if (semctl(semid, 0, IPC_RMID) < 0) {

                        /* Ignore entries that are already deleted */
                        if (errno == EIDRM || errno == EINVAL)
                                continue;

                        ret = log_warning_errno(errno,
                                                "Failed to remove SysV semaphores object %i: %m",
                                                semid);
                } else
                        log_debug("Removed SysV semaphore %i.", semid);
        }

        return ret;

fail:
        return log_warning_errno(errno, "Failed to read /proc/sysvipc/sem: %m");
}

static int clean_sysvipc_msg(uid_t delete_uid, gid_t delete_gid) {
        _cleanup_fclose_ FILE *f = NULL;
        char line[LINE_MAX];
        bool first = true;
        int ret = 0;

        f = fopen("/proc/sysvipc/msg", "re");
        if (!f) {
                if (errno == ENOENT)
                        return 0;

                return log_warning_errno(errno, "Failed to open /proc/sysvipc/msg: %m");
        }

        FOREACH_LINE(line, f, goto fail) {
                uid_t uid, cuid;
                gid_t gid, cgid;
                pid_t cpid, lpid;
                int msgid;

                if (first) {
                        first = false;
                        continue;
                }

                truncate_nl(line);

                if (sscanf(line, "%*i %i %*o %*u %*u " PID_FMT " " PID_FMT " " UID_FMT " " GID_FMT " " UID_FMT " " GID_FMT,
                           &msgid, &cpid, &lpid, &uid, &gid, &cuid, &cgid) != 7)
                        continue;

                if (!match_uid_gid(uid, gid, delete_uid, delete_gid))
                        continue;

                if (msgctl(msgid, IPC_RMID, NULL) < 0) {

                        /* Ignore entries that are already deleted */
                        if (errno == EIDRM || errno == EINVAL)
                                continue;

                        ret = log_warning_errno(errno,
                                                "Failed to remove SysV message queue %i: %m",
                                                msgid);
                } else
                        log_debug("Removed SysV message queue %i.", msgid);
        }

        return ret;

fail:
        return log_warning_errno(errno, "Failed to read /proc/sysvipc/msg: %m");
}

static int clean_posix_shm_internal(DIR *dir, uid_t uid, gid_t gid) {
        struct dirent *de;
        int ret = 0, r;

        assert(dir);

        FOREACH_DIRENT_ALL(de, dir, goto fail) {
                struct stat st;

                if (STR_IN_SET(de->d_name, "..", "."))
                        continue;

                if (fstatat(dirfd(dir), de->d_name, &st, AT_SYMLINK_NOFOLLOW) < 0) {
                        if (errno == ENOENT)
                                continue;

                        ret = log_warning_errno(errno, "Failed to stat() POSIX shared memory segment %s: %m", de->d_name);
                        continue;
                }

                if (!match_uid_gid(st.st_uid, st.st_gid, uid, gid))
                        continue;

                if (S_ISDIR(st.st_mode)) {
                        _cleanup_closedir_ DIR *kid;

                        kid = xopendirat(dirfd(dir), de->d_name, O_NOFOLLOW|O_NOATIME);
                        if (!kid) {
                                if (errno != ENOENT)
                                        ret = log_warning_errno(errno, "Failed to enter shared memory directory %s: %m", de->d_name);
                        } else {
                                r = clean_posix_shm_internal(kid, uid, gid);
                                if (r < 0)
                                        ret = r;
                        }

                        if (unlinkat(dirfd(dir), de->d_name, AT_REMOVEDIR) < 0) {

                                if (errno == ENOENT)
                                        continue;

                                ret = log_warning_errno(errno, "Failed to remove POSIX shared memory directory %s: %m", de->d_name);
                        } else
                                log_debug("Removed POSIX shared memory directory %s", de->d_name);
                } else {

                        if (unlinkat(dirfd(dir), de->d_name, 0) < 0) {

                                if (errno == ENOENT)
                                        continue;

                                ret = log_warning_errno(errno, "Failed to remove POSIX shared memory segment %s: %m", de->d_name);
                        } else
                                log_debug("Removed POSIX shared memory segment %s", de->d_name);
                }
        }

        return ret;

fail:
        return log_warning_errno(errno, "Failed to read /dev/shm: %m");
}

static int clean_posix_shm(uid_t uid, gid_t gid) {
        _cleanup_closedir_ DIR *dir = NULL;

        dir = opendir("/dev/shm");
        if (!dir) {
                if (errno == ENOENT)
                        return 0;

                return log_warning_errno(errno, "Failed to open /dev/shm: %m");
        }

        return clean_posix_shm_internal(dir, uid, gid);
}

static int clean_posix_mq(uid_t uid, gid_t gid) {
        _cleanup_closedir_ DIR *dir = NULL;
        struct dirent *de;
        int ret = 0;

        dir = opendir("/dev/mqueue");
        if (!dir) {
                if (errno == ENOENT)
                        return 0;

                return log_warning_errno(errno, "Failed to open /dev/mqueue: %m");
        }

        FOREACH_DIRENT_ALL(de, dir, goto fail) {
                struct stat st;
                char fn[1+strlen(de->d_name)+1];

                if (STR_IN_SET(de->d_name, "..", "."))
                        continue;

                if (fstatat(dirfd(dir), de->d_name, &st, AT_SYMLINK_NOFOLLOW) < 0) {
                        if (errno == ENOENT)
                                continue;

                        ret = log_warning_errno(errno,
                                                "Failed to stat() MQ segment %s: %m",
                                                de->d_name);
                        continue;
                }

                if (!match_uid_gid(st.st_uid, st.st_gid, uid, gid))
                        continue;

                fn[0] = '/';
                strcpy(fn+1, de->d_name);

                if (mq_unlink(fn) < 0) {
                        if (errno == ENOENT)
                                continue;

                        ret = log_warning_errno(errno,
                                                "Failed to unlink POSIX message queue %s: %m",
                                                fn);
                } else
                        log_debug("Removed POSIX message queue %s", fn);
        }

        return ret;

fail:
        return log_warning_errno(errno, "Failed to read /dev/mqueue: %m");
}

int clean_ipc(uid_t uid, gid_t gid) {
        int ret = 0, r;

        /* Anything to do? */
        if (!uid_is_valid(uid) && !gid_is_valid(gid))
                return 0;

        /* Refuse to clean IPC of the root user */
        if (uid == 0 && gid == 0)
                return 0;

        r = clean_sysvipc_shm(uid, gid);
        if (r < 0)
                ret = r;

        r = clean_sysvipc_sem(uid, gid);
        if (r < 0)
                ret = r;

        r = clean_sysvipc_msg(uid, gid);
        if (r < 0)
                ret = r;

        r = clean_posix_shm(uid, gid);
        if (r < 0)
                ret = r;

        r = clean_posix_mq(uid, gid);
        if (r < 0)
                ret = r;

        return ret;
}

int clean_ipc_by_uid(uid_t uid) {
        return clean_ipc(uid, GID_INVALID);
}

int clean_ipc_by_gid(gid_t gid) {
        return clean_ipc(UID_INVALID, gid);
}