dinit-client.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. #include <cstdint>
  2. #include <cstring>
  3. #include <sys/types.h>
  4. #include <pwd.h>
  5. #include <mconfig.h>
  6. #include <cpbuffer.h>
  7. #include <control-cmds.h>
  8. // Client library for Dinit clients
  9. using handle_t = uint32_t;
  10. using cpbuffer_t = cpbuffer<1024>;
  11. class cp_read_exception
  12. {
  13. public:
  14. int errcode;
  15. cp_read_exception(int err) : errcode(err) { }
  16. };
  17. class cp_write_exception
  18. {
  19. public:
  20. int errcode;
  21. cp_write_exception(int err) : errcode(err) { }
  22. };
  23. class cp_old_client_exception
  24. {
  25. // no body
  26. };
  27. class cp_old_server_exception
  28. {
  29. // no body
  30. };
  31. class general_error
  32. {
  33. int err; // related errno (or 0)
  34. const char *action; // may be nullptr if err != 0
  35. std::string arg; // may be empty, must be empty if action == nullptr
  36. public:
  37. general_error(int err) : err(err), action(nullptr), arg()
  38. {
  39. }
  40. general_error(int err, const char *action, std::string arg = {}) : err(err), action(action), arg(std::move(arg))
  41. {
  42. }
  43. int get_err() { return err; }
  44. const char *get_action() { return action; }
  45. std::string &get_arg() { return arg; }
  46. };
  47. // static_membuf: a buffer of a fixed size (N) with one additional value (of type T). Don't use this
  48. // directly, construct via membuf.
  49. template <int N> class static_membuf
  50. {
  51. public:
  52. static constexpr int size() { return N; }
  53. private:
  54. char buf[N];
  55. public:
  56. template <typename T>
  57. static_membuf(const T &val)
  58. {
  59. static_assert(sizeof(T) == N, "must initialise with object of correct size");
  60. memcpy(buf, &val, N);
  61. }
  62. template <int M, typename T>
  63. static_membuf(char (&prevbuf)[M], const T &val)
  64. {
  65. static_assert(M + sizeof(T) == N, "size is not correct");
  66. memcpy(buf, prevbuf, M);
  67. memcpy(buf + M, &val, sizeof(val));
  68. }
  69. const char *data() const { return buf; }
  70. template <typename U> static_membuf<N+sizeof(U)> append(const U &u)
  71. {
  72. return static_membuf<N+sizeof(U)>{buf, u};
  73. }
  74. void output(char *out)
  75. {
  76. memcpy(out, buf, size());
  77. }
  78. };
  79. // "membuf" class provides a compile-time allocated buffer that we can add items to one-by-one. This is
  80. // much safer than working with raw buffers and calculating offsets and sizes by hand (and with a decent
  81. // compiler the end result is just as efficient).
  82. //
  83. // To use:
  84. // auto m = membuf().append(value1).append(value2).append(value3);
  85. // Then:
  86. // m.size() - returns total size of the buffer (sizeof(value1)+...)
  87. // m.data() - returns a 'const char *' to the buffer contents
  88. class membuf
  89. {
  90. public:
  91. template <typename U> static_membuf<sizeof(U)> append(const U &u)
  92. {
  93. return static_membuf<sizeof(U)>(u);
  94. }
  95. };
  96. // Fill a circular buffer from a file descriptor, until it contains at least _rlength_ bytes.
  97. // Throws cp_read_exception if the requested number of bytes cannot be read, with:
  98. // errcode = 0 if end of stream (remote end closed)
  99. // errcode = errno if another error occurred
  100. // Note that EINTR is ignored (i.e. the read will be re-tried).
  101. inline void fill_buffer_to(cpbuffer_t &buf, int fd, int rlength)
  102. {
  103. do {
  104. int r = buf.fill_to(fd, rlength);
  105. if (r == -1) {
  106. if (errno != EINTR) {
  107. throw cp_read_exception(errno);
  108. }
  109. }
  110. else if (r == 0) {
  111. throw cp_read_exception(0);
  112. }
  113. else {
  114. return;
  115. }
  116. }
  117. while (true);
  118. }
  119. // Fill a circular buffer from a file descriptor, until it contains at least some more data
  120. // then it did. Throws cp_read_exception if no more bytes can be read, with:
  121. // errcode = 0 if end of stream (remote end closed)
  122. // errcode = errno if another error occurred
  123. // Note that EINTR is ignored (i.e. the read will be re-tried).
  124. inline void fill_some(cpbuffer_t &buf, int fd)
  125. {
  126. while(true) {
  127. int r = buf.fill(fd);
  128. if (r == 0) {
  129. throw cp_read_exception(0);
  130. }
  131. else if (r > 0) {
  132. return;
  133. }
  134. if (errno != EINTR) {
  135. throw cp_read_exception(errno);
  136. }
  137. // if EINTR, just try again
  138. }
  139. }
  140. // Wait for a reply packet, skipping over any information packets that are received in the meantime.
  141. // Note the reply packet may be partially read. Caller is responsible for reading full packet.
  142. inline void wait_for_reply(cpbuffer_t &rbuffer, int fd)
  143. {
  144. fill_buffer_to(rbuffer, fd, 1);
  145. while (rbuffer[0] >= 100) {
  146. // Information packet; discard.
  147. fill_buffer_to(rbuffer, fd, 2);
  148. int pktlen = (unsigned char) rbuffer[1];
  149. rbuffer.consume(1); // Consume one byte so we'll read one byte of the next packet
  150. fill_buffer_to(rbuffer, fd, pktlen);
  151. rbuffer.consume(pktlen - 1);
  152. }
  153. }
  154. // Wait for an info packet. If any other reply packet comes, throw a cp_read_exception.
  155. inline void wait_for_info(cpbuffer_t &rbuffer, int fd)
  156. {
  157. fill_buffer_to(rbuffer, fd, 2);
  158. if (rbuffer[0] < 100) {
  159. throw cp_read_exception(0);
  160. }
  161. int pktlen = (unsigned char) rbuffer[1];
  162. fill_buffer_to(rbuffer, fd, pktlen);
  163. }
  164. // Write *all* the requested buffer and re-try if necessary until
  165. // the buffer is written or an unrecoverable error occurs.
  166. // Note: count is int
  167. inline int write_all(int fd, const void *buf, int count)
  168. {
  169. const char *cbuf = static_cast<const char *>(buf);
  170. int w = 0;
  171. while (count > 0) {
  172. int r = write(fd, cbuf, count);
  173. if (r == -1) {
  174. if (errno == EINTR) continue;
  175. return (count > 0) ? count : r;
  176. }
  177. w += r;
  178. cbuf += r;
  179. count -= r;
  180. }
  181. return w;
  182. }
  183. // Write all the requested buffer, and throw an exception on failure.
  184. // Note: count is int
  185. inline void write_all_x(int fd, const void *buf, int count)
  186. {
  187. if (write_all(fd, buf, count) == -1) {
  188. throw cp_write_exception(errno);
  189. }
  190. }
  191. // Write all the requested buffer (eg membuf) and throw an exception on failure.
  192. template <typename Buf> inline void write_all_x(int fd, const Buf &b)
  193. {
  194. write_all_x(fd, b.data(), b.size());
  195. }
  196. // Check the protocol version is compatible with the client.
  197. // minversion - minimum protocol version that client can speak
  198. // version - maximum protocol version that client can speak
  199. // rbuffer, fd - communication buffer and socket
  200. // returns: the actual protocol version
  201. // throws an exception on protocol mismatch or error.
  202. inline uint16_t check_protocol_version(int minversion, int version, cpbuffer_t &rbuffer, int fd)
  203. {
  204. constexpr int bufsize = 1;
  205. char buf[bufsize] = { DINIT_CP_QUERYVERSION };
  206. write_all_x(fd, buf, bufsize);
  207. wait_for_reply(rbuffer, fd);
  208. if (rbuffer[0] != DINIT_RP_CPVERSION) {
  209. throw cp_read_exception{0};
  210. }
  211. // DINIT_RP_CVERSION, (2 byte) minimum compatible version, (2 byte) actual version
  212. constexpr int rbufsize = 1 + 2 * sizeof(uint16_t);
  213. fill_buffer_to(rbuffer, fd, rbufsize);
  214. uint16_t rminversion;
  215. uint16_t cpversion;
  216. rbuffer.extract(reinterpret_cast<char *>(&rminversion), 1, sizeof(uint16_t));
  217. rbuffer.extract(reinterpret_cast<char *>(&cpversion), 1 + sizeof(uint16_t), sizeof(uint16_t));
  218. rbuffer.consume(rbufsize);
  219. if (rminversion > version) {
  220. // We are too old
  221. throw cp_old_client_exception();
  222. }
  223. if (cpversion < minversion) {
  224. // Server is too old
  225. throw cp_old_server_exception();
  226. }
  227. return cpversion;
  228. }
  229. // Get the default socket path (i.e. the path to use if no path is explicitly specified).
  230. // 'control_socket_str' *may* be used for storage for the returned path.
  231. // 'user_dinit' should be true if getuid() == 0.
  232. inline const char *get_default_socket_path(std::string &control_socket_str, bool user_dinit)
  233. {
  234. const char *control_socket_path;
  235. const char *sockpath = getenv("DINIT_SOCKET_PATH");
  236. if (sockpath) {
  237. control_socket_str = sockpath;
  238. control_socket_path = control_socket_str.c_str();
  239. }
  240. else if (user_dinit) {
  241. const char * rundir = getenv("XDG_RUNTIME_DIR");
  242. const char * sockname = "dinitctl";
  243. if (rundir == nullptr) {
  244. sockname = ".dinitctl";
  245. rundir = getenv("HOME");
  246. if (rundir == nullptr) {
  247. struct passwd * pwuid_p = getpwuid(getuid());
  248. if (pwuid_p != nullptr) {
  249. rundir = pwuid_p->pw_dir;
  250. }
  251. }
  252. }
  253. if (rundir != nullptr) {
  254. control_socket_str = rundir;
  255. control_socket_str.push_back('/');
  256. control_socket_str += sockname;
  257. control_socket_path = control_socket_str.c_str();
  258. }
  259. else {
  260. return nullptr;
  261. }
  262. }
  263. else {
  264. control_socket_path = SYSCONTROLSOCKET; // default to system
  265. }
  266. return control_socket_path;
  267. }
  268. // Connect to the dinit daemon, return the connected socket fd.
  269. // Throws general_error on error.
  270. inline int connect_to_daemon(const char *control_socket_path)
  271. {
  272. int socknum = socket(AF_UNIX, SOCK_STREAM, 0);
  273. if (socknum == -1) {
  274. throw general_error(errno, "opening socket");
  275. }
  276. struct sockaddr_un * name;
  277. uint sockaddr_size = offsetof(struct sockaddr_un, sun_path) + strlen(control_socket_path) + 1;
  278. name = (struct sockaddr_un *) malloc(sockaddr_size);
  279. if (name == nullptr) {
  280. throw general_error(ENOMEM);
  281. }
  282. name->sun_family = AF_UNIX;
  283. strcpy(name->sun_path, control_socket_path);
  284. int connr = connect(socknum, (struct sockaddr *) name, sockaddr_size);
  285. free(name);
  286. if (connr == -1) {
  287. throw general_error(errno, "connecting to socket", control_socket_path);
  288. }
  289. return socknum;
  290. }
  291. // Get the file descriptor for the control socket connection as passed to use from parent process
  292. // (returns -1 if unsuccessful)
  293. inline int get_passed_cfd()
  294. {
  295. int socknum = -1;
  296. char * dinit_cs_fd_env = getenv("DINIT_CS_FD");
  297. if (dinit_cs_fd_env != nullptr) {
  298. char * endptr;
  299. long int cfdnum = strtol(dinit_cs_fd_env, &endptr, 10);
  300. if (endptr != dinit_cs_fd_env) {
  301. socknum = (int) cfdnum;
  302. // Set blocking mode (and validate file descriptor):
  303. errno = 0;
  304. int sock_flags = fcntl(socknum, F_GETFL, 0);
  305. if (sock_flags == -1 && errno != 0) {
  306. socknum = 0;
  307. }
  308. else {
  309. fcntl(socknum, F_SETFL, sock_flags & ~O_NONBLOCK);
  310. }
  311. }
  312. }
  313. return socknum;
  314. }