#include <stdlib.h>
#include <unistd.h>
#include <sys/types.h>
#include <stdio.h>
#include <string.h>
#include <err.h>
#include <ctype.h>
#include <pwd.h>
#include <errno.h>
#ifndef XMSH_PATH
#error must define XMSH_PATH
#endif
#define STRINGIFY(x) XSTRINGIFY(x)
#define XSTRINGIFY(x) #x
#define USERNAME_MAX 64
#define VMNAME_MAX 64
#define USERFILE_PATH_FMT (STRINGIFY(XMSH_PATH) "/users/%s")
#define USERFILE_PATH_MAX 64 + USERNAME_MAX
#define XM_PATH "/usr/sbin/xm"
/*
* Parse command:
* xmsh <cmd>
* xmsh -c <cmd>
*/
char *parse_command (int argc, char **argv) {
if (argc == 2)
return argv[1];
else if (argc == 3 && strcmp(argv[1], "-c") == 0)
return argv[2];
// fail
errx(EXIT_FAILURE, "usage: ssh [-t] <dom0> (list|reboot|console)");
}
/*
* Validate that the given command is legal
*/
void validate_command (const char *command) {
// strcmp against whitelist of commands
if (0
|| (strcmp(command, "list") == 0)
|| (strcmp(command, "reboot") == 0)
|| (strcmp(command, "console") == 0)
)
return;
// else fail
err(EXIT_FAILURE, "invalid command: %s", command);
}
/*
* Validate that the username is sane
*/
void validate_username (const char *c) {
if (!(*c))
errx(EXIT_FAILURE, "username length");
for (; *c; c++) {
if (!isalpha(*c) && *c != '-') {
break;
}
}
if (*c)
errx(EXIT_FAILURE, "username non-alpha");
}
void validate_vmname (const char *c) {
if (!(*c))
errx(EXIT_FAILURE, "vmname length");
for (; *c; c++) {
if (!isprint(*c) || isspace(*c)) {
break;
}
}
if (*c)
errx(EXIT_FAILURE, "vmname non-print/space");
}
/*
* Get the real uid's username - i.e. the user who executed this file.
*/
void get_username (char buf[USERNAME_MAX]) {
uid_t uid;
struct passwd *passwd;
// get the real uid
uid = getuid();
// get the passwd entry
if ((passwd = getpwuid(uid)) == NULL)
err(EXIT_FAILURE, "getpwuid");
if (passwd->pw_name == NULL)
errx(EXIT_FAILURE, "passwd->pw_name");
// fail too-long usernames
if (strlen(passwd->pw_name) >= USERNAME_MAX)
errx(EXIT_FAILURE, "strlen(passwd->pw_name) >= USERNAME_MAX");
// copy the username to buf
strncpy(buf, passwd->pw_name, USERNAME_MAX);
// force zero-terminate
buf[USERNAME_MAX - 1] = '\0';
// sanity-check username to be all-alpha
validate_username(buf);
}
/*
* Get the virtual machine name for the current user
*/
void get_vmname (const char *username, char buf[VMNAME_MAX]) {
// the path to the userfile
char path[USERFILE_PATH_MAX], *nl;
FILE *fh;
// format the userfile path
if (snprintf(path, USERFILE_PATH_MAX, USERFILE_PATH_FMT, username) >= USERFILE_PATH_MAX)
errx(EXIT_FAILURE, "USERFILE_PATH_MAX");
// open the userfile
if ((fh = fopen(path, "r")) == NULL) {
if (errno == ENOENT)
errx(EXIT_FAILURE, "no vm defined for user: %s", username);
else
err(EXIT_FAILURE, "fopen: %s", path);
}
// read the vmname
if (fgets(buf, VMNAME_MAX, fh) == NULL)
err(EXIT_FAILURE, "fgets: %s", path);
// kill the newline
if ((nl = index(buf, '\n')))
*nl = '\0';
// sanity-check the vmname
validate_vmname(buf);
// good
}
void __attribute__ ((noreturn)) xm_exec (const char *vmname, const char *command) {
const char *env[] = { NULL };
// setuid to root
if (setuid(0))
err(EXIT_FAILURE, "setuid: 0");
// exec
execle(XM_PATH, "xm", command, vmname, NULL, env);
// if we're still here, an error has occured
err(EXIT_FAILURE, "%s: %s %s", XM_PATH, command, vmname);
}
int main (int argc, char **argv) {
char username[USERNAME_MAX], vmname[VMNAME_MAX], *command;
// get command
command = parse_command(argc, argv);
// get username
get_username(username);
// get vmname
get_vmname(username, vmname);
// execute xm
xm_exec(vmname, command);
}