diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ask-password.c | 153 | ||||
-rw-r--r-- | src/util.c | 10 |
2 files changed, 143 insertions, 20 deletions
diff --git a/src/ask-password.c b/src/ask-password.c index 2c9b027d00..7247c7eda6 100644 --- a/src/ask-password.c +++ b/src/ask-password.c @@ -31,6 +31,8 @@ #include <sys/stat.h> #include <sys/signalfd.h> #include <getopt.h> +#include <termios.h> +#include <limits.h> #include "log.h" #include "macro.h" @@ -38,6 +40,7 @@ static const char *arg_icon = NULL; static const char *arg_message = NULL; +static bool arg_use_tty = true; static usec_t arg_timeout = 60 * USEC_PER_SEC; static int create_socket(char **name) { @@ -93,7 +96,8 @@ static int help(void) { "Query the user for a passphrase.\n\n" " -h --help Show this help\n" " --icon=NAME Icon name\n" - " --timeout=USEC Timeout in usec\n", + " --timeout=USEC Timeout in usec\n" + " --no-tty Ask question via agent even on TTY\n", program_invocation_short_name); return 0; @@ -103,13 +107,15 @@ static int parse_argv(int argc, char *argv[]) { enum { ARG_ICON = 0x100, - ARG_TIMEOUT + ARG_TIMEOUT, + ARG_NO_TTY }; static const struct option options[] = { { "help", no_argument, NULL, 'h' }, { "icon", required_argument, NULL, ARG_ICON }, { "timeout", required_argument, NULL, ARG_TIMEOUT }, + { "no-tty", no_argument, NULL, ARG_NO_TTY }, { NULL, 0, NULL, 0 } }; @@ -137,6 +143,10 @@ static int parse_argv(int argc, char *argv[]) { } break; + case ARG_NO_TTY: + arg_use_tty = false; + break; + case '?': return -EINVAL; @@ -152,29 +162,22 @@ static int parse_argv(int argc, char *argv[]) { } arg_message = argv[optind]; - return 0; + return 1; } -int main(int argc, char *argv[]) { +static int ask_agent(void) { char temp[] = "/dev/.systemd/ask-password/tmp.XXXXXX"; char final[sizeof(temp)] = ""; - int fd = -1, r = EXIT_FAILURE, k; + int fd = -1, r; FILE *f = NULL; char *socket_name = NULL; int socket_fd, signal_fd; sigset_t mask; usec_t not_after; - log_parse_environment(); - log_open(); - - if ((k = parse_argv(argc, argv)) < 0) { - r = k < 0 ? EXIT_FAILURE : EXIT_SUCCESS; - goto finish; - } - if ((fd = mkostemp(temp, O_CLOEXEC|O_CREAT|O_WRONLY)) < 0) { log_error("Failed to create password file: %m"); + r = -errno; goto finish; } @@ -182,6 +185,7 @@ int main(int argc, char *argv[]) { if (!(f = fdopen(fd, "w"))) { log_error("Failed to allocate FILE: %m"); + r = -errno; goto finish; } @@ -193,11 +197,14 @@ int main(int argc, char *argv[]) { if ((signal_fd = signalfd(-1, &mask, SFD_NONBLOCK|SFD_CLOEXEC)) < 0) { log_error("signalfd(): %m"); + r = -errno; goto finish; } - if ((socket_fd = create_socket(&socket_name)) < 0) + if ((socket_fd = create_socket(&socket_name)) < 0) { + r = socket_fd; goto finish; + } not_after = now(CLOCK_MONOTONIC) + arg_timeout; @@ -218,6 +225,7 @@ int main(int argc, char *argv[]) { if (ferror(f)) { log_error("Failed to write query file: %m"); + r = -errno; goto finish; } @@ -229,6 +237,7 @@ int main(int argc, char *argv[]) { if (rename(temp, final) < 0) { log_error("Failed to rename query file: %m"); + r = -errno; goto finish; } @@ -249,6 +258,7 @@ int main(int argc, char *argv[]) { } control; ssize_t n; struct pollfd pollfd[_FD_MAX]; + int k; zero(pollfd); pollfd[FD_SOCKET].fd = socket_fd; @@ -261,12 +271,14 @@ int main(int argc, char *argv[]) { if (errno == EINTR) continue; - log_error("poll() failed: %s", strerror(-r)); + log_error("poll() failed: %m"); + r = -errno; goto finish; } if (k <= 0) { log_notice("Timed out"); + r = -ETIME; goto finish; } @@ -275,6 +287,7 @@ int main(int argc, char *argv[]) { if (pollfd[FD_SOCKET].revents != POLLIN) { log_error("Unexpected poll() event."); + r = -EIO; goto finish; } @@ -296,6 +309,7 @@ int main(int argc, char *argv[]) { continue; log_error("recvmsg() failed: %m"); + r = -errno; goto finish; } @@ -321,9 +335,11 @@ int main(int argc, char *argv[]) { if (passphrase[0] == '+') { passphrase[n] = 0; fputs(passphrase+1, stdout); - } else if (passphrase[0] == '-') + fflush(stdout); + } else if (passphrase[0] == '-') { + r = -ECANCELED; goto finish; - else { + } else { log_error("Invalid packet"); continue; } @@ -331,7 +347,7 @@ int main(int argc, char *argv[]) { break; } - r = EXIT_SUCCESS; + r = 0; finish: if (fd >= 0) @@ -350,3 +366,104 @@ finish: return r; } + +static int ask_tty(void) { + struct termios old_termios, new_termios; + char passphrase[LINE_MAX]; + FILE *ttyf; + + if (!(ttyf = fopen("/dev/tty", "w"))) { + log_error("Failed to open /dev/tty: %m"); + return -errno; + } + + fputs("\x1B[1m", ttyf); + fprintf(ttyf, "%s: ", arg_message); + fputs("\x1B[0m", ttyf); + fflush(ttyf); + + if (tcgetattr(STDIN_FILENO, &old_termios) >= 0) { + + new_termios = old_termios; + + new_termios.c_lflag &= ~(ICANON|ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + + if (tcsetattr(STDIN_FILENO, TCSADRAIN, &new_termios) >= 0) { + size_t p = 0; + int r = 0; + + for (;;) { + size_t k; + char c; + + k = fread(&c, 1, 1, stdin); + + if (k <= 0) { + r = -EIO; + break; + } + + if (c == '\n') + break; + else if (c == '\b' || c == 127) { + if (p > 0) { + p--; + fputs("\b \b", ttyf); + } + } else { + passphrase[p++] = c; + fputc('*', ttyf); + } + + fflush(ttyf); + } + + fputc('\n', ttyf); + fclose(ttyf); + tcsetattr(STDIN_FILENO, TCSADRAIN, &old_termios); + + if (r < 0) + return -EIO; + + passphrase[p] = 0; + + fputs(passphrase, stdout); + fflush(stdout); + return 0; + } + + } + + fclose(ttyf); + + if (!fgets(passphrase, sizeof(passphrase), stdin)) { + log_error("Failed to read password."); + return -EIO; + } + + truncate_nl(passphrase); + fputs(passphrase, stdout); + fflush(stdout); + + return 0; +} + +int main(int argc, char *argv[]) { + int r; + + log_parse_environment(); + log_open(); + + if ((r = parse_argv(argc, argv)) <= 0) + goto finish; + + if (arg_use_tty && isatty(STDIN_FILENO)) + r = ask_tty(); + else + r = ask_agent(); + +finish: + return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS; +} diff --git a/src/util.c b/src/util.c index c1ee936aaf..bdc194e80c 100644 --- a/src/util.c +++ b/src/util.c @@ -2009,23 +2009,29 @@ int read_one_char(FILE *f, char *ret, bool *need_nl) { } int ask(char *ret, const char *replies, const char *text, ...) { + bool on_tty; + assert(ret); assert(replies); assert(text); + on_tty = isatty(STDOUT_FILENO); + for (;;) { va_list ap; char c; int r; bool need_nl = true; - fputs("\x1B[1m", stdout); + if (on_tty) + fputs("\x1B[1m", stdout); va_start(ap, text); vprintf(text, ap); va_end(ap); - fputs("\x1B[0m", stdout); + if (on_tty) + fputs("\x1B[0m", stdout); fflush(stdout); |