/*
 * standalone-socket-activate
 *
 * Copyright (c) 2019  Peter Pentchev
 * Copyright (c) 2023- Mark Hindley
 *
 * License: GPL-3.0+
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>

#include <err.h>
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
#include <limits.h>
#include <netdb.h>
#include <poll.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <arpa/inet.h>
#include <netinet/in.h>

#define SD_LISTEN_FDS_START 3

#ifndef __printflike
#if defined(__GNUC__) && __GNUC__ >= 3
#define __printflike(x, y)	__attribute__((format(printf, (x), (y))))
#define __unused		__attribute__((unused))
#else
#define __printflike(x, y)
#define __unused
#endif
#endif

#define VERSION_STRING	"0.4unreleased"

enum option_id_t {
	OPT_LABEL,
	OPT_MODE,
	OPT_USER,
	OPT_GROUP,
	OPT_BACKLOG,
	OPTIONS_COUNT
};

struct option_t {
	const char *value;
	bool handled;
};

static const char * const option_names[OPTIONS_COUNT] = {
	"label",
	"mode",
	"user",
	"group",
	"backlog",
};

static bool		verbose;

static void
usage(const bool _ferr)
{
	const char * const s =
	    "Usage:\tsocket-activate [-v | --verbose] --family:[options,...]:address... [--] program [arg...]\n"
	    "\tsocket-activate -V | -h | --version | --help\n"
	    "\tsocket-activate --features\n"
	    "\n"
	    "\t-h\tdisplay program usage information and exit\n"
	    "\t-V\tdisplay program version information and exit\n"
	    "\t-v\tverbose operation; display diagnostic output\n";

	fprintf(_ferr? stderr: stdout, "%s", s);
	if (_ferr)
		exit(1);
}

static void
version(void)
{
	puts("socket-activate " VERSION_STRING);
}

static void
features(void)
{
	puts("Features: socket-activate=" VERSION_STRING);
}

__printflike(1, 2)
static void
debug(const char * const fmt, ...)
{
	va_list v;

	va_start(v, fmt);
	if (verbose)
		vfprintf(stderr, fmt, v);
	va_end(v);
}

static long
strtolong_fatal(const char * const s, const int base)
{
	errno = 0;
	char *endp;
	const long value = strtol(s, &endp, base);
	if (*endp != '\0' ||
	    (errno == ERANGE && (value == LONG_MIN || value == LONG_MAX)))
		errx(1, "Invalid base %d numeric value '%s'", base, s);
	return value;
}

static int
connect_unix(const int socktype, struct option_t * const options,
		char * const address)
{
	const int sock = socket(AF_UNIX, socktype, 0);
	if (sock == -1)
		err(EXIT_FAILURE, "Could not create a Unix-domain socket");

	mode_t oldmask = 0;
	long mode = 0;
	struct sockaddr_un addr = { 0 };
	addr.sun_family = AF_UNIX;
	const int written = snprintf(addr.sun_path, sizeof(addr.sun_path),
	    "%s", address);
	if (written < 0 || (size_t)written >= sizeof(addr.sun_path))
		err(EXIT_FAILURE, "Path too long for a Unix-domain socket");
	if (unlink(address) == -1 && errno != ENOENT)
		err(EXIT_FAILURE, "Failed to remove existing Unix-domain socket %s", address);
	/* Set umask for bind */
	if (options[OPT_MODE].value != NULL) {
		options[OPT_MODE].handled = true;
		mode = strtolong_fatal(options[OPT_MODE].value, 8);
		if (mode < 0)
			errx(1, "Invalid access mode '%s' for '%s'",
			    options[OPT_MODE].value, address);
		oldmask = umask(~mode);
	}
	if (bind(sock, (const struct sockaddr *)&addr, sizeof(addr)) == -1)
		err(EXIT_FAILURE, "Could not bind the Unix-domain socket to %s", address);
	if (options[OPT_MODE].value != NULL)
	  	umask(oldmask);

	uid_t uid = 0;
	gid_t gid = 0;
	if (options[OPT_USER].value != NULL) {
		options[OPT_USER].handled = true;
		const long luid = strtolong_fatal(options[OPT_USER].value, 10);
		if (luid < 0)
			errx(1, "The user id must be non-negative");
		uid = (uid_t)luid;

		if (options[OPT_GROUP].value == NULL)
			gid = getgid();
	}
	if (options[OPT_GROUP].value != NULL) {
		options[OPT_GROUP].handled = true;
		const long lgid = strtolong_fatal(options[OPT_GROUP].value, 10);
		if (lgid < 0)
			errx(1, "The group id must be non-negative");
		gid = (gid_t)lgid;

		if (options[OPT_USER].value == NULL)
			uid = getuid();
	}
	if (options[OPT_USER].value != NULL ||
	    options[OPT_GROUP].value != NULL)
		if (chown(address, uid, gid) == -1)
			err(EXIT_FAILURE, "Could not set the Unix socket ownership to "
			    "%jd:%jd", (intmax_t)uid, (intmax_t)gid);

	return sock;
}

