smbserver.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Project ___| | | | _ \| |
  5. # / __| | | | |_) | |
  6. # | (__| |_| | _ <| |___
  7. # \___|\___/|_| \_\_____|
  8. #
  9. # Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
  10. #
  11. # This software is licensed as described in the file COPYING, which
  12. # you should have received as part of this distribution. The terms
  13. # are also available at https://curl.se/docs/copyright.html.
  14. #
  15. # You may opt to use, copy, modify, merge, publish, distribute and/or sell
  16. # copies of the Software, and permit persons to whom the Software is
  17. # furnished to do so, under the terms of the COPYING file.
  18. #
  19. # This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
  20. # KIND, either express or implied.
  21. #
  22. # SPDX-License-Identifier: curl
  23. #
  24. """Server for testing SMB."""
  25. from __future__ import (absolute_import, division, print_function,
  26. unicode_literals)
  27. import argparse
  28. import logging
  29. import os
  30. import signal
  31. import sys
  32. import tempfile
  33. import threading
  34. # Import our curl test data helper
  35. from util import ClosingFileHandler, TestData
  36. if sys.version_info.major >= 3:
  37. import configparser
  38. else:
  39. import ConfigParser as configparser
  40. # impacket needs to be installed in the Python environment
  41. try:
  42. import impacket # noqa: F401
  43. except ImportError:
  44. sys.stderr.write(
  45. 'Warning: Python package impacket is required for smb testing; '
  46. 'use pip or your package manager to install it\n')
  47. sys.exit(1)
  48. from impacket import smb as imp_smb
  49. from impacket import smbserver as imp_smbserver
  50. from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
  51. STATUS_SUCCESS)
  52. log = logging.getLogger(__name__)
  53. SERVER_MAGIC = "SERVER_MAGIC"
  54. TESTS_MAGIC = "TESTS_MAGIC"
  55. VERIFIED_REQ = "verifiedserver"
  56. VERIFIED_RSP = "WE ROOLZ: {pid}\n"
  57. class ShutdownHandler(threading.Thread):
  58. """
  59. Cleanly shut down the SMB server.
  60. This can only be done from another thread while the server is in
  61. serve_forever(), so a thread is spawned here that waits for a shutdown
  62. signal before doing its thing. Use in a with statement around the
  63. serve_forever() call.
  64. """
  65. def __init__(self, server):
  66. super(ShutdownHandler, self).__init__()
  67. self.server = server
  68. self.shutdown_event = threading.Event()
  69. def __enter__(self):
  70. self.start()
  71. signal.signal(signal.SIGINT, self._sighandler)
  72. signal.signal(signal.SIGTERM, self._sighandler)
  73. def __exit__(self, *_):
  74. # Call for shutdown just in case it wasn't done already
  75. self.shutdown_event.set()
  76. # Wait for thread, and therefore also the server, to finish
  77. self.join()
  78. # Uninstall our signal handlers
  79. signal.signal(signal.SIGINT, signal.SIG_DFL)
  80. signal.signal(signal.SIGTERM, signal.SIG_DFL)
  81. # Delete any temporary files created by the server during its run
  82. log.info("Deleting %d temporary file(s)", len(self.server.tmpfiles))
  83. for f in self.server.tmpfiles:
  84. os.unlink(f)
  85. def _sighandler(self, _signum, _frame):
  86. # Wake up the cleanup task
  87. self.shutdown_event.set()
  88. def run(self):
  89. # Wait for shutdown signal
  90. self.shutdown_event.wait()
  91. # Notify the server to shut down
  92. self.server.shutdown()
  93. def smbserver(options):
  94. """Start up a TCP SMB server that serves forever."""
  95. if options.pidfile:
  96. pid = os.getpid()
  97. # see tests/server/util.c function write_pidfile
  98. if os.name == "nt":
  99. pid += 65536
  100. with open(options.pidfile, "w") as f:
  101. f.write(str(pid))
  102. # Here we write a mini config for the server
  103. smb_config = configparser.ConfigParser()
  104. smb_config.add_section("global")
  105. smb_config.set("global", "server_name", "SERVICE")
  106. smb_config.set("global", "server_os", "UNIX")
  107. smb_config.set("global", "server_domain", "WORKGROUP")
  108. smb_config.set("global", "log_file", "None")
  109. smb_config.set("global", "credentials_file", "")
  110. # We need a share which allows us to test that the server is running
  111. smb_config.add_section("SERVER")
  112. smb_config.set("SERVER", "comment", "server function")
  113. smb_config.set("SERVER", "read only", "yes")
  114. smb_config.set("SERVER", "share type", "0")
  115. smb_config.set("SERVER", "path", SERVER_MAGIC)
  116. # Have a share for tests. These files will be autogenerated from the
  117. # test input.
  118. smb_config.add_section("TESTS")
  119. smb_config.set("TESTS", "comment", "tests")
  120. smb_config.set("TESTS", "read only", "yes")
  121. smb_config.set("TESTS", "share type", "0")
  122. smb_config.set("TESTS", "path", TESTS_MAGIC)
  123. if not options.srcdir or not os.path.isdir(options.srcdir):
  124. raise ScriptError("--srcdir is mandatory")
  125. test_data_dir = os.path.join(options.srcdir, "data")
  126. smb_server = TestSmbServer((options.host, options.port),
  127. config_parser=smb_config,
  128. test_data_directory=test_data_dir)
  129. log.info("[SMB] setting up SMB server on port %s", options.port)
  130. smb_server.processConfigFile()
  131. # Start a thread that cleanly shuts down the server on a signal
  132. with ShutdownHandler(smb_server):
  133. # This will block until smb_server.shutdown() is called
  134. smb_server.serve_forever()
  135. return 0
  136. class TestSmbServer(imp_smbserver.SMBSERVER):
  137. """
  138. Test server for SMB which subclasses the impacket SMBSERVER and provides
  139. test functionality.
  140. """
  141. def __init__(self,
  142. address,
  143. config_parser=None,
  144. test_data_directory=None):
  145. imp_smbserver.SMBSERVER.__init__(self,
  146. address,
  147. config_parser=config_parser)
  148. self.tmpfiles = []
  149. # Set up a test data object so we can get test data later.
  150. self.ctd = TestData(test_data_directory)
  151. # Override smbComNtCreateAndX so we can pretend to have files which
  152. # don't exist.
  153. self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
  154. self.create_and_x)
  155. def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
  156. """
  157. Our version of smbComNtCreateAndX looks for special test files and
  158. fools the rest of the framework into opening them as if they were
  159. normal files.
  160. """
  161. conn_data = smb_server.getConnectionData(conn_id)
  162. # Wrap processing in a try block which allows us to throw SmbError
  163. # to control the flow.
  164. try:
  165. ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
  166. smb_command["Parameters"])
  167. path = self.get_share_path(conn_data,
  168. ncax_parms["RootFid"],
  169. recv_packet["Tid"])
  170. log.info("[SMB] Requested share path: %s", path)
  171. disposition = ncax_parms["Disposition"]
  172. log.debug("[SMB] Requested disposition: %s", disposition)
  173. # Currently we only support reading files.
  174. if disposition != imp_smb.FILE_OPEN:
  175. raise SmbError(STATUS_ACCESS_DENIED,
  176. "Only support reading files")
  177. # Check to see if the path we were given is actually a
  178. # magic path which needs generating on the fly.
  179. if path not in [SERVER_MAGIC, TESTS_MAGIC]:
  180. # Pass the command onto the original handler.
  181. return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
  182. smb_server,
  183. smb_command,
  184. recv_packet)
  185. flags2 = recv_packet["Flags2"]
  186. ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
  187. data=smb_command[
  188. "Data"])
  189. requested_file = imp_smbserver.decodeSMBString(
  190. flags2,
  191. ncax_data["FileName"])
  192. log.debug("[SMB] User requested file '%s'", requested_file)
  193. if path == SERVER_MAGIC:
  194. fid, full_path = self.get_server_path(requested_file)
  195. else:
  196. assert path == TESTS_MAGIC
  197. fid, full_path = self.get_test_path(requested_file)
  198. self.tmpfiles.append(full_path)
  199. resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
  200. resp_data = ""
  201. # Simple way to generate a fid
  202. if len(conn_data["OpenedFiles"]) == 0:
  203. fakefid = 1
  204. else:
  205. fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
  206. resp_parms["Fid"] = fakefid
  207. resp_parms["CreateAction"] = disposition
  208. if os.path.isdir(path):
  209. resp_parms[
  210. "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
  211. resp_parms["IsDirectory"] = 1
  212. else:
  213. resp_parms["IsDirectory"] = 0
  214. resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
  215. # Get this file's information
  216. resp_info, error_code = imp_smbserver.queryPathInformation(
  217. os.path.dirname(full_path), os.path.basename(full_path),
  218. level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
  219. if error_code != STATUS_SUCCESS:
  220. raise SmbError(error_code, "Failed to query path info")
  221. resp_parms["CreateTime"] = resp_info["CreationTime"]
  222. resp_parms["LastAccessTime"] = resp_info[
  223. "LastAccessTime"]
  224. resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
  225. resp_parms["LastChangeTime"] = resp_info[
  226. "LastChangeTime"]
  227. resp_parms["FileAttributes"] = resp_info[
  228. "ExtFileAttributes"]
  229. resp_parms["AllocationSize"] = resp_info[
  230. "AllocationSize"]
  231. resp_parms["EndOfFile"] = resp_info["EndOfFile"]
  232. # Let's store the fid for the connection
  233. # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
  234. conn_data["OpenedFiles"][fakefid] = {}
  235. conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
  236. conn_data["OpenedFiles"][fakefid]["FileName"] = path
  237. conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
  238. except SmbError as s:
  239. log.debug("[SMB] SmbError hit: %s", s)
  240. error_code = s.error_code
  241. resp_parms = ""
  242. resp_data = ""
  243. resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
  244. resp_cmd["Parameters"] = resp_parms
  245. resp_cmd["Data"] = resp_data
  246. smb_server.setConnectionData(conn_id, conn_data)
  247. return [resp_cmd], None, error_code
  248. def get_share_path(self, conn_data, root_fid, tid):
  249. conn_shares = conn_data["ConnectedShares"]
  250. if tid in conn_shares:
  251. if root_fid > 0:
  252. # If we have a rootFid, the path is relative to that fid
  253. path = conn_data["OpenedFiles"][root_fid]["FileName"]
  254. log.debug("RootFid present %s!" % path)
  255. else:
  256. if "path" in conn_shares[tid]:
  257. path = conn_shares[tid]["path"]
  258. else:
  259. raise SmbError(STATUS_ACCESS_DENIED,
  260. "Connection share had no path")
  261. else:
  262. raise SmbError(imp_smbserver.STATUS_SMB_BAD_TID,
  263. "TID was invalid")
  264. return path
  265. def get_server_path(self, requested_filename):
  266. log.debug("[SMB] Get server path '%s'", requested_filename)
  267. if requested_filename not in [VERIFIED_REQ]:
  268. raise SmbError(STATUS_NO_SUCH_FILE, "Couldn't find the file")
  269. fid, filename = tempfile.mkstemp()
  270. log.debug("[SMB] Created %s (%d) for storing '%s'",
  271. filename, fid, requested_filename)
  272. contents = ""
  273. if requested_filename == VERIFIED_REQ:
  274. log.debug("[SMB] Verifying server is alive")
  275. pid = os.getpid()
  276. # see tests/server/util.c function write_pidfile
  277. if os.name == "nt":
  278. pid += 65536
  279. contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
  280. self.write_to_fid(fid, contents)
  281. return fid, filename
  282. def write_to_fid(self, fid, contents):
  283. # Write the contents to file descriptor
  284. os.write(fid, contents)
  285. os.fsync(fid)
  286. # Rewind the file to the beginning so a read gets us the contents
  287. os.lseek(fid, 0, os.SEEK_SET)
  288. def get_test_path(self, requested_filename):
  289. log.info("[SMB] Get reply data from 'test%s'", requested_filename)
  290. fid, filename = tempfile.mkstemp()
  291. log.debug("[SMB] Created %s (%d) for storing test '%s'",
  292. filename, fid, requested_filename)
  293. try:
  294. contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
  295. self.write_to_fid(fid, contents)
  296. return fid, filename
  297. except Exception:
  298. log.exception("Failed to make test file")
  299. raise SmbError(STATUS_NO_SUCH_FILE, "Failed to make test file")
  300. class SmbError(Exception):
  301. def __init__(self, error_code, error_message):
  302. super(SmbError, self).__init__(error_message)
  303. self.error_code = error_code
  304. class ScriptRC(object):
  305. """Enum for script return codes."""
  306. SUCCESS = 0
  307. FAILURE = 1
  308. EXCEPTION = 2
  309. class ScriptError(Exception):
  310. pass
  311. def get_options():
  312. parser = argparse.ArgumentParser()
  313. parser.add_argument("--port", action="store", default=9017,
  314. type=int, help="port to listen on")
  315. parser.add_argument("--host", action="store", default="127.0.0.1",
  316. help="host to listen on")
  317. parser.add_argument("--verbose", action="store", type=int, default=0,
  318. help="verbose output")
  319. parser.add_argument("--pidfile", action="store",
  320. help="file name for the PID")
  321. parser.add_argument("--logfile", action="store",
  322. help="file name for the log")
  323. parser.add_argument("--srcdir", action="store", help="test directory")
  324. parser.add_argument("--id", action="store", help="server ID")
  325. parser.add_argument("--ipv4", action="store_true", default=0,
  326. help="IPv4 flag")
  327. return parser.parse_args()
  328. def setup_logging(options):
  329. """Set up logging from the command line options."""
  330. root_logger = logging.getLogger()
  331. add_stdout = False
  332. formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
  333. # Write out to a logfile
  334. if options.logfile:
  335. handler = ClosingFileHandler(options.logfile)
  336. handler.setFormatter(formatter)
  337. handler.setLevel(logging.DEBUG)
  338. root_logger.addHandler(handler)
  339. else:
  340. # The logfile wasn't specified. Add a stdout logger.
  341. add_stdout = True
  342. if options.verbose:
  343. # Add a stdout logger as well in verbose mode
  344. root_logger.setLevel(logging.DEBUG)
  345. add_stdout = True
  346. else:
  347. root_logger.setLevel(logging.WARNING)
  348. if add_stdout:
  349. stdout_handler = logging.StreamHandler(sys.stdout)
  350. stdout_handler.setFormatter(formatter)
  351. stdout_handler.setLevel(logging.DEBUG)
  352. root_logger.addHandler(stdout_handler)
  353. if __name__ == '__main__':
  354. # Get the options from the user.
  355. options = get_options()
  356. # Setup logging using the user options
  357. setup_logging(options)
  358. # Run main script.
  359. try:
  360. rc = smbserver(options)
  361. except Exception:
  362. log.exception('Error in SMB server')
  363. rc = ScriptRC.EXCEPTION
  364. if options.pidfile and os.path.isfile(options.pidfile):
  365. os.unlink(options.pidfile)
  366. log.info("[SMB] Returning %d", rc)
  367. sys.exit(rc)