123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- #
- # Project ___| | | | _ \| |
- # / __| | | | |_) | |
- # | (__| |_| | _ <| |___
- # \___|\___/|_| \_\_____|
- #
- # Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
- #
- # This software is licensed as described in the file COPYING, which
- # you should have received as part of this distribution. The terms
- # are also available at https://curl.se/docs/copyright.html.
- #
- # You may opt to use, copy, modify, merge, publish, distribute and/or sell
- # copies of the Software, and permit persons to whom the Software is
- # furnished to do so, under the terms of the COPYING file.
- #
- # This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
- # KIND, either express or implied.
- #
- # SPDX-License-Identifier: curl
- #
- """Server for testing SMB."""
- from __future__ import (absolute_import, division, print_function,
- unicode_literals)
- import argparse
- import logging
- import os
- import signal
- import sys
- import tempfile
- import threading
- # Import our curl test data helper
- from util import ClosingFileHandler, TestData
- if sys.version_info.major >= 3:
- import configparser
- else:
- import ConfigParser as configparser
- # impacket needs to be installed in the Python environment
- try:
- import impacket # noqa: F401
- except ImportError:
- sys.stderr.write(
- 'Warning: Python package impacket is required for smb testing; '
- 'use pip or your package manager to install it\n')
- sys.exit(1)
- from impacket import smb as imp_smb
- from impacket import smbserver as imp_smbserver
- from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
- STATUS_SUCCESS)
- log = logging.getLogger(__name__)
- SERVER_MAGIC = "SERVER_MAGIC"
- TESTS_MAGIC = "TESTS_MAGIC"
- VERIFIED_REQ = "verifiedserver"
- VERIFIED_RSP = "WE ROOLZ: {pid}\n"
- class ShutdownHandler(threading.Thread):
- """
- Cleanly shut down the SMB server.
- This can only be done from another thread while the server is in
- serve_forever(), so a thread is spawned here that waits for a shutdown
- signal before doing its thing. Use in a with statement around the
- serve_forever() call.
- """
- def __init__(self, server):
- super(ShutdownHandler, self).__init__()
- self.server = server
- self.shutdown_event = threading.Event()
- def __enter__(self):
- self.start()
- signal.signal(signal.SIGINT, self._sighandler)
- signal.signal(signal.SIGTERM, self._sighandler)
- def __exit__(self, *_):
- # Call for shutdown just in case it wasn't done already
- self.shutdown_event.set()
- # Wait for thread, and therefore also the server, to finish
- self.join()
- # Uninstall our signal handlers
- signal.signal(signal.SIGINT, signal.SIG_DFL)
- signal.signal(signal.SIGTERM, signal.SIG_DFL)
- # Delete any temporary files created by the server during its run
- log.info("Deleting %d temporary file(s)", len(self.server.tmpfiles))
- for f in self.server.tmpfiles:
- os.unlink(f)
- def _sighandler(self, _signum, _frame):
- # Wake up the cleanup task
- self.shutdown_event.set()
- def run(self):
- # Wait for shutdown signal
- self.shutdown_event.wait()
- # Notify the server to shut down
- self.server.shutdown()
- def smbserver(options):
- """Start up a TCP SMB server that serves forever."""
- if options.pidfile:
- pid = os.getpid()
- # see tests/server/util.c function write_pidfile
- if os.name == "nt":
- pid += 65536
- with open(options.pidfile, "w") as f:
- f.write(str(pid))
- # Here we write a mini config for the server
- smb_config = configparser.ConfigParser()
- smb_config.add_section("global")
- smb_config.set("global", "server_name", "SERVICE")
- smb_config.set("global", "server_os", "UNIX")
- smb_config.set("global", "server_domain", "WORKGROUP")
- smb_config.set("global", "log_file", "None")
- smb_config.set("global", "credentials_file", "")
- # We need a share which allows us to test that the server is running
- smb_config.add_section("SERVER")
- smb_config.set("SERVER", "comment", "server function")
- smb_config.set("SERVER", "read only", "yes")
- smb_config.set("SERVER", "share type", "0")
- smb_config.set("SERVER", "path", SERVER_MAGIC)
- # Have a share for tests. These files will be autogenerated from the
- # test input.
- smb_config.add_section("TESTS")
- smb_config.set("TESTS", "comment", "tests")
- smb_config.set("TESTS", "read only", "yes")
- smb_config.set("TESTS", "share type", "0")
- smb_config.set("TESTS", "path", TESTS_MAGIC)
- if not options.srcdir or not os.path.isdir(options.srcdir):
- raise ScriptError("--srcdir is mandatory")
- test_data_dir = os.path.join(options.srcdir, "data")
- smb_server = TestSmbServer((options.host, options.port),
- config_parser=smb_config,
- test_data_directory=test_data_dir)
- log.info("[SMB] setting up SMB server on port %s", options.port)
- smb_server.processConfigFile()
- # Start a thread that cleanly shuts down the server on a signal
- with ShutdownHandler(smb_server):
- # This will block until smb_server.shutdown() is called
- smb_server.serve_forever()
- return 0
- class TestSmbServer(imp_smbserver.SMBSERVER):
- """
- Test server for SMB which subclasses the impacket SMBSERVER and provides
- test functionality.
- """
- def __init__(self,
- address,
- config_parser=None,
- test_data_directory=None):
- imp_smbserver.SMBSERVER.__init__(self,
- address,
- config_parser=config_parser)
- self.tmpfiles = []
- # Set up a test data object so we can get test data later.
- self.ctd = TestData(test_data_directory)
- # Override smbComNtCreateAndX so we can pretend to have files which
- # don't exist.
- self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
- self.create_and_x)
- def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
- """
- Our version of smbComNtCreateAndX looks for special test files and
- fools the rest of the framework into opening them as if they were
- normal files.
- """
- conn_data = smb_server.getConnectionData(conn_id)
- # Wrap processing in a try block which allows us to throw SmbError
- # to control the flow.
- try:
- ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
- smb_command["Parameters"])
- path = self.get_share_path(conn_data,
- ncax_parms["RootFid"],
- recv_packet["Tid"])
- log.info("[SMB] Requested share path: %s", path)
- disposition = ncax_parms["Disposition"]
- log.debug("[SMB] Requested disposition: %s", disposition)
- # Currently we only support reading files.
- if disposition != imp_smb.FILE_OPEN:
- raise SmbError(STATUS_ACCESS_DENIED,
- "Only support reading files")
- # Check to see if the path we were given is actually a
- # magic path which needs generating on the fly.
- if path not in [SERVER_MAGIC, TESTS_MAGIC]:
- # Pass the command onto the original handler.
- return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
- smb_server,
- smb_command,
- recv_packet)
- flags2 = recv_packet["Flags2"]
- ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
- data=smb_command[
- "Data"])
- requested_file = imp_smbserver.decodeSMBString(
- flags2,
- ncax_data["FileName"])
- log.debug("[SMB] User requested file '%s'", requested_file)
- if path == SERVER_MAGIC:
- fid, full_path = self.get_server_path(requested_file)
- else:
- assert path == TESTS_MAGIC
- fid, full_path = self.get_test_path(requested_file)
- self.tmpfiles.append(full_path)
- resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
- resp_data = ""
- # Simple way to generate a fid
- if len(conn_data["OpenedFiles"]) == 0:
- fakefid = 1
- else:
- fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
- resp_parms["Fid"] = fakefid
- resp_parms["CreateAction"] = disposition
- if os.path.isdir(path):
- resp_parms[
- "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
- resp_parms["IsDirectory"] = 1
- else:
- resp_parms["IsDirectory"] = 0
- resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
- # Get this file's information
- resp_info, error_code = imp_smbserver.queryPathInformation(
- os.path.dirname(full_path), os.path.basename(full_path),
- level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
- if error_code != STATUS_SUCCESS:
- raise SmbError(error_code, "Failed to query path info")
- resp_parms["CreateTime"] = resp_info["CreationTime"]
- resp_parms["LastAccessTime"] = resp_info[
- "LastAccessTime"]
- resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
- resp_parms["LastChangeTime"] = resp_info[
- "LastChangeTime"]
- resp_parms["FileAttributes"] = resp_info[
- "ExtFileAttributes"]
- resp_parms["AllocationSize"] = resp_info[
- "AllocationSize"]
- resp_parms["EndOfFile"] = resp_info["EndOfFile"]
- # Let's store the fid for the connection
- # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
- conn_data["OpenedFiles"][fakefid] = {}
- conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
- conn_data["OpenedFiles"][fakefid]["FileName"] = path
- conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
- except SmbError as s:
- log.debug("[SMB] SmbError hit: %s", s)
- error_code = s.error_code
- resp_parms = ""
- resp_data = ""
- resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
- resp_cmd["Parameters"] = resp_parms
- resp_cmd["Data"] = resp_data
- smb_server.setConnectionData(conn_id, conn_data)
- return [resp_cmd], None, error_code
- def get_share_path(self, conn_data, root_fid, tid):
- conn_shares = conn_data["ConnectedShares"]
- if tid in conn_shares:
- if root_fid > 0:
- # If we have a rootFid, the path is relative to that fid
- path = conn_data["OpenedFiles"][root_fid]["FileName"]
- log.debug("RootFid present %s!" % path)
- else:
- if "path" in conn_shares[tid]:
- path = conn_shares[tid]["path"]
- else:
- raise SmbError(STATUS_ACCESS_DENIED,
- "Connection share had no path")
- else:
- raise SmbError(imp_smbserver.STATUS_SMB_BAD_TID,
- "TID was invalid")
- return path
- def get_server_path(self, requested_filename):
- log.debug("[SMB] Get server path '%s'", requested_filename)
- if requested_filename not in [VERIFIED_REQ]:
- raise SmbError(STATUS_NO_SUCH_FILE, "Couldn't find the file")
- fid, filename = tempfile.mkstemp()
- log.debug("[SMB] Created %s (%d) for storing '%s'",
- filename, fid, requested_filename)
- contents = ""
- if requested_filename == VERIFIED_REQ:
- log.debug("[SMB] Verifying server is alive")
- pid = os.getpid()
- # see tests/server/util.c function write_pidfile
- if os.name == "nt":
- pid += 65536
- contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
- self.write_to_fid(fid, contents)
- return fid, filename
- def write_to_fid(self, fid, contents):
- # Write the contents to file descriptor
- os.write(fid, contents)
- os.fsync(fid)
- # Rewind the file to the beginning so a read gets us the contents
- os.lseek(fid, 0, os.SEEK_SET)
- def get_test_path(self, requested_filename):
- log.info("[SMB] Get reply data from 'test%s'", requested_filename)
- fid, filename = tempfile.mkstemp()
- log.debug("[SMB] Created %s (%d) for storing test '%s'",
- filename, fid, requested_filename)
- try:
- contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
- self.write_to_fid(fid, contents)
- return fid, filename
- except Exception:
- log.exception("Failed to make test file")
- raise SmbError(STATUS_NO_SUCH_FILE, "Failed to make test file")
- class SmbError(Exception):
- def __init__(self, error_code, error_message):
- super(SmbError, self).__init__(error_message)
- self.error_code = error_code
- class ScriptRC(object):
- """Enum for script return codes."""
- SUCCESS = 0
- FAILURE = 1
- EXCEPTION = 2
- class ScriptError(Exception):
- pass
- def get_options():
- parser = argparse.ArgumentParser()
- parser.add_argument("--port", action="store", default=9017,
- type=int, help="port to listen on")
- parser.add_argument("--host", action="store", default="127.0.0.1",
- help="host to listen on")
- parser.add_argument("--verbose", action="store", type=int, default=0,
- help="verbose output")
- parser.add_argument("--pidfile", action="store",
- help="file name for the PID")
- parser.add_argument("--logfile", action="store",
- help="file name for the log")
- parser.add_argument("--srcdir", action="store", help="test directory")
- parser.add_argument("--id", action="store", help="server ID")
- parser.add_argument("--ipv4", action="store_true", default=0,
- help="IPv4 flag")
- return parser.parse_args()
- def setup_logging(options):
- """Set up logging from the command line options."""
- root_logger = logging.getLogger()
- add_stdout = False
- formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
- # Write out to a logfile
- if options.logfile:
- handler = ClosingFileHandler(options.logfile)
- handler.setFormatter(formatter)
- handler.setLevel(logging.DEBUG)
- root_logger.addHandler(handler)
- else:
- # The logfile wasn't specified. Add a stdout logger.
- add_stdout = True
- if options.verbose:
- # Add a stdout logger as well in verbose mode
- root_logger.setLevel(logging.DEBUG)
- add_stdout = True
- else:
- root_logger.setLevel(logging.WARNING)
- if add_stdout:
- stdout_handler = logging.StreamHandler(sys.stdout)
- stdout_handler.setFormatter(formatter)
- stdout_handler.setLevel(logging.DEBUG)
- root_logger.addHandler(stdout_handler)
- if __name__ == '__main__':
- # Get the options from the user.
- options = get_options()
- # Setup logging using the user options
- setup_logging(options)
- # Run main script.
- try:
- rc = smbserver(options)
- except Exception:
- log.exception('Error in SMB server')
- rc = ScriptRC.EXCEPTION
- if options.pidfile and os.path.isfile(options.pidfile):
- os.unlink(options.pidfile)
- log.info("[SMB] Returning %d", rc)
- sys.exit(rc)
|