Source code for switchyard.lib.socket.socketemu

import sys
from queue import Queue, Empty
from threading import Lock
from subprocess import getoutput
import re
import random
from textwrap import indent
from copy import copy
from collections import namedtuple
from time import time
import socket
from socket import error as sockerr

# carefully control what we export to user code; we provide our own
# implementation for some symbols, and others simply aren't supported
# ugly, but working within limitations of importlib, etc....
implist = copy(socket.__all__)
from socket import *
dontimport = ('setdefaulttimeout', 'getdefaulttimeout', 'has_ipv6', 
    'socket', 'socketpair', 'fromfd', 'dup', 'create_connection', 'CMSG_LEN',
    'CMSG_SPACE')
for name in dontimport:
    if name in implist:
        implist.remove(name)

explist = ['socket', 'ApplicationLayer', 'getdefaulttimeout', 'setdefaulttimeout', 'has_ipv6']
explist.extend(implist)
__all__ = explist

from ...hostfirewall import Firewall
from ...pcapffi import PcapLiveDevice
from ..exceptions import NoPackets
from ..logging import log_debug, log_info, log_warn, setup_logging, red, yellow
from ..packet import IPProtocol
from ..address import ip_address

has_ipv6 = True

_lock = Lock()

def _gather_ports():
    portset = set()
    out = getoutput("netstat -an | grep ^udp")
    for x in out.split('\n'):
        fields = x.split()
        if len(fields) < 5:
            continue
        ports = fields[3].strip()
        mobj = re.search(r'[\.:](\d+|\*)$', ports)
        if mobj:
            port = (mobj.groups()[0])
            if port != '*':
                portset.add(int(port))
    return portset


def _get_ephemeral_port():
    with _lock:
        ports = _gather_ports()
        while True:
            p = random.randint(30000,60000)
            if p not in ports and p not in ApplicationLayer._emuports():
                return p

_default_timeout = None

def getdefaulttimeout():
    '''
    Get the default timeout value for a socket.  The preset default
    is None, meaning to block indefinitely.
    '''
    return _default_timeout

def setdefaulttimeout(tmo):
    '''
    Set the default timeout value for a socket to the given value.
    Calling this function does not affect any preexisting sockets.
    '''
    global _default_timeout
    with _lock:
        _default_timeout = tmo

def _normalize_addrs(addrtuple):
    return (ip_address(addrtuple[0]), int(addrtuple[1]))

def _stringify_addrs(addrtuple):
    return (str(addrtuple[0]), int(addrtuple[1]))