static int
connect_inet(const int socktype, struct option_t * const options __unused,
		char * const address)
{
	const int proto = socktype == SOCK_STREAM ? IPPROTO_TCP : IPPROTO_UDP;
	char *next = address;
	const char * const first = strsep(&next, "/");
	const char * const addrstr = next == NULL ? "0.0.0.0" : first;
	const char * const portstr = next == NULL ? first : next;
	struct addrinfo hints = {
		.ai_family = next == NULL ? AF_INET : AF_UNSPEC,
		.ai_socktype = socktype,
		.ai_protocol = proto,
		.ai_flags = AI_NUMERICHOST | AI_NUMERICSERV | AI_PASSIVE,
	};
	struct addrinfo *res;
	const int ret = getaddrinfo(addrstr, portstr, &hints, &res);
	if (ret != 0)
		errx(1, "Could not parse the %s address: %s",
		    address, gai_strerror(ret));
	if (res == NULL)
		errx(1, "Could not parse the %s address", address);

	if (verbose) {
		char buf[200];
		short port;
		const void *vaddr;
		if (res->ai_family == AF_INET) {
			const struct sockaddr_in * const iaddr =
			    (const struct sockaddr_in *)res->ai_addr;
			vaddr = &iaddr->sin_addr;
			port = iaddr->sin_port;
		} else {
			const struct sockaddr_in6 * const iaddr =
			    (const struct sockaddr_in6 *)res->ai_addr;
			vaddr = &iaddr->sin6_addr;
			port = iaddr->sin6_port;
		}
		debug("got family %d type %d proto %d len %d "
		    "address %s port %d\n",
		    res->ai_family, res->ai_socktype, res->ai_protocol,
		    res->ai_addrlen,
		    inet_ntop(res->ai_family, vaddr, buf, res->ai_addrlen),
		    ntohs(port));
	}

	const int sock = socket(res->ai_family, res->ai_socktype,
	    res->ai_protocol);
	if (sock == -1)
		err(EXIT_FAILURE, "Could not create a %s socket",
		    socktype == SOCK_STREAM ? "tcp" : "udp");
	const int optval = 1;
	if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR,
		       &optval, sizeof(optval)) == -1
#ifndef __gnu_hurd__
	    || setsockopt(sock, SOL_SOCKET, SO_REUSEPORT,
			  &optval, sizeof(optval)) == -1
#endif
	    )
		err(EXIT_FAILURE, "Could not set the socket options for %s:%s",
		    addrstr, portstr);
	if (bind(sock, res->ai_addr, res->ai_addrlen) == -1)
		err(EXIT_FAILURE, "Could not bind to %s:%s", addrstr, portstr);

	return sock;
}


struct connect_func_t {
	const char * const family;
	int (* const func)(int, struct option_t * const options, char *);
	const int socktype;
};

static const struct connect_func_t connect_functions[] = {
	{"unix", connect_unix, SOCK_STREAM},
	{"unix-dgram", connect_unix, SOCK_DGRAM},
	{"tcp", connect_inet, SOCK_STREAM},
	{"udp", connect_inet, SOCK_DGRAM},
};

static inline const struct connect_func_t *
get_connect_function(const char * const family)
{
	for (size_t idx = 0; idx < sizeof(connect_functions) / sizeof(connect_functions[0]); idx++)
		if (strcmp(family, connect_functions[idx].family) == 0)
			return &connect_functions[idx];
	errx(1, "Invalid socket family '%s'", family);
}

static inline enum option_id_t
get_option_id(const char * const name)
{
	for (size_t idx = 0; idx < (size_t)OPTIONS_COUNT; idx++)
		if (strcmp(option_names[idx], name) == 0)
			return (enum option_id_t)idx;
	return OPTIONS_COUNT;
}

