Skip to content

Instantly share code, notes, and snippets.

@shimarin
Created January 19, 2021 12:53
Show Gist options
  • Save shimarin/43ac99debbdc468b39763298e241a30f to your computer and use it in GitHub Desktop.
Save shimarin/43ac99debbdc468b39763298e241a30f to your computer and use it in GitHub Desktop.
#include <errno.h>
#include <poll.h>
#include <unistd.h>
#include <pty.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <stdexcept>
#include <iostream>
#include <string>
#include <map>
#include <set>
#include <vector>
#include <filesystem>
#include <systemd/sd-bus.h>
struct Vm {
std::string name;
pid_t pid;
int fd;
int listening_socket;
int peer_socket = -1;
std::vector<char> inbuf;
std::vector<char> outbuf;
};
std::map<std::string,Vm> vms;
std::filesystem::path socket_base_dir("/run/whitebase");
int make_nonblocking(int fd)
{
auto flags = fcntl(fd, F_GETFL, 0);
return fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}
std::pair<pid_t, int> createSubprocessWithPty(int rows, int cols, const char* prog, const std::vector<std::string>& args = {}, const char* TERM = "xterm-256color")
{
int fd;
struct winsize win = { (unsigned short)rows, (unsigned short)cols, 0, 0 };
auto pid = forkpty(&fd, NULL, NULL, &win);
if (pid < 0) throw std::runtime_error("forkpty failed");
//else
if (!pid) {
struct sigaction sig_action;
sig_action.sa_handler = SIG_DFL;
sig_action.sa_flags = 0;
sigemptyset(&sig_action.sa_mask);
for (int i = 0 ; i < NSIG ; i++) {
sigaction(i, &sig_action, NULL);
}
setenv("TERM", TERM, 1);
char ** argv = new char *[args.size() + 2];
argv[0] = strdup(prog);
for (int i = 1; i <= args.size(); i++) {
argv[i] = strdup(args[i - 1].c_str());
}
argv[args.size() + 1] = NULL;
if (execvp(prog, argv) < 0) exit(-1);
}
//else
return { pid, fd };
}
static int method_start(sd_bus_message *m, void *userdata, sd_bus_error *ret_error) {
const char* name;
/* Read the parameters */
auto r = sd_bus_message_read(m, "s", &name);
if (r < 0) {
fprintf(stderr, "Failed to parse parameters: %s\n", strerror(-r));
return r;
}
if (vms.count(name) && kill(vms.at(name).pid, 0) == 0) {
sd_bus_error_set_const(ret_error, "net.poettering.AlreadyRunning", "VM is already running.");
return -EINVAL;
}
//else
auto vm = createSubprocessWithPty(24, 80, "bash");
make_nonblocking(vm.second);
int sock = socket(AF_UNIX, SOCK_STREAM, 0);
if (sock < 0) throw std::runtime_error("Unable to create socket for listening");
//else
struct sockaddr_un sockaddr;
memset(&sockaddr, 0, sizeof(sockaddr));
sockaddr.sun_family = AF_UNIX;
std::filesystem::create_directories(socket_base_dir);
std::filesystem::path path = socket_base_dir / name;
if (std::filesystem::exists(path)) std::filesystem::remove(path);
strcpy(sockaddr.sun_path, path.c_str());
if (bind(sock, (const struct sockaddr*)&sockaddr, sizeof(sockaddr)) < 0) {
close(sock);
throw std::runtime_error("Unable to bind socket");
}
if (listen(sock, 10) < 0) {
close(sock);
throw std::runtime_error("Unable to listen socket");
}
make_nonblocking(sock);
vms[name] = Vm { name, vm.first, vm.second, sock };
std::cout << name << " started. PID=" << vm.first << std::endl;
/* Reply with the response */
return sd_bus_reply_method_return(m, "u", vm.first);
}
static int method_stop(sd_bus_message *m, void *userdata, sd_bus_error *ret_error) {
const char* name;
/* Read the parameters */
auto r = sd_bus_message_read(m, "s", &name);
if (r < 0) {
fprintf(stderr, "Failed to parse parameters: %s\n", strerror(-r));
return r;
}
if (!vms.count(name)) {
sd_bus_error_set_const(ret_error, "net.poettering.NotRunning", "VM is not running.");
return -EINVAL;
}
if (kill(vms.at(name).pid, 0) < 0 && errno == ESRCH) {
vms.erase(name);
sd_bus_error_set_const(ret_error, "net.poettering.Vanished", "VM is vanished.");
}
std::cout << "Stopping " << name << std::endl;
kill(vms.at(name).pid, SIGTERM);
/* Reply with the response */
return sd_bus_reply_method_return(m, "u", vms.at(name).pid);
}
static int method_sendstring(sd_bus_message *m, void *userdata, sd_bus_error *ret_error) {
/* Read the parameters */
const char* name;
auto r = sd_bus_message_read(m, "s", &name);
if (r < 0) {
fprintf(stderr, "Failed to parse parameters: %s\n", strerror(-r));
return r;
}
const char* str;
r = sd_bus_message_read(m, "s", &str);
if (r < 0) {
fprintf(stderr, "Failed to parse parameters: %s\n", strerror(-r));
return r;
}
if (!vms.count(name)) {
sd_bus_error_set_const(ret_error, "net.poettering.NotRunning", "VM is not running.");
return -EINVAL;
}
if (kill(vms.at(name).pid, 0) < 0 && errno == ESRCH) {
vms.erase(name);
sd_bus_error_set_const(ret_error, "net.poettering.Vanished", "VM is vanished.");
}
std::cout << "Sending '" << str << '\'' << std::endl;
for (int i = 0; i < strlen(str); i++) vms.at(name).outbuf.push_back(str[i]);
/* Reply with the response */
return sd_bus_reply_method_return(m, "b", 1);
}
static const sd_bus_vtable vtable[] = {
SD_BUS_VTABLE_START(0),
SD_BUS_METHOD("Start", "s", "u", method_start, SD_BUS_VTABLE_UNPRIVILEGED),
SD_BUS_METHOD("Stop", "s", "u", method_stop, SD_BUS_VTABLE_UNPRIVILEGED),
SD_BUS_METHOD("Sendstring", "ss", "b", method_sendstring, SD_BUS_VTABLE_UNPRIVILEGED),
SD_BUS_VTABLE_END
};
void process_io(Vm& vm)
{
// process output to VM
auto outbuf_size = vm.outbuf.size();
for (int i = 0; i < outbuf_size; i++) {
if (write(vm.fd, &(vm.outbuf[0]), 1) <= 0/*EAGAIN?*/) break;
//else
vm.outbuf.erase(vm.outbuf.begin());
}
// process input from VM
while(true) {
char buf[4096];
int r = read(vm.fd, buf, sizeof(buf));
if (r == 0/*EOF*/ || r < 0/*possibly EAGAIN*/) break;
//else
if (vm.peer_socket >= 0) { // keep bytes only when a peers is there
for (int i = 0; i < r; i++) {
vm.inbuf.push_back(buf[i]);
}
}
std::cout << "read " << r << " bytes from " << vm.name << std::endl;
}
// process listening socket
int sock = accept(vm.listening_socket, NULL, NULL);
if (sock >= 0) {
if (vm.peer_socket < 0) {
make_nonblocking(sock);
vm.peer_socket = sock;
std::cout << "Peer accepted." << std::endl;
} else {
const char* msg = "Simultaneous connections are not allowed\n";
write(sock, msg, strlen(msg));
close(sock);
}
}
// process output to peer(was input from VM)
if (vm.peer_socket >= 0) {
auto inbuf_size = vm.inbuf.size();
int cnt = 0;
for (int i = 0; i < inbuf_size; i++) {
if (write(vm.peer_socket, &(vm.inbuf[0]), 1) <= 0/*EAGAIN?*/) break;
//else
vm.inbuf.erase(vm.inbuf.begin());
cnt++;
}
if (inbuf_size > 0) std::cout << "wrote " << cnt << " bytes to peer" << std::endl;
// process input from peer(will be output to VM)
while(true) {
char buf[4096];
int r = read(vm.peer_socket, buf, sizeof(buf));
if (r < 0/*possibly EAGAIN*/) break;
if (r == 0/*EOF*/) {
std::cout << "EOF from peer. Cleaning up." << std::endl;
// cleanup socket when EOF
close(vm.peer_socket);
vm.peer_socket = -1;
break;
}
//else
for (int i = 0; i < r; i++) {
vm.outbuf.push_back(buf[i]);
}
}
}
}
int main(int argc, char *argv[])
{
sd_bus_slot *slot = NULL;
sd_bus *bus = NULL;
const char* service_name = "net.poettering.Calculator";
const char* object_path = "/net/poettering/Calculator";
const char* interface_name = service_name;
int status = 0;
try {
/* Connect to the user bus this time */
auto r = sd_bus_open_system(&bus);
if (r < 0) {
throw std::runtime_error(std::string("Failed to connect to system bus: ") + strerror(-r));
}
/* Install the object */
r = sd_bus_add_object_vtable(bus,
&slot,
object_path, /* object path */
interface_name, /* interface name */
vtable,
NULL);
if (r < 0) {
throw std::runtime_error(std::string("Failed to issue method call: ") + strerror(-r));
}
/* Take a well-known service name so that clients can find us */
r = sd_bus_request_name(bus, service_name, 0);
if (r < 0) {
throw std::runtime_error(std::string("Failed to acquire service name: ") + strerror(-r));
}
sigset_t mask;
sigemptyset (&mask);
sigaddset (&mask, SIGINT);
sigaddset (&mask, SIGTERM);
sigaddset (&mask, SIGCHLD);
sigprocmask(SIG_SETMASK, &mask, NULL);
auto sigfd = signalfd (-1, &mask, SFD_NONBLOCK|SFD_CLOEXEC);
std::cout << getpid() << std::endl;
bool exit_flag = false;
for (;;) {
/* Process requests */
auto r = sd_bus_process(bus, NULL);
if (r < 0) {
throw std::runtime_error(std::string("Failed to process bus: ") + strerror(-r));
}
if (r > 0) /* we processed a request, try to process another one, right-away */
continue;
std::vector<std::pair<int,short> > pollfds;
pollfds.push_back({sigfd, POLLIN}); // 0
pollfds.push_back({sd_bus_get_fd(bus), sd_bus_get_events(bus)}); // 1
for (auto i = vms.begin(); i != vms.end(); i++) {
Vm& vm = i->second;
pollfds.push_back({vm.fd, vm.outbuf.size() > 0 ? POLLIN : (POLLIN|POLLOUT)});
pollfds.push_back({vm.listening_socket, POLLIN});
if (vm.peer_socket >= 0) pollfds.push_back({vm.peer_socket, vm.inbuf.size() > 0? POLLIN : (POLLIN|POLLOUT)});
}
struct pollfd c_pollfds[pollfds.size()];
for (int i = 0; i < pollfds.size(); i++) {
c_pollfds[i].fd = pollfds[i].first;
c_pollfds[i].events = pollfds[i].second;
}
if (poll(c_pollfds, pollfds.size(), 1000) == 0) {
std::cout << "poll() timeout" << std::endl;
}
if (c_pollfds[0].revents & POLLIN) { // signal received
struct signalfd_siginfo info;
read(c_pollfds[0].fd, &info, sizeof(info));
std::cout << "Signal received: signo=" << info.ssi_signo << ", code=" << info.ssi_code << ", pid=" << info.ssi_pid << std::endl;
if (info.ssi_signo == SIGTERM || info.ssi_signo == SIGINT) {
for (auto i = vms.cbegin(); i != vms.cend(); i++) {
std::cout << "Shutting down " << i->first << std::endl;
kill(i->second.pid, SIGTERM);
}
exit_flag = true;
}
}
// cleanup exited child processes
pid_t pid;
int status;
while ((pid = waitpid(-1, &status, WNOHANG)) > 0) {
std::cout << "PID " << pid << " exited with status " << status << "." << std::endl;
for (auto i = vms.begin(), next_i = i; i != vms.end(); i = next_i) {
Vm& vm = i->second;
++next_i;
if (vm.pid == pid/*exited process*/ || kill(vm.pid, 0) < 0/*stale process*/) {
// cleanup peer socket
if (vm.peer_socket >= 0) close(vm.peer_socket);
// cleanup listening socket
close(vm.listening_socket);
// remove socket file
std::filesystem::path path = socket_base_dir / i->first;
if (std::filesystem::exists(path)) std::filesystem::remove(path);
vms.erase(i);
std::cout << "VM erased. remain=" << vms.size() << std::endl;
}
}
}
if (exit_flag && vms.size() == 0) break;
// process each vm's i/o
for (auto i = vms.begin(); i != vms.end(); i++) {
process_io(i->second);
}
if (c_pollfds[1].revents != 0) {
std::cout << "Message received" << std::endl;
}
/* Wait for the next request to process */
r = sd_bus_wait(bus, (uint64_t) 0);
if (r < 0) {
throw std::runtime_error(std::string("Failed to wait on bus: %s\n") + strerror(-r));
}
}
}
catch (const std::exception& ex) {
std::cerr << ex.what() << std::endl;
status = 1;
}
sd_bus_slot_unref(slot);
sd_bus_unref(bus);
return status;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment