search_query_transformer.rb 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # frozen_string_literal: true
  2. class SearchQueryTransformer < Parslet::Transform
  3. SUPPORTED_PREFIXES = %w(
  4. has
  5. is
  6. language
  7. from
  8. before
  9. after
  10. during
  11. in
  12. ).freeze
  13. class Query
  14. def initialize(clauses, options = {})
  15. raise ArgumentError if options[:current_account].nil?
  16. @clauses = clauses
  17. @options = options
  18. flags_from_clauses!
  19. end
  20. def request
  21. search = Chewy::Search::Request.new(*indexes).filter(default_filter)
  22. must_clauses.each { |clause| search = search.query.must(clause.to_query) }
  23. must_not_clauses.each { |clause| search = search.query.must_not(clause.to_query) }
  24. filter_clauses.each { |clause| search = search.filter(**clause.to_query) }
  25. search
  26. end
  27. private
  28. def clauses_by_operator
  29. @clauses_by_operator ||= @clauses.compact.chunk(&:operator).to_h
  30. end
  31. def flags_from_clauses!
  32. @flags = clauses_by_operator.fetch(:flag, []).to_h { |clause| [clause.prefix, clause.term] }
  33. end
  34. def must_clauses
  35. clauses_by_operator.fetch(:must, [])
  36. end
  37. def must_not_clauses
  38. clauses_by_operator.fetch(:must_not, [])
  39. end
  40. def filter_clauses
  41. clauses_by_operator.fetch(:filter, [])
  42. end
  43. def indexes
  44. case @flags['in']
  45. when 'library'
  46. [StatusesIndex]
  47. when 'public'
  48. [PublicStatusesIndex]
  49. else
  50. [PublicStatusesIndex, StatusesIndex]
  51. end
  52. end
  53. def default_filter
  54. {
  55. bool: {
  56. should: [
  57. {
  58. term: {
  59. _index: PublicStatusesIndex.index_name,
  60. },
  61. },
  62. {
  63. bool: {
  64. must: [
  65. {
  66. term: {
  67. _index: StatusesIndex.index_name,
  68. },
  69. },
  70. {
  71. term: {
  72. searchable_by: @options[:current_account].id,
  73. },
  74. },
  75. ],
  76. },
  77. },
  78. ],
  79. minimum_should_match: 1,
  80. },
  81. }
  82. end
  83. end
  84. class Operator
  85. class << self
  86. def symbol(str)
  87. case str
  88. when '+', nil
  89. :must
  90. when '-'
  91. :must_not
  92. else
  93. raise "Unknown operator: #{str}"
  94. end
  95. end
  96. end
  97. end
  98. class TermClause
  99. attr_reader :operator, :term
  100. def initialize(operator, term)
  101. @operator = Operator.symbol(operator)
  102. @term = term
  103. end
  104. def to_query
  105. if @term.start_with?('#')
  106. { match: { tags: { query: @term, operator: 'and' } } }
  107. else
  108. { multi_match: { type: 'most_fields', query: @term, fields: ['text', 'text.stemmed'], operator: 'and' } }
  109. end
  110. end
  111. end
  112. class PhraseClause
  113. attr_reader :operator, :phrase
  114. def initialize(operator, phrase)
  115. @operator = Operator.symbol(operator)
  116. @phrase = phrase
  117. end
  118. def to_query
  119. { match_phrase: { text: { query: @phrase } } }
  120. end
  121. end
  122. class PrefixClause
  123. attr_reader :operator, :prefix, :term
  124. def initialize(prefix, operator, term, options = {})
  125. @prefix = prefix
  126. @negated = operator == '-'
  127. @options = options
  128. @operator = :filter
  129. case prefix
  130. when 'has', 'is'
  131. @filter = :properties
  132. @type = :term
  133. @term = term
  134. when 'language'
  135. @filter = :language
  136. @type = :term
  137. @term = language_code_from_term(term)
  138. when 'from'
  139. @filter = :account_id
  140. @type = :term
  141. @term = account_id_from_term(term)
  142. when 'before'
  143. @filter = :created_at
  144. @type = :range
  145. @term = { lt: TermValidator.validate_date!(term), time_zone: @options[:current_account]&.user_time_zone.presence || 'UTC' }
  146. when 'after'
  147. @filter = :created_at
  148. @type = :range
  149. @term = { gt: TermValidator.validate_date!(term), time_zone: @options[:current_account]&.user_time_zone.presence || 'UTC' }
  150. when 'during'
  151. @filter = :created_at
  152. @type = :range
  153. @term = { gte: TermValidator.validate_date!(term), lte: TermValidator.validate_date!(term), time_zone: @options[:current_account]&.user_time_zone.presence || 'UTC' }
  154. when 'in'
  155. @operator = :flag
  156. @term = term
  157. else
  158. raise "Unknown prefix: #{prefix}"
  159. end
  160. end
  161. def to_query
  162. if @negated
  163. { bool: { must_not: { @type => { @filter => @term } } } }
  164. else
  165. { @type => { @filter => @term } }
  166. end
  167. end
  168. private
  169. def account_id_from_term(term)
  170. return @options[:current_account]&.id || -1 if term == 'me'
  171. username, domain = term.gsub(/\A@/, '').split('@')
  172. domain = nil if TagManager.instance.local_domain?(domain)
  173. account = Account.find_remote(username, domain)
  174. # If the account is not found, we want to return empty results, so return
  175. # an ID that does not exist
  176. account&.id || -1
  177. end
  178. def language_code_from_term(term)
  179. language_code = term
  180. return language_code if LanguagesHelper::SUPPORTED_LOCALES.key?(language_code.to_sym)
  181. language_code = term.downcase
  182. return language_code if LanguagesHelper::SUPPORTED_LOCALES.key?(language_code.to_sym)
  183. language_code = term.split(/[_-]/).first.downcase
  184. return language_code if LanguagesHelper::SUPPORTED_LOCALES.key?(language_code.to_sym)
  185. term
  186. end
  187. end
  188. class TermValidator
  189. STRICT_DATE_REGEX = /\A\d{4}-\d{2}-\d{2}\z/ # yyyy-MM-dd
  190. EPOCH_MILLIS_REGEX = /\A\d{1,19}\z/
  191. def self.validate_date!(value)
  192. return value if value.match?(STRICT_DATE_REGEX) || value.match?(EPOCH_MILLIS_REGEX)
  193. raise Mastodon::FilterValidationError, "Invalid date #{value}"
  194. end
  195. end
  196. rule(clause: subtree(:clause)) do
  197. prefix = clause[:prefix][:term].to_s.downcase if clause[:prefix]
  198. operator = clause[:operator]&.to_s
  199. term = clause[:phrase] ? clause[:phrase].map { |term| term[:term].to_s }.join(' ') : clause[:term].to_s
  200. if clause[:prefix] && SUPPORTED_PREFIXES.include?(prefix)
  201. PrefixClause.new(prefix, operator, term, current_account: current_account)
  202. elsif clause[:prefix]
  203. TermClause.new(operator, "#{prefix} #{term}")
  204. elsif clause[:term]
  205. TermClause.new(operator, term)
  206. elsif clause[:phrase]
  207. PhraseClause.new(operator, term)
  208. else
  209. raise "Unexpected clause type: #{clause}"
  210. end
  211. end
  212. rule(junk: subtree(:junk)) do
  213. nil
  214. end
  215. rule(query: sequence(:clauses)) do
  216. Query.new(clauses, current_account: current_account)
  217. end
  218. end