static const char *
add_listener(const char * const optarg, const int nextfd)
{
	char *next = strdup(optarg);
	if (next == NULL)
		err(EXIT_FAILURE, "Could not allocate memory");

	const char * const family = strsep(&next, ":");
	const struct connect_func_t * const func = get_connect_function(family);

	struct option_t * const options = calloc(OPTIONS_COUNT, sizeof(*options));
	char *optstr = strsep(&next, ":");
	char *current;
	while (current = strsep(&optstr, ","), current != NULL && strlen(current)) {
		const char * const name = strsep(&current, "=");
		if (name[0] == '\0')
			errx(1, "Empty option name in '%s'", optarg);
		const enum option_id_t idx = get_option_id(name);
		if (idx == OPTIONS_COUNT)
			errx(1, "Unknown option '%s' in '%s'", name, optarg);

		if (current[0] == '\0')
			errx(1, "No value for option '%s' in '%s'",
			    name, optarg);
		options[idx].value = current;
	}

	const char * const label = options[OPT_LABEL].value;
	options[OPT_LABEL].handled = true;
	char * const address = next;
	const int fd = func->func(func->socktype, options, address);
	if (fd != nextfd) {
		debug("moving %s fd %d to %d\n", optarg, fd, nextfd);
		if (dup2(fd, nextfd) == -1)
			err(EXIT_FAILURE, "Could not move %s fd %d to %d",
			    optarg, fd, nextfd);
		if (close(fd) == -1)
			err(EXIT_FAILURE, "Could not close %s fd %d after "
			    "moving it to %d",
			    optarg, fd, nextfd);
		/* fd now closed, don't use anymore */
	} else {
		debug("no need to move %s fd %d\n", optarg, fd);
	}
	debug("fd %d: label %s, family %s, address %s\n",
	    nextfd, label, family, address);

	if (func->socktype == SOCK_STREAM) {
		int backlog = SOMAXCONN;
		if (options[OPT_BACKLOG].value != NULL) {
			options[OPT_BACKLOG].handled = true;
			long value = strtolong_fatal(
			    options[OPT_BACKLOG].value, 10);
			if (value < 1)
				err(EXIT_FAILURE, "The listen backlog for %s must be "
				    "a positive number", optarg);
#if LONG_MAX > INT_MAX
			if (value > INT_MAX)
				err(EXIT_FAILURE, "The listen backlog for %s must be "
				    "no more than %d", optarg, INT_MAX);
#endif
			backlog = value;
		}
		if (listen(nextfd, backlog) == -1)
			err(EXIT_FAILURE, "Could not listen for %s", optarg);
	}

	for (size_t idx = 0; idx < OPTIONS_COUNT; idx++)
		if (options[idx].value != NULL && !options[idx].handled)
			errx(1, "Unexpected option '%s' in '%s'",
			    option_names[idx], optarg);
	free(options);

	return label != NULL ? label : "";
}

int
main(int argc, char * const argv[])
{
	char *fdnames = NULL;
	bool hflag = false, Vflag = false, show_features = false;
	int ch;
	int nextfd = SD_LISTEN_FDS_START;
	while ((ch = getopt(argc, argv, "hVv-:")) != -1) {
		switch ((char)ch) {
			case 'h':
				hflag = true;
				break;

			case 'V':
				Vflag = true;
				break;

			case 'v':
				verbose = true;
				break;

			case '-':
				if (strcmp(optarg, "help") == 0)
					hflag = true;
				else if (strcmp(optarg, "version") == 0)
					Vflag = true;
				else if (strcmp(optarg, "verbose") == 0)
					verbose = true;
				else if (strcmp(optarg, "features") == 0)
					show_features = true;
				else {
					const char * const label =
					    add_listener(optarg, nextfd);
					nextfd++;
					char *newnames = NULL;
					const int res = asprintf(&newnames,
					    "%s%s%s",
					    fdnames == NULL ? "" : fdnames,
					    fdnames == NULL ? "" : ":",
					    label);
					if (res < 0 || newnames == NULL)
						err(EXIT_FAILURE, "Could not allocate "
						    "the names list");
					if (fdnames == NULL || strcmp(newnames,fdnames) != 0) {
						free(fdnames);
						fdnames = newnames;
					} else {
						free(newnames);
					}
				}
				break;

			default:
				usage(true);
				/* NOTREACHED */
		}
	}
	if (Vflag)
		version();
	if (hflag)
		usage(false);
	if (show_features)
		features();
	if (Vflag || hflag || show_features)
		return (0);

	argc -= optind;
	argv += optind;
	if (argc < 1)
		usage(true);

	struct pollfd *pfds;
	pid_t pid;
	nfds_t nfds = nextfd - SD_LISTEN_FDS_START;
	nfds_t j = 0;
	pfds = calloc(nfds, sizeof(struct pollfd));
	if (pfds == NULL)
		err(EXIT_FAILURE, "Malloc failed");

	for (int fd = SD_LISTEN_FDS_START; fd < nextfd; fd++) {
		pfds[j].events = POLLIN;
		pfds[j++].fd = fd;
	}

	debug("Polling %ld fds\n", nfds);
	const int n = poll(pfds, nfds, -1);
	if (n == -1)
		err(EXIT_FAILURE, "Could not wait for incoming connections");
	debug("got %d fds from poll()\n", n);
	char buf[20];
	pid = getpid();
	snprintf(buf, sizeof(buf), "%d", pid);
	setenv("LISTEN_PID", buf, 1);
	snprintf(buf, sizeof(buf), "%d", nextfd - SD_LISTEN_FDS_START);
	setenv("LISTEN_FDS", buf, 1);
	setenv("LISTEN_FDNAMES", fdnames, 1);
	debug("Exec %s, PID %d\n", argv[0], pid);
	execvp(argv[0], argv);
	err(EXIT_FAILURE, "Could not execute '%s'", argv[0]);
}
