index.js 43 KB


  1. // @ts-check
  2. import fs from 'node:fs';
  3. import http from 'node:http';
  4. import path from 'node:path';
  5. import url from 'node:url';
  6. import cors from 'cors';
  7. import dotenv from 'dotenv';
  8. import express from 'express';
  9. import { JSDOM } from 'jsdom';
  10. import { WebSocketServer } from 'ws';
  11. import * as Database from './database.js';
  12. import { AuthenticationError, RequestError, extractStatusAndMessage as extractErrorStatusAndMessage } from './errors.js';
  13. import { logger, httpLogger, initializeLogLevel, attachWebsocketHttpLogger, createWebsocketLogger } from './logging.js';
  14. import { setupMetrics } from './metrics.js';
  15. import * as Redis from './redis.js';
  16. import { isTruthy, normalizeHashtag, firstParam } from './utils.js';
  17. const environment = process.env.NODE_ENV || 'development';
  18. // Correctly detect and load .env or .env.production file based on environment:
  19. const dotenvFile = environment === 'production' ? '.env.production' : '.env';
  20. const dotenvFilePath = path.resolve(
  21. url.fileURLToPath(
  22. new URL(path.join('..', dotenvFile), import.meta.url)
  23. )
  24. );
  25. dotenv.config({
  26. path: dotenvFilePath
  27. });
  28. initializeLogLevel(process.env, environment);
  29. /**
  30. * Declares the result type for accountFromToken / accountFromRequest.
  31. *
  32. * Note: This is here because jsdoc doesn't like importing types that
  33. * are nested in functions
  34. * @typedef ResolvedAccount
  35. * @property {string} accessTokenId
  36. * @property {string[]} scopes
  37. * @property {string} accountId
  38. * @property {string[]} chosenLanguages
  39. */
  40. /**
  41. * Attempts to safely parse a string as JSON, used when both receiving a message
  42. * from redis and when receiving a message from a client over a websocket
  43. * connection, this is why it accepts a `req` argument.
  44. * @param {string} json
  45. * @param {any?} req
  46. * @returns {Object.<string, any>|null}
  47. */
  48. const parseJSON = (json, req) => {
  49. try {
  50. return JSON.parse(json);
  51. } catch (err) {
  52. /* FIXME: This logging isn't great, and should probably be done at the
  53. * call-site of parseJSON, not in the method, but this would require changing
  54. * the signature of parseJSON to return something akin to a Result type:
  55. * [Error|null, null|Object<string,any}], and then handling the error
  56. * scenarios.
  57. */
  58. if (req) {
  59. if (req.accountId) {
  60. req.log.error({ err }, `Error parsing message from user ${req.accountId}`);
  61. } else {
  62. req.log.error({ err }, `Error parsing message from ${req.remoteAddress}`);
  63. }
  64. } else {
  65. logger.error({ err }, `Error parsing message from redis`);
  66. }
  67. return null;
  68. }
  69. };
  70. const PUBLIC_CHANNELS = [
  71. 'public',
  72. 'public:media',
  73. 'public:local',
  74. 'public:local:media',
  75. 'public:remote',
  76. 'public:remote:media',
  77. 'hashtag',
  78. 'hashtag:local',
  79. ];
  80. // Used for priming the counters/gauges for the various metrics that are
  81. // per-channel
  82. const CHANNEL_NAMES = [
  83. 'system',
  84. 'user',
  85. 'user:notification',
  86. 'list',
  87. 'direct',
  88. ...PUBLIC_CHANNELS
  89. ];
  90. const startServer = async () => {
  91. const pgConfig = Database.configFromEnv(process.env, environment);
  92. const pgPool = Database.getPool(pgConfig, environment, logger);
  93. const metrics = setupMetrics(CHANNEL_NAMES, pgPool);
  94. const redisConfig = Redis.configFromEnv(process.env);
  95. const redisClient = Redis.createClient(redisConfig, logger);
  96. const server = http.createServer();
  97. const wss = new WebSocketServer({ noServer: true });
  98. /**
  99. * Adds a namespace to Redis keys or channel names
  100. * Fixes: https://github.com/redis/ioredis/issues/1910
  101. * @param {string} keyOrChannel
  102. * @returns {string}
  103. */
  104. function redisNamespaced(keyOrChannel) {
  105. if (redisConfig.namespace) {
  106. return `${redisConfig.namespace}:${keyOrChannel}`;
  107. } else {
  108. return keyOrChannel;
  109. }
  110. }
  111. /**
  112. * Removes the redis namespace from a channel name
  113. * @param {string} channel
  114. * @returns {string}
  115. */
  116. function redisUnnamespaced(channel) {
  117. if (typeof redisConfig.namespace === "string") {
  118. // Note: this removes the configured namespace and the colon that is used
  119. // to separate it:
  120. return channel.slice(redisConfig.namespace.length + 1);
  121. } else {
  122. return channel;
  123. }
  124. }
  125. // Set the X-Request-Id header on WebSockets:
  126. wss.on("headers", function onHeaders(headers, req) {
  127. headers.push(`X-Request-Id: ${req.id}`);
  128. });
  129. const app = express();
  130. app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
  131. app.use(httpLogger);
  132. app.use(cors());
  133. // Handle eventsource & other http requests:
  134. server.on('request', app);
  135. // Handle upgrade requests:
  136. server.on('upgrade', async function handleUpgrade(request, socket, head) {
  137. // Setup the HTTP logger, since websocket upgrades don't get the usual http
  138. // logger. This decorates the `request` object.
  139. attachWebsocketHttpLogger(request);
  140. request.log.info("HTTP Upgrade Requested");
  141. /** @param {Error} err */
  142. const onSocketError = (err) => {
  143. request.log.error({ error: err }, err.message);
  144. };
  145. socket.on('error', onSocketError);
  146. /** @type {ResolvedAccount} */
  147. let resolvedAccount;
  148. try {
  149. resolvedAccount = await accountFromRequest(request);
  150. } catch (err) {
  151. // Unfortunately for using the on('upgrade') setup, we need to manually
  152. // write a HTTP Response to the Socket to close the connection upgrade
  153. // attempt, so the following code is to handle all of that.
  154. const {statusCode, errorMessage } = extractErrorStatusAndMessage(err);
  155. /** @type {Record<string, string | number | import('pino-http').ReqId>} */
  156. const headers = {
  157. 'Connection': 'close',
  158. 'Content-Type': 'text/plain',
  159. 'Content-Length': 0,
  160. 'X-Request-Id': request.id,
  161. 'X-Error-Message': errorMessage
  162. };
  163. // Ensure the socket is closed once we've finished writing to it:
  164. socket.once('finish', () => {
  165. socket.destroy();
  166. });
  167. // Write the HTTP response manually:
  168. socket.end(`HTTP/1.1 ${statusCode} ${http.STATUS_CODES[statusCode]}\r\n${Object.keys(headers).map((key) => `${key}: ${headers[key]}`).join('\r\n')}\r\n\r\n`);
  169. // Finally, log the error:
  170. request.log.error({
  171. err,
  172. res: {
  173. statusCode,
  174. headers
  175. }
  176. }, errorMessage);
  177. return;
  178. }
  179. // Remove the error handler, wss.handleUpgrade has its own:
  180. socket.removeListener('error', onSocketError);
  181. wss.handleUpgrade(request, socket, head, function done(ws) {
  182. request.log.info("Authenticated request & upgraded to WebSocket connection");
  183. const wsLogger = createWebsocketLogger(request, resolvedAccount);
  184. // Start the connection:
  185. wss.emit('connection', ws, request, wsLogger);
  186. });
  187. });
  188. /**
  189. * @type {Object.<string, Array.<function(Object<string, any>): void>>}
  190. */
  191. const subs = {};
  192. const redisSubscribeClient = Redis.createClient(redisConfig, logger);
  193. // When checking metrics in the browser, the favicon is requested this
  194. // prevents the request from falling through to the API Router, which would
  195. // error for this endpoint:
  196. app.get('/favicon.ico', (_req, res) => res.status(404).end());
  197. app.get('/api/v1/streaming/health', (_req, res) => {
  198. res.writeHead(200, { 'Content-Type': 'text/plain', 'Cache-Control': 'private, no-store' });
  199. res.end('OK');
  200. });
  201. app.get('/metrics', metrics.requestHandler);
  202. /**
  203. * @param {string[]} channels
  204. * @returns {function(): void}
  205. */
  206. const subscriptionHeartbeat = channels => {
  207. const interval = 6 * 60;
  208. const tellSubscribed = () => {
  209. channels.forEach(channel => redisClient.set(redisNamespaced(`subscribed:${channel}`), '1', 'EX', interval * 3));
  210. };
  211. tellSubscribed();
  212. const heartbeat = setInterval(tellSubscribed, interval * 1000);
  213. return () => {
  214. clearInterval(heartbeat);
  215. };
  216. };
  217. /**
  218. * @param {string} channel
  219. * @param {string} message
  220. */
  221. const onRedisMessage = (channel, message) => {
  222. metrics.redisMessagesReceived.inc();
  223. logger.debug(`New message on channel ${channel}`);
  224. const key = redisUnnamespaced(channel);
  225. const callbacks = subs[key];
  226. if (!callbacks) {
  227. return;
  228. }
  229. const json = parseJSON(message, null);
  230. if (!json) return;
  231. callbacks.forEach(callback => callback(json));
  232. };
  233. redisSubscribeClient.on("message", onRedisMessage);
  234. /**
  235. * @callback SubscriptionListener
  236. * @param {ReturnType<parseJSON>} json of the message
  237. * @returns void
  238. */
  239. /**
  240. * @param {string} channel
  241. * @param {SubscriptionListener} callback
  242. */
  243. const subscribe = (channel, callback) => {
  244. logger.debug(`Adding listener for ${channel}`);
  245. subs[channel] = subs[channel] || [];
  246. if (subs[channel].length === 0) {
  247. logger.debug(`Subscribe ${channel}`);
  248. redisSubscribeClient.subscribe(redisNamespaced(channel), (err, count) => {
  249. if (err) {
  250. logger.error(`Error subscribing to ${channel}`);
  251. } else if (typeof count === 'number') {
  252. metrics.redisSubscriptions.set(count);
  253. }
  254. });
  255. }
  256. subs[channel].push(callback);
  257. };
  258. /**
  259. * @param {string} channel
  260. * @param {SubscriptionListener} callback
  261. */
  262. const unsubscribe = (channel, callback) => {
  263. logger.debug(`Removing listener for ${channel}`);
  264. if (!subs[channel]) {
  265. return;
  266. }
  267. subs[channel] = subs[channel].filter(item => item !== callback);
  268. if (subs[channel].length === 0) {
  269. logger.debug(`Unsubscribe ${channel}`);
  270. // FIXME: https://github.com/redis/ioredis/issues/1910
  271. redisSubscribeClient.unsubscribe(redisNamespaced(channel), (err, count) => {
  272. if (err) {
  273. logger.error(`Error unsubscribing to ${channel}`);
  274. } else if (typeof count === 'number') {
  275. metrics.redisSubscriptions.set(count);
  276. }
  277. });
  278. delete subs[channel];
  279. }
  280. };
  281. /**
  282. * @param {http.IncomingMessage & ResolvedAccount} req
  283. * @param {string[]} necessaryScopes
  284. * @returns {boolean}
  285. */
  286. const isInScope = (req, necessaryScopes) =>
  287. req.scopes.some(scope => necessaryScopes.includes(scope));
  288. /**
  289. * @param {string} token
  290. * @param {any} req
  291. * @returns {Promise<ResolvedAccount>}
  292. */
  293. const accountFromToken = async (token, req) => {
  294. const result = await pgPool.query('SELECT oauth_access_tokens.id, oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token]);
  295. if (result.rows.length === 0) {
  296. throw new AuthenticationError('Invalid access token');
  297. }
  298. req.accessTokenId = result.rows[0].id;
  299. req.scopes = result.rows[0].scopes.split(' ');
  300. req.accountId = result.rows[0].account_id;
  301. req.chosenLanguages = result.rows[0].chosen_languages;
  302. return {
  303. accessTokenId: result.rows[0].id,
  304. scopes: result.rows[0].scopes.split(' '),
  305. accountId: result.rows[0].account_id,
  306. chosenLanguages: result.rows[0].chosen_languages,
  307. };
  308. };
  309. /**
  310. * @param {any} req
  311. * @returns {Promise<ResolvedAccount>}
  312. */
  313. const accountFromRequest = (req) => new Promise((resolve, reject) => {
  314. const authorization = req.headers.authorization;
  315. const location = url.parse(req.url, true);
  316. const accessToken = location.query.access_token || req.headers['sec-websocket-protocol'];
  317. if (!authorization && !accessToken) {
  318. reject(new AuthenticationError('Missing access token'));
  319. return;
  320. }
  321. const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken;
  322. resolve(accountFromToken(token, req));
  323. });
  324. /**
  325. * @param {any} req
  326. * @returns {string|undefined}
  327. */
  328. const channelNameFromPath = req => {
  329. const { path, query } = req;
  330. const onlyMedia = isTruthy(query.only_media);
  331. switch (path) {
  332. case '/api/v1/streaming/user':
  333. return 'user';
  334. case '/api/v1/streaming/user/notification':
  335. return 'user:notification';
  336. case '/api/v1/streaming/public':
  337. return onlyMedia ? 'public:media' : 'public';
  338. case '/api/v1/streaming/public/local':
  339. return onlyMedia ? 'public:local:media' : 'public:local';
  340. case '/api/v1/streaming/public/remote':
  341. return onlyMedia ? 'public:remote:media' : 'public:remote';
  342. case '/api/v1/streaming/hashtag':
  343. return 'hashtag';
  344. case '/api/v1/streaming/hashtag/local':
  345. return 'hashtag:local';
  346. case '/api/v1/streaming/direct':
  347. return 'direct';
  348. case '/api/v1/streaming/list':
  349. return 'list';
  350. default:
  351. return undefined;
  352. }
  353. };
  354. /**
  355. * @param {http.IncomingMessage & ResolvedAccount} req
  356. * @param {import('pino').Logger} logger
  357. * @param {string|undefined} channelName
  358. * @returns {Promise.<void>}
  359. */
  360. const checkScopes = (req, logger, channelName) => new Promise((resolve, reject) => {
  361. logger.debug(`Checking OAuth scopes for ${channelName}`);
  362. // When accessing public channels, no scopes are needed
  363. if (channelName && PUBLIC_CHANNELS.includes(channelName)) {
  364. resolve();
  365. return;
  366. }
  367. // The `read` scope has the highest priority, if the token has it
  368. // then it can access all streams
  369. const requiredScopes = ['read'];
  370. // When accessing specifically the notifications stream,
  371. // we need a read:notifications, while in all other cases,
  372. // we can allow access with read:statuses. Mind that the
  373. // user stream will not contain notifications unless
  374. // the token has either read or read:notifications scope
  375. // as well, this is handled separately.
  376. if (channelName === 'user:notification') {
  377. requiredScopes.push('read:notifications');
  378. } else {
  379. requiredScopes.push('read:statuses');
  380. }
  381. if (req.scopes && requiredScopes.some(requiredScope => req.scopes.includes(requiredScope))) {
  382. resolve();
  383. return;
  384. }
  385. reject(new AuthenticationError('Access token does not have the required scopes'));
  386. });
  387. /**
  388. * @typedef SystemMessageHandlers
  389. * @property {function(): void} onKill
  390. */
  391. /**
  392. * @param {any} req
  393. * @param {SystemMessageHandlers} eventHandlers
  394. * @returns {SubscriptionListener}
  395. */
  396. const createSystemMessageListener = (req, eventHandlers) => {
  397. return message => {
  398. if (!message?.event) {
  399. return;
  400. }
  401. const { event } = message;
  402. req.log.debug(`System message for ${req.accountId}: ${event}`);
  403. if (event === 'kill') {
  404. req.log.debug(`Closing connection for ${req.accountId} due to expired access token`);
  405. eventHandlers.onKill();
  406. } else if (event === 'filters_changed') {
  407. req.log.debug(`Invalidating filters cache for ${req.accountId}`);
  408. req.cachedFilters = null;
  409. }
  410. };
  411. };
  412. /**
  413. * @param {http.IncomingMessage & ResolvedAccount} req
  414. * @param {http.OutgoingMessage} res
  415. */
  416. const subscribeHttpToSystemChannel = (req, res) => {
  417. const accessTokenChannelId = `timeline:access_token:${req.accessTokenId}`;
  418. const systemChannelId = `timeline:system:${req.accountId}`;
  419. const listener = createSystemMessageListener(req, {
  420. onKill() {
  421. res.end();
  422. },
  423. });
  424. res.on('close', () => {
  425. unsubscribe(accessTokenChannelId, listener);
  426. unsubscribe(systemChannelId, listener);
  427. metrics.connectedChannels.labels({ type: 'eventsource', channel: 'system' }).dec(2);
  428. });
  429. subscribe(accessTokenChannelId, listener);
  430. subscribe(systemChannelId, listener);
  431. metrics.connectedChannels.labels({ type: 'eventsource', channel: 'system' }).inc(2);
  432. };
  433. /**
  434. * @param {any} req
  435. * @param {any} res
  436. * @param {function(Error=): void} next
  437. */
  438. const authenticationMiddleware = (req, res, next) => {
  439. if (req.method === 'OPTIONS') {
  440. next();
  441. return;
  442. }
  443. const channelName = channelNameFromPath(req);
  444. // If no channelName can be found for the request, then we should terminate
  445. // the connection, as there's nothing to stream back
  446. if (!channelName) {
  447. next(new RequestError('Unknown channel requested'));
  448. return;
  449. }
  450. accountFromRequest(req).then(() => checkScopes(req, req.log, channelName)).then(() => {
  451. subscribeHttpToSystemChannel(req, res);
  452. }).then(() => {
  453. next();
  454. }).catch(err => {
  455. next(err);
  456. });
  457. };
  458. /**
  459. * @param {Error} err
  460. * @param {any} req
  461. * @param {any} res
  462. * @param {function(Error=): void} next
  463. */
  464. const errorMiddleware = (err, req, res, next) => {
  465. req.log.error({ err }, err.toString());
  466. if (res.headersSent) {
  467. next(err);
  468. return;
  469. }
  470. const {statusCode, errorMessage } = extractErrorStatusAndMessage(err);
  471. res.writeHead(statusCode, { 'Content-Type': 'application/json' });
  472. res.end(JSON.stringify({ error: errorMessage }));
  473. };
  474. /**
  475. * @param {any[]} arr
  476. * @param {number=} shift
  477. * @returns {string}
  478. */
  479. // @ts-ignore
  480. const placeholders = (arr, shift = 0) => arr.map((_, i) => `$${i + 1 + shift}`).join(', ');
  481. /**
  482. * @param {string} listId
  483. * @param {any} req
  484. * @returns {Promise.<void>}
  485. */
  486. const authorizeListAccess = async (listId, req) => {
  487. const { accountId } = req;
  488. const result = await pgPool.query('SELECT id, account_id FROM lists WHERE id = $1 AND account_id = $2 LIMIT 1', [listId, accountId]);
  489. if (result.rows.length === 0) {
  490. throw new AuthenticationError('List not found');
  491. }
  492. };
  493. /**
  494. * @param {string[]} channelIds
  495. * @param {http.IncomingMessage & ResolvedAccount} req
  496. * @param {import('pino').Logger} log
  497. * @param {function(string, string): void} output
  498. * @param {undefined | function(string[], SubscriptionListener): void} attachCloseHandler
  499. * @param {'websocket' | 'eventsource'} destinationType
  500. * @param {boolean=} needsFiltering
  501. * @returns {SubscriptionListener}
  502. */
  503. const streamFrom = (channelIds, req, log, output, attachCloseHandler, destinationType, needsFiltering = false) => {
  504. log.info({ channelIds }, `Starting stream`);
  505. /**
  506. * @param {string} event
  507. * @param {object|string} payload
  508. */
  509. const transmit = (event, payload) => {
  510. // TODO: Replace "string"-based delete payloads with object payloads:
  511. const encodedPayload = typeof payload === 'object' ? JSON.stringify(payload) : payload;
  512. metrics.messagesSent.labels({ type: destinationType }).inc(1);
  513. log.debug({ event, payload }, `Transmitting ${event} to ${req.accountId}`);
  514. output(event, encodedPayload);
  515. };
  516. // The listener used to process each message off the redis subscription,
  517. // message here is an object with an `event` and `payload` property. Some
  518. // events also include a queued_at value, but this is being removed shortly.
  519. /** @type {SubscriptionListener} */
  520. const listener = message => {
  521. if (!message?.event || !message?.payload) {
  522. return;
  523. }
  524. const { event, payload } = message;
  525. // Streaming only needs to apply filtering to some channels and only to
  526. // some events. This is because majority of the filtering happens on the
  527. // Ruby on Rails side when producing the event for streaming.
  528. //
  529. // The only events that require filtering from the streaming server are
  530. // `update` and `status.update`, all other events are transmitted to the
  531. // client as soon as they're received (pass-through).
  532. //
  533. // The channels that need filtering are determined in the function
  534. // `channelNameToIds` defined below:
  535. if (!needsFiltering || (event !== 'update' && event !== 'status.update')) {
  536. transmit(event, payload);
  537. return;
  538. }
  539. // The rest of the logic from here on in this function is to handle
  540. // filtering of statuses:
  541. // Filter based on language:
  542. if (Array.isArray(req.chosenLanguages) && payload.language !== null && req.chosenLanguages.indexOf(payload.language) === -1) {
  543. log.debug(`Message ${payload.id} filtered by language (${payload.language})`);
  544. return;
  545. }
  546. // When the account is not logged in, it is not necessary to confirm the block or mute
  547. if (!req.accountId) {
  548. transmit(event, payload);
  549. return;
  550. }
  551. // Filter based on domain blocks, blocks, mutes, or custom filters:
  552. // @ts-ignore
  553. const targetAccountIds = [payload.account.id].concat(payload.mentions.map(item => item.id));
  554. const accountDomain = payload.account.acct.split('@')[1];
  555. // TODO: Move this logic out of the message handling loop
  556. pgPool.connect((err, client, releasePgConnection) => {
  557. if (err) {
  558. log.error(err);
  559. return;
  560. }
  561. const queries = [
  562. // @ts-ignore
  563. client.query(`SELECT 1
  564. FROM blocks
  565. WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
  566. OR (account_id = $2 AND target_account_id = $1)
  567. UNION
  568. SELECT 1
  569. FROM mutes
  570. WHERE account_id = $1
  571. AND target_account_id IN (${placeholders(targetAccountIds, 2)})`, [req.accountId, payload.account.id].concat(targetAccountIds)),
  572. ];
  573. if (accountDomain) {
  574. // @ts-ignore
  575. queries.push(client.query('SELECT 1 FROM account_domain_blocks WHERE account_id = $1 AND domain = $2', [req.accountId, accountDomain]));
  576. }
  577. // @ts-ignore
  578. if (!payload.filtered && !req.cachedFilters) {
  579. // @ts-ignore
  580. queries.push(client.query('SELECT filter.id AS id, filter.phrase AS title, filter.context AS context, filter.expires_at AS expires_at, filter.action AS filter_action, keyword.keyword AS keyword, keyword.whole_word AS whole_word FROM custom_filter_keywords keyword JOIN custom_filters filter ON keyword.custom_filter_id = filter.id WHERE filter.account_id = $1 AND (filter.expires_at IS NULL OR filter.expires_at > NOW())', [req.accountId]));
  581. }
  582. Promise.all(queries).then(values => {
  583. releasePgConnection();
  584. // Handling blocks & mutes and domain blocks: If one of those applies,
  585. // then we don't transmit the payload of the event to the client
  586. if (values[0].rows.length > 0 || (accountDomain && values[1].rows.length > 0)) {
  587. return;
  588. }
  589. // If the payload already contains the `filtered` property, it means
  590. // that filtering has been applied on the ruby on rails side, as
  591. // such, we don't need to construct or apply the filters in streaming:
  592. if (Object.hasOwn(payload, "filtered")) {
  593. transmit(event, payload);
  594. return;
  595. }
  596. // Handling for constructing the custom filters and caching them on the request
  597. // TODO: Move this logic out of the message handling lifecycle
  598. // @ts-ignore
  599. if (!req.cachedFilters) {
  600. const filterRows = values[accountDomain ? 2 : 1].rows;
  601. // @ts-ignore
  602. req.cachedFilters = filterRows.reduce((cache, filter) => {
  603. if (cache[filter.id]) {
  604. cache[filter.id].keywords.push([filter.keyword, filter.whole_word]);
  605. } else {
  606. cache[filter.id] = {
  607. keywords: [[filter.keyword, filter.whole_word]],
  608. expires_at: filter.expires_at,
  609. filter: {
  610. id: filter.id,
  611. title: filter.title,
  612. context: filter.context,
  613. expires_at: filter.expires_at,
  614. // filter.filter_action is the value from the
  615. // custom_filters.action database column, it is an integer
  616. // representing a value in an enum defined by Ruby on Rails:
  617. //
  618. // enum { warn: 0, hide: 1 }
  619. filter_action: ['warn', 'hide'][filter.filter_action],
  620. },
  621. };
  622. }
  623. return cache;
  624. }, {});
  625. // Construct the regular expressions for the custom filters: This
  626. // needs to be done in a separate loop as the database returns one
  627. // filterRow per keyword, so we need all the keywords before
  628. // constructing the regular expression
  629. // @ts-ignore
  630. Object.keys(req.cachedFilters).forEach((key) => {
  631. // @ts-ignore
  632. req.cachedFilters[key].regexp = new RegExp(req.cachedFilters[key].keywords.map(([keyword, whole_word]) => {
  633. let expr = keyword.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
  634. if (whole_word) {
  635. if (/^[\w]/.test(expr)) {
  636. expr = `\\b${expr}`;
  637. }
  638. if (/[\w]$/.test(expr)) {
  639. expr = `${expr}\\b`;
  640. }
  641. }
  642. return expr;
  643. }).join('|'), 'i');
  644. });
  645. }
  646. // Apply cachedFilters against the payload, constructing a
  647. // `filter_results` array of FilterResult entities
  648. // @ts-ignore
  649. if (req.cachedFilters) {
  650. const status = payload;
  651. // TODO: Calculate searchableContent in Ruby on Rails:
  652. // @ts-ignore
  653. const searchableContent = ([status.spoiler_text || '', status.content].concat((status.poll && status.poll.options) ? status.poll.options.map(option => option.title) : [])).concat(status.media_attachments.map(att => att.description)).join('\n\n').replace(/<br\s*\/?>/g, '\n').replace(/<\/p><p>/g, '\n\n');
  654. const searchableTextContent = JSDOM.fragment(searchableContent).textContent;
  655. const now = new Date();
  656. // @ts-ignore
  657. const filter_results = Object.values(req.cachedFilters).reduce((results, cachedFilter) => {
  658. // Check the filter hasn't expired before applying:
  659. if (cachedFilter.expires_at !== null && cachedFilter.expires_at < now) {
  660. return results;
  661. }
  662. // Just in-case JSDOM fails to find textContent in searchableContent
  663. if (!searchableTextContent) {
  664. return results;
  665. }
  666. const keyword_matches = searchableTextContent.match(cachedFilter.regexp);
  667. if (keyword_matches) {
  668. // results is an Array of FilterResult; status_matches is always
  669. // null as we only are only applying the keyword-based custom
  670. // filters, not the status-based custom filters.
  671. // https://docs.joinmastodon.org/entities/FilterResult/
  672. results.push({
  673. filter: cachedFilter.filter,
  674. keyword_matches,
  675. status_matches: null
  676. });
  677. }
  678. return results;
  679. }, []);
  680. // Send the payload + the FilterResults as the `filtered` property
  681. // to the streaming connection. To reach this code, the `event` must
  682. // have been either `update` or `status.update`, meaning the
  683. // `payload` is a Status entity, which has a `filtered` property:
  684. //
  685. // filtered: https://docs.joinmastodon.org/entities/Status/#filtered
  686. transmit(event, {
  687. ...payload,
  688. filtered: filter_results
  689. });
  690. } else {
  691. transmit(event, payload);
  692. }
  693. }).catch(err => {
  694. log.error(err);
  695. releasePgConnection();
  696. });
  697. });
  698. };
  699. channelIds.forEach(id => {
  700. subscribe(id, listener);
  701. });
  702. if (typeof attachCloseHandler === 'function') {
  703. attachCloseHandler(channelIds, listener);
  704. }
  705. return listener;
  706. };
  707. /**
  708. * @param {any} req
  709. * @param {any} res
  710. * @returns {function(string, string): void}
  711. */
  712. const streamToHttp = (req, res) => {
  713. const channelName = channelNameFromPath(req);
  714. metrics.connectedClients.labels({ type: 'eventsource' }).inc();
  715. // In theory we'll always have a channel name, but channelNameFromPath can return undefined:
  716. if (typeof channelName === 'string') {
  717. metrics.connectedChannels.labels({ type: 'eventsource', channel: channelName }).inc();
  718. }
  719. res.setHeader('Content-Type', 'text/event-stream');
  720. res.setHeader('Cache-Control', 'private, no-store');
  721. res.setHeader('Transfer-Encoding', 'chunked');
  722. res.write(':)\n');
  723. const heartbeat = setInterval(() => res.write(':thump\n'), 15000);
  724. req.on('close', () => {
  725. req.log.info({ accountId: req.accountId }, `Ending stream`);
  726. // We decrement these counters here instead of in streamHttpEnd as in that
  727. // method we don't have knowledge of the channel names
  728. metrics.connectedClients.labels({ type: 'eventsource' }).dec();
  729. // In theory we'll always have a channel name, but channelNameFromPath can return undefined:
  730. if (typeof channelName === 'string') {
  731. metrics.connectedChannels.labels({ type: 'eventsource', channel: channelName }).dec();
  732. }
  733. clearInterval(heartbeat);
  734. });
  735. return (event, payload) => {
  736. res.write(`event: ${event}\n`);
  737. res.write(`data: ${payload}\n\n`);
  738. };
  739. };
  740. /**
  741. * @param {any} req
  742. * @param {function(): void} [closeHandler]
  743. * @returns {function(string[], SubscriptionListener): void}
  744. */
  745. const streamHttpEnd = (req, closeHandler = undefined) => (ids, listener) => {
  746. req.on('close', () => {
  747. ids.forEach(id => {
  748. unsubscribe(id, listener);
  749. });
  750. if (closeHandler) {
  751. closeHandler();
  752. }
  753. });
  754. };
  755. /**
  756. * @param {http.IncomingMessage} req
  757. * @param {import('ws').WebSocket} ws
  758. * @param {string[]} streamName
  759. * @returns {function(string, string): void}
  760. */
  761. const streamToWs = (req, ws, streamName) => (event, payload) => {
  762. if (ws.readyState !== ws.OPEN) {
  763. req.log.error('Tried writing to closed socket');
  764. return;
  765. }
  766. const message = JSON.stringify({ stream: streamName, event, payload });
  767. ws.send(message, (/** @type {Error|undefined} */ err) => {
  768. if (err) {
  769. req.log.error({err}, `Failed to send to websocket`);
  770. }
  771. });
  772. };
  773. /**
  774. * @param {http.ServerResponse} res
  775. */
  776. const httpNotFound = res => {
  777. res.writeHead(404, { 'Content-Type': 'application/json' });
  778. res.end(JSON.stringify({ error: 'Not found' }));
  779. };
  780. const api = express.Router();
  781. app.use(api);
  782. api.use(authenticationMiddleware);
  783. api.use(errorMiddleware);
  784. api.get('/api/v1/streaming/*', (req, res) => {
  785. const channelName = channelNameFromPath(req);
  786. // FIXME: In theory we'd never actually reach here due to
  787. // authenticationMiddleware catching this case, however, we need to refactor
  788. // how those middlewares work, so I'm adding the extra check in here.
  789. if (!channelName) {
  790. httpNotFound(res);
  791. return;
  792. }
  793. channelNameToIds(req, channelName, req.query).then(({ channelIds, options }) => {
  794. const onSend = streamToHttp(req, res);
  795. const onEnd = streamHttpEnd(req, subscriptionHeartbeat(channelIds));
  796. // @ts-ignore
  797. streamFrom(channelIds, req, req.log, onSend, onEnd, 'eventsource', options.needsFiltering);
  798. }).catch(err => {
  799. const {statusCode, errorMessage } = extractErrorStatusAndMessage(err);
  800. res.log.info({ err }, 'Eventsource subscription error');
  801. res.writeHead(statusCode, { 'Content-Type': 'application/json' });
  802. res.end(JSON.stringify({ error: errorMessage }));
  803. });
  804. });
  805. /**
  806. * @typedef StreamParams
  807. * @property {string} [tag]
  808. * @property {string} [list]
  809. * @property {string} [only_media]
  810. */
  811. /**
  812. * @param {any} req
  813. * @returns {string[]}
  814. */
  815. const channelsForUserStream = req => {
  816. const arr = [`timeline:${req.accountId}`];
  817. if (isInScope(req, ['read', 'read:notifications'])) {
  818. arr.push(`timeline:${req.accountId}:notifications`);
  819. }
  820. return arr;
  821. };
  822. /**
  823. * @param {any} req
  824. * @param {string} name
  825. * @param {StreamParams} params
  826. * @returns {Promise.<{ channelIds: string[], options: { needsFiltering: boolean } }>}
  827. */
  828. const channelNameToIds = (req, name, params) => new Promise((resolve, reject) => {
  829. switch (name) {
  830. case 'user':
  831. resolve({
  832. channelIds: channelsForUserStream(req),
  833. options: { needsFiltering: false },
  834. });
  835. break;
  836. case 'user:notification':
  837. resolve({
  838. channelIds: [`timeline:${req.accountId}:notifications`],
  839. options: { needsFiltering: false },
  840. });
  841. break;
  842. case 'public':
  843. resolve({
  844. channelIds: ['timeline:public'],
  845. options: { needsFiltering: true },
  846. });
  847. break;
  848. case 'public:local':
  849. resolve({
  850. channelIds: ['timeline:public:local'],
  851. options: { needsFiltering: true },
  852. });
  853. break;
  854. case 'public:remote':
  855. resolve({
  856. channelIds: ['timeline:public:remote'],
  857. options: { needsFiltering: true },
  858. });
  859. break;
  860. case 'public:media':
  861. resolve({
  862. channelIds: ['timeline:public:media'],
  863. options: { needsFiltering: true },
  864. });
  865. break;
  866. case 'public:local:media':
  867. resolve({
  868. channelIds: ['timeline:public:local:media'],
  869. options: { needsFiltering: true },
  870. });
  871. break;
  872. case 'public:remote:media':
  873. resolve({
  874. channelIds: ['timeline:public:remote:media'],
  875. options: { needsFiltering: true },
  876. });
  877. break;
  878. case 'direct':
  879. resolve({
  880. channelIds: [`timeline:direct:${req.accountId}`],
  881. options: { needsFiltering: false },
  882. });
  883. break;
  884. case 'hashtag':
  885. if (!params.tag) {
  886. reject(new RequestError('Missing tag name parameter'));
  887. } else {
  888. resolve({
  889. channelIds: [`timeline:hashtag:${normalizeHashtag(params.tag)}`],
  890. options: { needsFiltering: true },
  891. });
  892. }
  893. break;
  894. case 'hashtag:local':
  895. if (!params.tag) {
  896. reject(new RequestError('Missing tag name parameter'));
  897. } else {
  898. resolve({
  899. channelIds: [`timeline:hashtag:${normalizeHashtag(params.tag)}:local`],
  900. options: { needsFiltering: true },
  901. });
  902. }
  903. break;
  904. case 'list':
  905. if (!params.list) {
  906. reject(new RequestError('Missing list name parameter'));
  907. return;
  908. }
  909. authorizeListAccess(params.list, req).then(() => {
  910. resolve({
  911. channelIds: [`timeline:list:${params.list}`],
  912. options: { needsFiltering: false },
  913. });
  914. }).catch(() => {
  915. reject(new AuthenticationError('Not authorized to stream this list'));
  916. });
  917. break;
  918. default:
  919. reject(new RequestError('Unknown stream type'));
  920. }
  921. });
  922. /**
  923. * @param {string} channelName
  924. * @param {StreamParams} params
  925. * @returns {string[]}
  926. */
  927. const streamNameFromChannelName = (channelName, params) => {
  928. if (channelName === 'list' && params.list) {
  929. return [channelName, params.list];
  930. } else if (['hashtag', 'hashtag:local'].includes(channelName) && params.tag) {
  931. return [channelName, params.tag];
  932. } else {
  933. return [channelName];
  934. }
  935. };
  936. /**
  937. * @typedef WebSocketSession
  938. * @property {import('ws').WebSocket & { isAlive: boolean}} websocket
  939. * @property {http.IncomingMessage & ResolvedAccount} request
  940. * @property {import('pino').Logger} logger
  941. * @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
  942. */
  943. /**
  944. * @param {WebSocketSession} session
  945. * @param {string} channelName
  946. * @param {StreamParams} params
  947. * @returns {void}
  948. */
  949. const subscribeWebsocketToChannel = ({ websocket, request, logger, subscriptions }, channelName, params) => {
  950. checkScopes(request, logger, channelName).then(() => channelNameToIds(request, channelName, params)).then(({
  951. channelIds,
  952. options,
  953. }) => {
  954. if (subscriptions[channelIds.join(';')]) {
  955. return;
  956. }
  957. const onSend = streamToWs(request, websocket, streamNameFromChannelName(channelName, params));
  958. const stopHeartbeat = subscriptionHeartbeat(channelIds);
  959. const listener = streamFrom(channelIds, request, logger, onSend, undefined, 'websocket', options.needsFiltering);
  960. metrics.connectedChannels.labels({ type: 'websocket', channel: channelName }).inc();
  961. subscriptions[channelIds.join(';')] = {
  962. channelName,
  963. listener,
  964. stopHeartbeat,
  965. };
  966. }).catch(err => {
  967. const {statusCode, errorMessage } = extractErrorStatusAndMessage(err);
  968. logger.error({ err }, 'Websocket subscription error');
  969. // If we have a socket that is alive and open still, send the error back to the client:
  970. if (websocket.isAlive && websocket.readyState === websocket.OPEN) {
  971. websocket.send(JSON.stringify({
  972. error: errorMessage,
  973. status: statusCode
  974. }));
  975. }
  976. });
  977. };
  978. /**
  979. * @param {WebSocketSession} session
  980. * @param {string[]} channelIds
  981. */
  982. const removeSubscription = ({ request, logger, subscriptions }, channelIds) => {
  983. logger.info({ channelIds, accountId: request.accountId }, `Ending stream`);
  984. const subscription = subscriptions[channelIds.join(';')];
  985. if (!subscription) {
  986. return;
  987. }
  988. channelIds.forEach(channelId => {
  989. unsubscribe(channelId, subscription.listener);
  990. });
  991. metrics.connectedChannels.labels({ type: 'websocket', channel: subscription.channelName }).dec();
  992. subscription.stopHeartbeat();
  993. delete subscriptions[channelIds.join(';')];
  994. };
  995. /**
  996. * @param {WebSocketSession} session
  997. * @param {string} channelName
  998. * @param {StreamParams} params
  999. * @returns {void}
  1000. */
  1001. const unsubscribeWebsocketFromChannel = (session, channelName, params) => {
  1002. const { websocket, request, logger } = session;
  1003. channelNameToIds(request, channelName, params).then(({ channelIds }) => {
  1004. removeSubscription(session, channelIds);
  1005. }).catch(err => {
  1006. logger.error({err}, 'Websocket unsubscribe error');
  1007. // If we have a socket that is alive and open still, send the error back to the client:
  1008. if (websocket.isAlive && websocket.readyState === websocket.OPEN) {
  1009. // TODO: Use a better error response here
  1010. websocket.send(JSON.stringify({ error: "Error unsubscribing from channel" }));
  1011. }
  1012. });
  1013. };
  1014. /**
  1015. * @param {WebSocketSession} session
  1016. */
  1017. const subscribeWebsocketToSystemChannel = ({ websocket, request, subscriptions }) => {
  1018. const accessTokenChannelId = `timeline:access_token:${request.accessTokenId}`;
  1019. const systemChannelId = `timeline:system:${request.accountId}`;
  1020. const listener = createSystemMessageListener(request, {
  1021. onKill() {
  1022. websocket.close();
  1023. },
  1024. });
  1025. subscribe(accessTokenChannelId, listener);
  1026. subscribe(systemChannelId, listener);
  1027. subscriptions[accessTokenChannelId] = {
  1028. channelName: 'system',
  1029. listener,
  1030. stopHeartbeat: () => {
  1031. },
  1032. };
  1033. subscriptions[systemChannelId] = {
  1034. channelName: 'system',
  1035. listener,
  1036. stopHeartbeat: () => {
  1037. },
  1038. };
  1039. metrics.connectedChannels.labels({ type: 'websocket', channel: 'system' }).inc(2);
  1040. };
  1041. /**
  1042. * @param {import('ws').WebSocket & { isAlive: boolean }} ws
  1043. * @param {http.IncomingMessage & ResolvedAccount} req
  1044. * @param {import('pino').Logger} log
  1045. */
  1046. function onConnection(ws, req, log) {
  1047. // Note: url.parse could throw, which would terminate the connection, so we
  1048. // increment the connected clients metric straight away when we establish
  1049. // the connection, without waiting:
  1050. metrics.connectedClients.labels({ type: 'websocket' }).inc();
  1051. // Setup connection keep-alive state:
  1052. ws.isAlive = true;
  1053. ws.on('pong', () => {
  1054. ws.isAlive = true;
  1055. });
  1056. /**
  1057. * @type {WebSocketSession}
  1058. */
  1059. const session = {
  1060. websocket: ws,
  1061. request: req,
  1062. logger: log,
  1063. subscriptions: {},
  1064. };
  1065. ws.on('close', function onWebsocketClose() {
  1066. const subscriptions = Object.keys(session.subscriptions);
  1067. subscriptions.forEach(channelIds => {
  1068. removeSubscription(session, channelIds.split(';'));
  1069. });
  1070. // Decrement the metrics for connected clients:
  1071. metrics.connectedClients.labels({ type: 'websocket' }).dec();
  1072. // We need to unassign the session object as to ensure it correctly gets
  1073. // garbage collected, without doing this we could accidentally hold on to
  1074. // references to the websocket, the request, and the logger, causing
  1075. // memory leaks.
  1076. // This is commented out because `delete` only operated on object properties
  1077. // It needs to be replaced by `session = undefined`, but it requires every calls to
  1078. // `session` to check for it, thus a significant refactor
  1079. // delete session;
  1080. });
  1081. // Note: immediately after the `error` event is emitted, the `close` event
  1082. // is emitted. As such, all we need to do is log the error here.
  1083. ws.on('error', (/** @type {Error} */ err) => {
  1084. log.error(err);
  1085. });
  1086. ws.on('message', (data, isBinary) => {
  1087. if (isBinary) {
  1088. log.warn('Received binary data, closing connection');
  1089. ws.close(1003, 'The mastodon streaming server does not support binary messages');
  1090. return;
  1091. }
  1092. const message = data.toString('utf8');
  1093. const json = parseJSON(message, session.request);
  1094. if (!json) return;
  1095. const { type, stream, ...params } = json;
  1096. if (type === 'subscribe') {
  1097. subscribeWebsocketToChannel(session, firstParam(stream), params);
  1098. } else if (type === 'unsubscribe') {
  1099. unsubscribeWebsocketFromChannel(session, firstParam(stream), params);
  1100. } else {
  1101. // Unknown action type
  1102. }
  1103. });
  1104. subscribeWebsocketToSystemChannel(session);
  1105. // Parse the URL for the connection arguments (if supplied), url.parse can throw:
  1106. const location = req.url && url.parse(req.url, true);
  1107. if (location && location.query.stream) {
  1108. subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
  1109. }
  1110. }
  1111. wss.on('connection', onConnection);
  1112. setInterval(() => {
  1113. wss.clients.forEach(ws => {
  1114. // @ts-ignore
  1115. if (ws.isAlive === false) {
  1116. ws.terminate();
  1117. return;
  1118. }
  1119. // @ts-ignore
  1120. ws.isAlive = false;
  1121. ws.ping('', false);
  1122. });
  1123. }, 30000);
  1124. attachServerWithConfig(server, address => {
  1125. logger.info(`Streaming API now listening on ${address}`);
  1126. });
  1127. const onExit = () => {
  1128. server.close();
  1129. process.exit(0);
  1130. };
  1131. /** @param {Error} err */
  1132. const onError = (err) => {
  1133. logger.error(err);
  1134. server.close();
  1135. process.exit(0);
  1136. };
  1137. process.on('SIGINT', onExit);
  1138. process.on('SIGTERM', onExit);
  1139. process.on('exit', onExit);
  1140. process.on('uncaughtException', onError);
  1141. };
  1142. /**
  1143. * @param {any} server
  1144. * @param {function(string): void} [onSuccess]
  1145. */
  1146. const attachServerWithConfig = (server, onSuccess) => {
  1147. if (process.env.SOCKET) {
  1148. server.listen(process.env.SOCKET, () => {
  1149. if (onSuccess) {
  1150. fs.chmodSync(server.address(), 0o666);
  1151. onSuccess(server.address());
  1152. }
  1153. });
  1154. } else {
  1155. const port = +(process.env.PORT || 4000);
  1156. let bind = process.env.BIND ?? '127.0.0.1';
  1157. // Web uses the URI syntax for BIND, which means IPv6 addresses may
  1158. // be wrapped in square brackets:
  1159. if (bind.startsWith('[') && bind.endsWith(']')) {
  1160. bind = bind.slice(1, -1);
  1161. }
  1162. server.listen(port, bind, () => {
  1163. if (onSuccess) {
  1164. onSuccess(`${server.address().address}:${server.address().port}`);
  1165. }
  1166. });
  1167. }
  1168. };
  1169. startServer();