[docs]class ApplicationLayer(object): _isinit = False _to_app = None _from_app = None def __init__(self): ''' Don't try to create an instance of this class. Switchyard internally handles initialization. Users should only ever call the recv_from_app() and send_to_app() static methods. ''' raise RuntimeError("Ouch. Please don't try to create an instance " "of {}. Use the static init() method " "instead.".format(self.__class__.__name__)) @staticmethod def _init(): ''' Internal switchyard static initialization method. ''' if ApplicationLayer._isinit: return ApplicationLayer._isinit = True ApplicationLayer._to_app = {} ApplicationLayer._from_app = Queue() @staticmethod def _emuports(): s = set() for sockid,_ in ApplicationLayer._to_app.items(): s.add(sockid[-1]) return s
[docs] @staticmethod def recv_from_app(timeout=_default_timeout): ''' Called by a network stack implementer to receive application-layer data for sending on to a remote location. Can optionally take a timeout value. If no data are available, raises NoPackets exception. Returns a 2-tuple: flowaddr and data. The flowaddr consists of 5 items: protocol, localaddr, localport, remoteaddr, remoteport. ''' try: return ApplicationLayer._from_app.get(timeout=timeout) except Empty: pass raise NoPackets()
[docs] @staticmethod def send_to_app(proto, local_addr, remote_addr, data): ''' Called by a network stack implementer to push application-layer data "up" from the stack. Arguments are protocol number, local_addr (a 2-tuple of IP address and port), remote_addr (a 2-tuple of IP address and port), and the message. Returns True if a socket was found to which to deliver the message, and False otherwise. When False is returned, a log warning is also emitted. ''' proto = IPProtocol(proto) local_addr = _normalize_addrs(local_addr) remote_addr = _normalize_addrs(remote_addr) xtup = (proto, local_addr[0], local_addr[1]) with _lock: sockqueue = ApplicationLayer._to_app.get(xtup, None) if sockqueue is not None: sockqueue.put((local_addr,remote_addr,data)) return True # no dice, try local IP addr of 0.0.0.0 local2 = _normalize_addrs(("0.0.0.0", local_addr[1])) xtup = (proto, local2[0], local2[1]) with _lock: sockqueue = ApplicationLayer._to_app.get(xtup, None) if sockqueue is not None: sockqueue.put((local_addr,remote_addr,data)) return True log_warn("No socket queue found for local proto/address: {}".format(xtup)) return False
@staticmethod def _register_socket(s): ''' Internal method used by socket emulation layer to create a new "upward" queue for an app-layer socket and to register the socket object. Returns two queues: "downward" (fromapp) and "upward" (toapp). ''' queue_to_app = Queue() with _lock: ApplicationLayer._to_app[s._sockid()] = queue_to_app return ApplicationLayer._from_app, queue_to_app @staticmethod def _registry_update(s, oldid): ''' Internal method used to update an existing socket registry when the socket is re-bound to a different local port number. Requires the socket object and old sockid. Returns None. ''' with _lock: sock_queue = ApplicationLayer._to_app.pop(oldid) ApplicationLayer._to_app[s._sockid()] = sock_queue @staticmethod def _unregister_socket(s): ''' Internal method used to remove the socket from AppLayer registry. Warns if the "upward" socket queue has any left-over data. ''' with _lock: sock_queue = ApplicationLayer._to_app.pop(s._sockid()) if not sock_queue.empty(): log_warn("Socket being destroyed still has data enqueued for application layer.")
[docs]class socket(object): ''' A socket object, emulated by Switchyard. ''' __slots__ = ('_family','_socktype','_protoname','_proto', '_timeout','_block','_remote_addr','_local_addr', '_socket_queue_app_to_stack','_socket_queue_stack_to_app') def __init__(self, family, socktype, proto=0, fileno=0): if not ApplicationLayer._isinit: raise RuntimeError("ApplicationLayer isn't initialized; this socket class can only be used within a Switchyard program.") family = AddressFamily(family) if family not in (AddressFamily.AF_INET, AddressFamily.AF_INET6): raise NotImplementedError( "socket for family {} not implemented".format(family)) # only UDP is supported right now... if socktype not in (SOCK_DGRAM,): raise NotImplementedError( "socket type {} not implemented".format(socktype)) self._family = family self._socktype = socktype self._protoname = 'udp' self._proto = IPProtocol.UDP if proto != 0: self._proto = proto self._timeout = _default_timeout self._block = True self._remote_addr = (None,None) self._local_addr = (ip_address('0.0.0.0'),_get_ephemeral_port()) self.__set_fw_rules() self._socket_queue_app_to_stack, self._socket_queue_stack_to_app = \ ApplicationLayer._register_socket(self) def __set_fw_rules(self): hostrule = "{}:{}".format(self._protoname, self._local_addr[1]) pcaprule = "{} dst port {} or icmp or icmp6".format(self._protoname, self._local_addr[1]) log_debug("Preventing host from receiving traffic on {}".format(hostrule)) log_debug("Selecting only '{}' for receiving on pcap devices".format(pcaprule)) try: # prevent host networking stack from receiving traffic on port # that we're using. Firewall.add_rule(hostrule) # only receive packets with destination port of local port, # or any icmp packets PcapLiveDevice.set_bpf_filter_on_all_devices(pcaprule) except Exception as e: with yellow(): print ("Unable to complete socket emulation setup (failed on " "firewall/bpf filter installation). Did you start the " " program via switchyard?") import traceback print ("Here is the raw exception information:") with red(): print(indent(traceback.format_exc(), ' ')) raise e @property def family(self): ''' Get the address family of the socket. ''' return self._family @property def type(self): ''' Get the type of the socket. ''' return self._socktype @property def proto(self): ''' Get the protocol of the socket. ''' return self._proto def _sockid(self): return (IPProtocol(self._proto), self._local_addr[0], self._local_addr[1]) def _flowaddr(self): return (self._proto, self._local_addr[0], self._local_addr[1], self._remote_addr[0], self._remote_addr[1])
[docs] def accept(self): ''' Not implemented. ''' raise NotImplementedError()
[docs] def close(self): ''' Close the socket. ''' try: ApplicationLayer._unregister_socket(self) except: # ignore any errors (e.g., double-close) pass return 0
[docs] def bind(self, address): ''' Alter the local address with which this socket is associated. The address parameter is a 2-tuple consisting of an IP address and port number. NB: this method fails and returns -1 if the requested port to bind to is already in use but does *not* check that the address is valid. ''' portset = _gather_ports().union(ApplicationLayer._emuports()) if address[1] in portset: log_warn("Port is already in use.") return -1 oldid = self._sockid() # block firewall port # set stack to only allow packets through for addr/port self._local_addr = _normalize_addrs(address) # update firewall and pcap filters self.__set_fw_rules() ApplicationLayer._registry_update(self, oldid) return 0
[docs] def connect(self, address): ''' Set the remote address (IP address and port) with which this socket is used to communicate. ''' self._remote_addr = _normalize_addrs(address) return 0
[docs] def connect_ex(self, address): ''' Set the remote address (IP address and port) with which this socket is used to communicate. ''' self._remote_addr = _normalize_addrs(address) return 0
[docs] def getpeername(self): ''' Return a 2-tuple containing the remote IP address and port associated with the socket, if any. ''' return _stringify_addrs(self._remote_addr)
[docs] def getsockname(self): ''' Return a 2-tuple containing the local IP address and port associated with the socket. ''' return _stringify_addrs(self._local_addr)
[docs] def getsockopt(self, level, option, buffersize=0): ''' Not implemented. ''' raise NotImplementedError()
[docs] def gettimeout(self): ''' Obtain the currently set timeout value. ''' return self._timeout
@property def timeout(self): ''' Obtain the currently set timeout value. ''' return self._timeout
[docs] def listen(self, backlog): ''' Not implemented. ''' raise NotImplementedError()
[docs] def recv(self, buffersize, flags=0): ''' Receive data on the socket. The buffersize and flags arguments are currently ignored. Only returns the data. ''' _,_,data = self._recv(buffersize) return data
[docs] def recv_into(self, *args): ''' Not implemented. ''' raise NotImplementedError("*_into calls aren't implemented")
[docs] def recvfrom(self, buffersize, flags=0): ''' Receive data on the socket. The buffersize and flags arguments are currently ignored. Returns the data and an address tuple (IP address and port) of the remote host. ''' _,remoteaddr,data = self._recv(buffersize) return data,remoteaddr
[docs] def recvfrom_into(self, *args): ''' Not implemented. ''' raise NotImplementedError("*_into calls aren't implemented")
def _recv(self, nbytes): try: localaddr,remoteaddr,data = self._socket_queue_stack_to_app.get( block=self._block, timeout=self._timeout) return _stringify_addrs(localaddr),_stringify_addrs(remoteaddr),data except Empty as e: pass raise timeout("timed out")
[docs] def send(self, data, flags=0): ''' Send data on the socket. A call to connect() must have been previously made for this call to succeed. Flags is currently ignored. ''' if self._remote_addr == (None,None): raise sockerr("ENOTCONN: socket not connected") return self._send(data, self._flowaddr())
[docs] def sendto(self, data, *args): ''' Send data on the socket. Accepts the same parameters as the built-in socket sendto: data[, flags], address where address is a 2-tuple of IP address and port. Any flags are currently ignored. ''' remoteaddr = args[-1] remoteaddr = _normalize_addrs(remoteaddr) return self._send(data, (self._proto, self._local_addr[0], self._local_addr[1], remoteaddr[0], remoteaddr[1]))
def _send(self, data, flowaddr): self._socket_queue_app_to_stack.put( (flowaddr, data) ) return len(data)
[docs] def sendall(self, *args): ''' Not implemented. ''' raise NotImplementedError("sendall isn't implemented")
[docs] def sendmsg(self, *args): ''' Not implemented. ''' raise NotImplementedError("*msg calls aren't implemented")
[docs] def recvmsg(self, *args): ''' Not implemented. ''' raise NotImplementedError("*msg calls aren't implemented")
[docs] def setblocking(self, flags): ''' Set whether this socket should block on a call to recv*. ''' self._block = bool(flags)
[docs] def setsockopt(self, *args): ''' Not implemented. ''' raise NotImplementedError("set/get sockopt calls aren't implemented")
[docs] def settimeout(self, timeout): ''' Set the timeout value for this socket. ''' if timeout is None: self._block = True elif float(timeout) == 0.0: self._block = False else: self._timeout = float(timeout) self._block = True
[docs] def shutdown(self, flag): ''' Shut down the socket. This is currently implemented by calling close(). ''' return self.close()