forked from mirrors/gecko-dev
		
	
		
			
				
	
	
		
			1002 lines
		
	
	
	
		
			34 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1002 lines
		
	
	
	
		
			34 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # vim: set ts=4 et sw=4 tw=80
 | |
| # This Source Code Form is subject to the terms of the Mozilla Public
 | |
| # License, v. 2.0. If a copy of the MPL was not distributed with this
 | |
| # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 | |
| 
 | |
| from __future__ import absolute_import, print_function
 | |
| import ipaddr
 | |
| import socket
 | |
| import hmac
 | |
| import hashlib
 | |
| import passlib.utils  # for saslprep
 | |
| import copy
 | |
| import random
 | |
| import operator
 | |
| import os
 | |
| import platform
 | |
| import six
 | |
| import string
 | |
| import time
 | |
| from functools import reduce
 | |
| from string import Template
 | |
| from twisted.internet import reactor, protocol
 | |
| from twisted.internet.task import LoopingCall
 | |
| from twisted.internet.address import IPv4Address
 | |
| from twisted.internet.address import IPv6Address
 | |
| 
 | |
| MAGIC_COOKIE = 0x2112A442
 | |
| 
 | |
| REQUEST = 0
 | |
| INDICATION = 1
 | |
| SUCCESS_RESPONSE = 2
 | |
| ERROR_RESPONSE = 3
 | |
| 
 | |
| BINDING = 0x001
 | |
| ALLOCATE = 0x003
 | |
| REFRESH = 0x004
 | |
| SEND = 0x006
 | |
| DATA_MSG = 0x007
 | |
| CREATE_PERMISSION = 0x008
 | |
| CHANNEL_BIND = 0x009
 | |
| 
 | |
| # STUN spec chose silly values for these
 | |
| STUN_IPV4 = 1
 | |
| STUN_IPV6 = 2
 | |
| 
 | |
| MAPPED_ADDRESS = 0x0001
 | |
| USERNAME = 0x0006
 | |
| MESSAGE_INTEGRITY = 0x0008
 | |
| ERROR_CODE = 0x0009
 | |
| UNKNOWN_ATTRIBUTES = 0x000A
 | |
| LIFETIME = 0x000D
 | |
| DATA_ATTR = 0x0013
 | |
| XOR_PEER_ADDRESS = 0x0012
 | |
| REALM = 0x0014
 | |
| NONCE = 0x0015
 | |
| XOR_RELAYED_ADDRESS = 0x0016
 | |
| REQUESTED_TRANSPORT = 0x0019
 | |
| DONT_FRAGMENT = 0x001A
 | |
| XOR_MAPPED_ADDRESS = 0x0020
 | |
| SOFTWARE = 0x8022
 | |
| ALTERNATE_SERVER = 0x8023
 | |
| FINGERPRINT = 0x8028
 | |
| 
 | |
| STUN_PORT = 3478
 | |
| STUNS_PORT = 5349
 | |
| 
 | |
| TURN_REDIRECT_PORT = 3479
 | |
| TURNS_REDIRECT_PORT = 5350
 | |
| 
 | |
| 
 | |
| def unpack_uint(bytes_buf):
 | |
|     result = 0
 | |
|     for byte in bytes_buf:
 | |
|         result = (result << 8) + byte
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def pack_uint(value, width):
 | |
|     if value < 0:
 | |
|         raise ValueError("Invalid value: {}".format(value))
 | |
|     buf = bytearray([0] * width)
 | |
|     for i in range(0, width):
 | |
|         buf[i] = (value >> (8 * (width - i - 1))) & 0xFF
 | |
| 
 | |
|     return buf
 | |
| 
 | |
| 
 | |
| def unpack(bytes_buf, format_array):
 | |
|     results = ()
 | |
|     for width in format_array:
 | |
|         results = results + (unpack_uint(bytes_buf[0:width]),)
 | |
|         bytes_buf = bytes_buf[width:]
 | |
|     return results
 | |
| 
 | |
| 
 | |
| def pack(values, format_array):
 | |
|     if len(values) != len(format_array):
 | |
|         raise ValueError()
 | |
|     buf = bytearray()
 | |
|     for i in range(0, len(values)):
 | |
|         buf.extend(pack_uint(values[i], format_array[i]))
 | |
|     return buf
 | |
| 
 | |
| 
 | |
| def bitwise_pack(source, dest, start_bit, num_bits):
 | |
|     if num_bits <= 0 or num_bits > start_bit + 1:
 | |
|         raise ValueError(
 | |
|             "Invalid num_bits: {}, start_bit = {}".format(num_bits, start_bit)
 | |
|         )
 | |
|     last_bit = start_bit - num_bits + 1
 | |
|     source = source >> last_bit
 | |
|     dest = dest << num_bits
 | |
|     mask = (1 << num_bits) - 1
 | |
|     dest += source & mask
 | |
|     return dest
 | |
| 
 | |
| 
 | |
| def to_ipaddress(protocol, host, port):
 | |
|     if ":" not in host:
 | |
|         return IPv4Address(protocol, host, port)
 | |
| 
 | |
|     return IPv6Address(protocol, host, port)
 | |
| 
 | |
| 
 | |
| class StunAttribute(object):
 | |
|     """
 | |
|     Represents a STUN attribute in a raw format, according to the following:
 | |
| 
 | |
|      0                   1                   2                   3
 | |
|      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
 | |
|     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     |   StunAttribute.attr_type     |  Length (derived as needed)   |
 | |
|     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     |           StunAttribute.data (variable length)             ....
 | |
|     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     """
 | |
| 
 | |
|     __attr_header_fmt = [2, 2]
 | |
|     __attr_header_size = reduce(operator.add, __attr_header_fmt)
 | |
| 
 | |
|     def __init__(self, attr_type=0, buf=bytearray()):
 | |
|         self.attr_type = attr_type
 | |
|         self.data = buf
 | |
| 
 | |
|     def build(self):
 | |
|         buf = pack((self.attr_type, len(self.data)), self.__attr_header_fmt)
 | |
|         buf.extend(self.data)
 | |
|         # add padding if necessary
 | |
|         if len(buf) % 4:
 | |
|             buf.extend([0] * (4 - (len(buf) % 4)))
 | |
|         return buf
 | |
| 
 | |
|     def parse(self, buf):
 | |
|         if self.__attr_header_size > len(buf):
 | |
|             raise Exception("truncated at attribute: incomplete header")
 | |
| 
 | |
|         self.attr_type, length = unpack(buf, self.__attr_header_fmt)
 | |
|         length += self.__attr_header_size
 | |
| 
 | |
|         if length > len(buf):
 | |
|             raise Exception("truncated at attribute: incomplete contents")
 | |
| 
 | |
|         self.data = buf[self.__attr_header_size : length]
 | |
| 
 | |
|         # verify padding
 | |
|         while length % 4:
 | |
|             if buf[length]:
 | |
|                 raise ValueError("Non-zero padding")
 | |
|             length += 1
 | |
| 
 | |
|         return length
 | |
| 
 | |
| 
 | |
| class StunMessage(object):
 | |
|     """
 | |
|     Represents a STUN message. Contains a method, msg_class, cookie,
 | |
|     transaction_id, and attributes (as an array of StunAttribute).
 | |
| 
 | |
|     Has various functions for getting/adding attributes.
 | |
|     """
 | |
| 
 | |
|     def __init__(self):
 | |
|         self.method = 0
 | |
|         self.msg_class = 0
 | |
|         self.cookie = MAGIC_COOKIE
 | |
|         self.transaction_id = 0
 | |
|         self.attributes = []
 | |
| 
 | |
|     #      0                   1                   2                   3
 | |
|     #      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
 | |
|     #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     #     |0 0|M M M M M|C|M M M|C|M M M M|         Message Length        |
 | |
|     #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     #     |                         Magic Cookie                          |
 | |
|     #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     #     |                                                               |
 | |
|     #     |                     Transaction ID (96 bits)                  |
 | |
|     #     |                                                               |
 | |
|     #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     __header_fmt = [2, 2, 4, 12]
 | |
|     __header_size = reduce(operator.add, __header_fmt)
 | |
| 
 | |
|     # Returns how many bytes were parsed if buf was large enough, or how many
 | |
|     # bytes we would have needed if not. Throws if buf is malformed.
 | |
|     def parse(self, buf):
 | |
|         min_buf_size = self.__header_size
 | |
|         if len(buf) < min_buf_size:
 | |
|             return min_buf_size
 | |
| 
 | |
|         message_type, length, cookie, self.transaction_id = unpack(
 | |
|             buf, self.__header_fmt
 | |
|         )
 | |
|         min_buf_size += length
 | |
|         if len(buf) < min_buf_size:
 | |
|             return min_buf_size
 | |
| 
 | |
|         # Avert your eyes...
 | |
|         self.method = bitwise_pack(message_type, 0, 13, 5)
 | |
|         self.msg_class = bitwise_pack(message_type, 0, 8, 1)
 | |
|         self.method = bitwise_pack(message_type, self.method, 7, 3)
 | |
|         self.msg_class = bitwise_pack(message_type, self.msg_class, 4, 1)
 | |
|         self.method = bitwise_pack(message_type, self.method, 3, 4)
 | |
| 
 | |
|         if cookie != self.cookie:
 | |
|             raise Exception("Invalid cookie: {}".format(cookie))
 | |
| 
 | |
|         buf = buf[self.__header_size : min_buf_size]
 | |
|         while len(buf):
 | |
|             attr = StunAttribute()
 | |
|             length = attr.parse(buf)
 | |
|             buf = buf[length:]
 | |
|             self.attributes.append(attr)
 | |
| 
 | |
|         return min_buf_size
 | |
| 
 | |
|     # stop_after_attr_type is useful for calculating MESSAGE-DIGEST
 | |
|     def build(self, stop_after_attr_type=0):
 | |
|         attrs = bytearray()
 | |
|         for attr in self.attributes:
 | |
|             attrs.extend(attr.build())
 | |
|             if attr.attr_type == stop_after_attr_type:
 | |
|                 break
 | |
| 
 | |
|         message_type = bitwise_pack(self.method, 0, 11, 5)
 | |
|         message_type = bitwise_pack(self.msg_class, message_type, 1, 1)
 | |
|         message_type = bitwise_pack(self.method, message_type, 6, 3)
 | |
|         message_type = bitwise_pack(self.msg_class, message_type, 0, 1)
 | |
|         message_type = bitwise_pack(self.method, message_type, 3, 4)
 | |
| 
 | |
|         message = pack(
 | |
|             (message_type, len(attrs), self.cookie, self.transaction_id),
 | |
|             self.__header_fmt,
 | |
|         )
 | |
|         message.extend(attrs)
 | |
| 
 | |
|         return message
 | |
| 
 | |
|     def add_error_code(self, code, phrase=None):
 | |
|         #      0                   1                   2                   3
 | |
|         #      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
 | |
|         #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|         #     |           Reserved, should be 0         |Class|     Number    |
 | |
|         #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|         #     |      Reason Phrase (variable)                                ..
 | |
|         #     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|         error_code_fmt = [3, 1]
 | |
|         error_code = pack((code // 100, code % 100), error_code_fmt)
 | |
|         if phrase != None:
 | |
|             error_code.extend(bytearray(phrase, "utf-8"))
 | |
|         self.attributes.append(StunAttribute(ERROR_CODE, error_code))
 | |
| 
 | |
|     #     0                   1                   2                   3
 | |
|     #     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
 | |
|     #    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     #    |x x x x x x x x|    Family     |         X-Port                |
 | |
|     #    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     #    |                X-Address (Variable)
 | |
|     #    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
 | |
|     __v4addr_fmt = [1, 1, 2, 4]
 | |
|     __v6addr_fmt = [1, 1, 2, 16]
 | |
|     __v4addr_size = reduce(operator.add, __v4addr_fmt)
 | |
|     __v6addr_size = reduce(operator.add, __v6addr_fmt)
 | |
| 
 | |
|     def add_address(self, ip_address, version, port, attr_type):
 | |
|         if version == STUN_IPV4:
 | |
|             address = pack((0, STUN_IPV4, port, ip_address), self.__v4addr_fmt)
 | |
|         elif version == STUN_IPV6:
 | |
|             address = pack((0, STUN_IPV6, port, ip_address), self.__v6addr_fmt)
 | |
|         else:
 | |
|             raise ValueError("Invalid ip version: {}".format(version))
 | |
|         self.attributes.append(StunAttribute(attr_type, address))
 | |
| 
 | |
|     def get_xaddr(self, ip_addr, version):
 | |
|         if version == STUN_IPV4:
 | |
|             return self.cookie ^ ip_addr
 | |
|         elif version == STUN_IPV6:
 | |
|             return ((self.cookie << 96) + self.transaction_id) ^ ip_addr
 | |
|         else:
 | |
|             raise ValueError("Invalid family: {}".format(version))
 | |
| 
 | |
|     def get_xport(self, port):
 | |
|         return (self.cookie >> 16) ^ port
 | |
| 
 | |
|     def add_xor_address(self, addr_port, attr_type):
 | |
|         ip_address = ipaddr.IPAddress(addr_port.host)
 | |
|         version = STUN_IPV6 if ip_address.version == 6 else STUN_IPV4
 | |
|         xaddr = self.get_xaddr(int(ip_address), version)
 | |
|         xport = self.get_xport(addr_port.port)
 | |
|         self.add_address(xaddr, version, xport, attr_type)
 | |
| 
 | |
|     def add_data(self, buf):
 | |
|         self.attributes.append(StunAttribute(DATA_ATTR, buf))
 | |
| 
 | |
|     def find(self, attr_type):
 | |
|         for attr in self.attributes:
 | |
|             if attr.attr_type == attr_type:
 | |
|                 return attr
 | |
|         return None
 | |
| 
 | |
|     def get_xor_address(self, attr_type):
 | |
|         addr_attr = self.find(attr_type)
 | |
|         if not addr_attr:
 | |
|             return None
 | |
| 
 | |
|         padding, family, xport, xaddr = unpack(addr_attr.data, self.__v4addr_fmt)
 | |
|         addr_ctor = IPv4Address
 | |
|         if family == STUN_IPV6:
 | |
|             padding, family, xport, xaddr = unpack(addr_attr.data, self.__v6addr_fmt)
 | |
|             addr_ctor = IPv6Address
 | |
|         elif family != STUN_IPV4:
 | |
|             raise ValueError("Invalid family: {}".format(family))
 | |
| 
 | |
|         return addr_ctor(
 | |
|             "UDP",
 | |
|             str(ipaddr.IPAddress(self.get_xaddr(xaddr, family))),
 | |
|             self.get_xport(xport),
 | |
|         )
 | |
| 
 | |
|     def add_nonce(self, nonce):
 | |
|         self.attributes.append(StunAttribute(NONCE, bytearray(nonce, "utf-8")))
 | |
| 
 | |
|     def add_realm(self, realm):
 | |
|         self.attributes.append(StunAttribute(REALM, bytearray(realm, "utf-8")))
 | |
| 
 | |
|     def calculate_message_digest(self, username, realm, password):
 | |
|         digest_buf = self.build(MESSAGE_INTEGRITY)
 | |
|         # Trim off the MESSAGE-INTEGRITY attr
 | |
|         digest_buf = digest_buf[: len(digest_buf) - 24]
 | |
|         password = passlib.utils.saslprep(six.text_type(password))
 | |
|         key_string = "{}:{}:{}".format(username, realm, password)
 | |
|         md5 = hashlib.md5()
 | |
|         md5.update(bytearray(key_string, "utf-8"))
 | |
|         key = md5.digest()
 | |
|         return bytearray(hmac.new(key, digest_buf, hashlib.sha1).digest())
 | |
| 
 | |
|     def add_lifetime(self, lifetime):
 | |
|         self.attributes.append(StunAttribute(LIFETIME, pack_uint(lifetime, 4)))
 | |
| 
 | |
|     def get_lifetime(self):
 | |
|         lifetime_attr = self.find(LIFETIME)
 | |
|         if not lifetime_attr:
 | |
|             return None
 | |
|         return unpack_uint(lifetime_attr.data[0:4])
 | |
| 
 | |
|     def get_username(self):
 | |
|         username = self.find(USERNAME)
 | |
|         if not username:
 | |
|             return None
 | |
|         return str(username.data)
 | |
| 
 | |
|     def add_message_integrity(self, username, realm, password):
 | |
|         dummy_value = bytearray([0] * 20)
 | |
|         self.attributes.append(StunAttribute(MESSAGE_INTEGRITY, dummy_value))
 | |
|         digest = self.calculate_message_digest(username, realm, password)
 | |
|         self.find(MESSAGE_INTEGRITY).data = digest
 | |
| 
 | |
|     def add_alternate_server(self, host, port):
 | |
|         address = ipaddr.IPAddress(host)
 | |
|         version = STUN_IPV6 if address.version == 6 else STUN_IPV4
 | |
|         self.add_address(int(address), version, port, ALTERNATE_SERVER)
 | |
| 
 | |
| 
 | |
| class Allocation(protocol.DatagramProtocol):
 | |
|     """
 | |
|     Comprises the socket for a TURN allocation, a back-reference to the
 | |
|     transport we will forward received traffic on, the allocator's address and
 | |
|     username, the set of permissions for the allocation, and the allocation's
 | |
|     expiry.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, other_transport_handler, allocator_address, username):
 | |
|         self.permissions = set()  # str, int tuples
 | |
|         # Handler to use when sending stuff that arrives on the allocation
 | |
|         self.other_transport_handler = other_transport_handler
 | |
|         self.allocator_address = allocator_address
 | |
|         self.username = username
 | |
|         self.expiry = time.time()
 | |
|         self.port = reactor.listenUDP(0, self, interface=v4_address)
 | |
| 
 | |
|     def datagramReceived(self, data, address):
 | |
|         host = address[0]
 | |
|         port = address[1]
 | |
|         if not host in self.permissions:
 | |
|             print(
 | |
|                 "Dropping packet from {}:{}, no permission on allocation {}".format(
 | |
|                     host, port, self.transport.getHost()
 | |
|                 )
 | |
|             )
 | |
|             return
 | |
| 
 | |
|         data_indication = StunMessage()
 | |
|         data_indication.method = DATA_MSG
 | |
|         data_indication.msg_class = INDICATION
 | |
|         data_indication.transaction_id = random.getrandbits(96)
 | |
| 
 | |
|         # Only handles UDP allocations. Doubtful that we need more than this.
 | |
|         data_indication.add_xor_address(
 | |
|             to_ipaddress("UDP", host, port), XOR_PEER_ADDRESS
 | |
|         )
 | |
|         data_indication.add_data(data)
 | |
| 
 | |
|         self.other_transport_handler.write(
 | |
|             data_indication.build(), self.allocator_address
 | |
|         )
 | |
| 
 | |
|     def close(self):
 | |
|         self.port.stopListening()
 | |
|         self.port = None
 | |
| 
 | |
| 
 | |
| class StunHandler(object):
 | |
|     """
 | |
|     Frames and handles STUN messages. This is the core logic of the TURN
 | |
|     server, along with Allocation.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, transport_handler):
 | |
|         self.client_address = None
 | |
|         self.data = bytearray()
 | |
|         self.transport_handler = transport_handler
 | |
| 
 | |
|     def data_received(self, data, address):
 | |
|         self.data += data
 | |
|         while True:
 | |
|             stun_message = StunMessage()
 | |
|             parsed_len = stun_message.parse(self.data)
 | |
|             if parsed_len > len(self.data):
 | |
|                 break
 | |
|             self.data = self.data[parsed_len:]
 | |
| 
 | |
|             response = self.handle_stun(stun_message, address)
 | |
|             if response:
 | |
|                 self.transport_handler.write(response, address)
 | |
| 
 | |
|     def handle_stun(self, stun_message, address):
 | |
|         self.client_address = address
 | |
|         if stun_message.msg_class == INDICATION:
 | |
|             if stun_message.method == SEND:
 | |
|                 self.handle_send_indication(stun_message)
 | |
|             else:
 | |
|                 print(
 | |
|                     "Dropping unknown indication method: {}".format(stun_message.method)
 | |
|                 )
 | |
|             return None
 | |
| 
 | |
|         if stun_message.msg_class != REQUEST:
 | |
|             print("Dropping STUN response, method: {}".format(stun_message.method))
 | |
|             return None
 | |
| 
 | |
|         if stun_message.method == BINDING:
 | |
|             return self.make_success_response(stun_message).build()
 | |
|         elif stun_message.method == ALLOCATE:
 | |
|             return self.handle_allocation(stun_message).build()
 | |
|         elif stun_message.method == REFRESH:
 | |
|             return self.handle_refresh(stun_message).build()
 | |
|         elif stun_message.method == CREATE_PERMISSION:
 | |
|             return self.handle_permission(stun_message).build()
 | |
|         else:
 | |
|             return self.make_error_response(
 | |
|                 stun_message,
 | |
|                 400,
 | |
|                 ("Unsupported STUN request, method: {}".format(stun_message.method)),
 | |
|             ).build()
 | |
| 
 | |
|     def get_allocation_tuple(self):
 | |
|         return (
 | |
|             self.client_address.host,
 | |
|             self.client_address.port,
 | |
|             self.transport_handler.transport.getHost().type,
 | |
|             self.transport_handler.transport.getHost().host,
 | |
|             self.transport_handler.transport.getHost().port,
 | |
|         )
 | |
| 
 | |
|     def handle_allocation(self, request):
 | |
|         allocate_response = self.check_long_term_auth(request)
 | |
|         if allocate_response.msg_class == SUCCESS_RESPONSE:
 | |
|             if self.get_allocation_tuple() in allocations:
 | |
|                 return self.make_error_response(
 | |
|                     request,
 | |
|                     437,
 | |
|                     (
 | |
|                         "Duplicate allocation request for tuple {}".format(
 | |
|                             self.get_allocation_tuple()
 | |
|                         )
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|             allocation = Allocation(
 | |
|                 self.transport_handler, self.client_address, request.get_username()
 | |
|             )
 | |
| 
 | |
|             allocate_response.add_xor_address(
 | |
|                 allocation.transport.getHost(), XOR_RELAYED_ADDRESS
 | |
|             )
 | |
| 
 | |
|             lifetime = request.get_lifetime()
 | |
|             if lifetime == None:
 | |
|                 return self.make_error_response(
 | |
|                     request, 400, "Missing lifetime attribute in allocation request"
 | |
|                 )
 | |
| 
 | |
|             lifetime = min(lifetime, 3600)
 | |
|             allocate_response.add_lifetime(lifetime)
 | |
|             allocation.expiry = time.time() + lifetime
 | |
| 
 | |
|             allocate_response.add_message_integrity(turn_user, turn_realm, turn_pass)
 | |
|             allocations[self.get_allocation_tuple()] = allocation
 | |
|         return allocate_response
 | |
| 
 | |
|     def handle_refresh(self, request):
 | |
|         refresh_response = self.check_long_term_auth(request)
 | |
|         if refresh_response.msg_class == SUCCESS_RESPONSE:
 | |
|             try:
 | |
|                 allocation = allocations[self.get_allocation_tuple()]
 | |
|             except KeyError:
 | |
|                 return self.make_error_response(
 | |
|                     request,
 | |
|                     437,
 | |
|                     (
 | |
|                         "Refresh request for non-existing allocation, tuple {}".format(
 | |
|                             self.get_allocation_tuple()
 | |
|                         )
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|             if allocation.username != request.get_username():
 | |
|                 return self.make_error_response(
 | |
|                     request,
 | |
|                     441,
 | |
|                     (
 | |
|                         "Refresh request with wrong user, exp {}, got {}".format(
 | |
|                             allocation.username, request.get_username()
 | |
|                         )
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|             lifetime = request.get_lifetime()
 | |
|             if lifetime == None:
 | |
|                 return self.make_error_response(
 | |
|                     request, 400, "Missing lifetime attribute in allocation request"
 | |
|                 )
 | |
| 
 | |
|             lifetime = min(lifetime, 3600)
 | |
|             refresh_response.add_lifetime(lifetime)
 | |
|             allocation.expiry = time.time() + lifetime
 | |
| 
 | |
|             refresh_response.add_message_integrity(turn_user, turn_realm, turn_pass)
 | |
|         return refresh_response
 | |
| 
 | |
|     def handle_permission(self, request):
 | |
|         permission_response = self.check_long_term_auth(request)
 | |
|         if permission_response.msg_class == SUCCESS_RESPONSE:
 | |
|             try:
 | |
|                 allocation = allocations[self.get_allocation_tuple()]
 | |
|             except KeyError:
 | |
|                 return self.make_error_response(
 | |
|                     request,
 | |
|                     437,
 | |
|                     (
 | |
|                         "No such allocation for permission request, tuple {}".format(
 | |
|                             self.get_allocation_tuple()
 | |
|                         )
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|             if allocation.username != request.get_username():
 | |
|                 return self.make_error_response(
 | |
|                     request,
 | |
|                     441,
 | |
|                     (
 | |
|                         "Permission request with wrong user, exp {}, got {}".format(
 | |
|                             allocation.username, request.get_username()
 | |
|                         )
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|             # TODO: Handle multiple XOR-PEER-ADDRESS
 | |
|             peer_address = request.get_xor_address(XOR_PEER_ADDRESS)
 | |
|             if not peer_address:
 | |
|                 return self.make_error_response(
 | |
|                     request, 400, "Missing XOR-PEER-ADDRESS on permission request"
 | |
|                 )
 | |
| 
 | |
|             permission_response.add_message_integrity(turn_user, turn_realm, turn_pass)
 | |
|             allocation.permissions.add(peer_address.host)
 | |
| 
 | |
|         return permission_response
 | |
| 
 | |
|     def handle_send_indication(self, indication):
 | |
|         try:
 | |
|             allocation = allocations[self.get_allocation_tuple()]
 | |
|         except KeyError:
 | |
|             print(
 | |
|                 "Dropping send indication; no allocation for tuple {}".format(
 | |
|                     self.get_allocation_tuple()
 | |
|                 )
 | |
|             )
 | |
|             return
 | |
| 
 | |
|         peer_address = indication.get_xor_address(XOR_PEER_ADDRESS)
 | |
|         if not peer_address:
 | |
|             print("Dropping send indication, missing XOR-PEER-ADDRESS")
 | |
|             return
 | |
| 
 | |
|         data_attr = indication.find(DATA_ATTR)
 | |
|         if not data_attr:
 | |
|             print("Dropping send indication, missing DATA")
 | |
|             return
 | |
| 
 | |
|         if indication.find(DONT_FRAGMENT):
 | |
|             print("Dropping send indication, DONT-FRAGMENT set")
 | |
|             return
 | |
| 
 | |
|         if not peer_address.host in allocation.permissions:
 | |
|             print(
 | |
|                 "Dropping send indication, no permission for {} on tuple {}".format(
 | |
|                     peer_address.host, self.get_allocation_tuple()
 | |
|                 )
 | |
|             )
 | |
|             return
 | |
| 
 | |
|         allocation.transport.write(
 | |
|             data_attr.data, (peer_address.host, peer_address.port)
 | |
|         )
 | |
| 
 | |
|     def make_success_response(self, request):
 | |
|         response = copy.deepcopy(request)
 | |
|         response.attributes = []
 | |
|         response.add_xor_address(self.client_address, XOR_MAPPED_ADDRESS)
 | |
|         response.msg_class = SUCCESS_RESPONSE
 | |
|         return response
 | |
| 
 | |
|     def make_error_response(self, request, code, reason=None):
 | |
|         if reason:
 | |
|             print("{}: rejecting with {}".format(reason, code))
 | |
|         response = copy.deepcopy(request)
 | |
|         response.attributes = []
 | |
|         response.add_error_code(code, reason)
 | |
|         response.msg_class = ERROR_RESPONSE
 | |
|         return response
 | |
| 
 | |
|     def make_challenge_response(self, request, reason=None):
 | |
|         response = self.make_error_response(request, 401, reason)
 | |
|         # 65 means the hex encoding will need padding half the time
 | |
|         response.add_nonce("{:x}".format(random.getrandbits(65)))
 | |
|         response.add_realm(turn_realm)
 | |
|         return response
 | |
| 
 | |
|     def check_long_term_auth(self, request):
 | |
|         message_integrity = request.find(MESSAGE_INTEGRITY)
 | |
|         if not message_integrity:
 | |
|             return self.make_challenge_response(request)
 | |
| 
 | |
|         username = request.find(USERNAME)
 | |
|         realm = request.find(REALM)
 | |
|         nonce = request.find(NONCE)
 | |
|         if not username or not realm or not nonce:
 | |
|             return self.make_error_response(
 | |
|                 request, 400, "Missing either USERNAME, NONCE, or REALM"
 | |
|             )
 | |
| 
 | |
|         if username.data.decode("utf-8") != turn_user:
 | |
|             return self.make_challenge_response(
 | |
|                 request, "Wrong user {}, exp {}".format(username.data, turn_user)
 | |
|             )
 | |
| 
 | |
|         expected_message_digest = request.calculate_message_digest(
 | |
|             turn_user, turn_realm, turn_pass
 | |
|         )
 | |
|         if message_integrity.data != expected_message_digest:
 | |
|             return self.make_challenge_response(request, "Incorrect message disgest")
 | |
| 
 | |
|         return self.make_success_response(request)
 | |
| 
 | |
| 
 | |
| class StunRedirectHandler(StunHandler):
 | |
|     """
 | |
|     Frames and handles STUN messages by redirecting to the "real" server port.
 | |
|     Performs the redirect with auth, so does a 401 to unauthed requests.
 | |
|     Can be used to test port-based redirect handling.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, transport_handler):
 | |
|         super(StunRedirectHandler, self).__init__(transport_handler)
 | |
| 
 | |
|     def handle_stun(self, stun_message, address):
 | |
|         self.client_address = address
 | |
|         if stun_message.msg_class == REQUEST:
 | |
|             challenge_response = self.check_long_term_auth(stun_message)
 | |
| 
 | |
|             if challenge_response.msg_class == SUCCESS_RESPONSE:
 | |
|                 return self.make_redirect_response(stun_message).build()
 | |
| 
 | |
|             return challenge_response.build()
 | |
| 
 | |
|     def make_redirect_response(self, request):
 | |
|         response = self.make_error_response(request, 300, "Try alternate")
 | |
|         port = STUN_PORT
 | |
|         if self.transport_handler.transport.getHost().port == TURNS_REDIRECT_PORT:
 | |
|             port = STUNS_PORT
 | |
| 
 | |
|         response.add_alternate_server(
 | |
|             self.transport_handler.transport.getHost().host, port
 | |
|         )
 | |
| 
 | |
|         response.add_message_integrity(turn_user, turn_realm, turn_pass)
 | |
|         return response
 | |
| 
 | |
| 
 | |
| class UdpStunHandler(protocol.DatagramProtocol):
 | |
|     """
 | |
|     Represents a UDP listen port for TURN.
 | |
|     """
 | |
| 
 | |
|     def datagramReceived(self, data, address):
 | |
|         stun_handler = StunHandler(self)
 | |
|         stun_handler.data_received(data, to_ipaddress("UDP", address[0], address[1]))
 | |
| 
 | |
|     def write(self, data, address):
 | |
|         self.transport.write(bytes(data), (address.host, address.port))
 | |
| 
 | |
| 
 | |
| class UdpStunRedirectHandler(protocol.DatagramProtocol):
 | |
|     """
 | |
|     Represents a UDP listen port for TURN that will redirect.
 | |
|     """
 | |
| 
 | |
|     def datagramReceived(self, data, address):
 | |
|         stun_handler = StunRedirectHandler(self)
 | |
|         stun_handler.data_received(data, to_ipaddress("UDP", address[0], address[1]))
 | |
| 
 | |
|     def write(self, data, address):
 | |
|         self.transport.write(bytes(data), (address.host, address.port))
 | |
| 
 | |
| 
 | |
| class TcpStunHandlerFactory(protocol.Factory):
 | |
|     """
 | |
|     Represents a TCP listen port for TURN.
 | |
|     """
 | |
| 
 | |
|     def buildProtocol(self, addr):
 | |
|         return TcpStunHandler(addr)
 | |
| 
 | |
| 
 | |
| class TcpStunHandler(protocol.Protocol):
 | |
|     """
 | |
|     Represents a connected TCP port for TURN.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, addr):
 | |
|         self.address = addr
 | |
|         self.stun_handler = None
 | |
| 
 | |
|     def dataReceived(self, data):
 | |
|         # This needs to persist, since it handles framing
 | |
|         if not self.stun_handler:
 | |
|             self.stun_handler = StunHandler(self)
 | |
|         self.stun_handler.data_received(data, self.address)
 | |
| 
 | |
|     def connectionLost(self, reason):
 | |
|         print("Lost connection from {}".format(self.address))
 | |
|         # Destroy allocations that this connection made
 | |
|         keys_to_delete = []
 | |
|         for key, allocation in allocations.items():
 | |
|             if allocation.other_transport_handler == self:
 | |
|                 print("Closing allocation due to dropped connection: {}".format(key))
 | |
|                 keys_to_delete.append(key)
 | |
|                 allocation.close()
 | |
| 
 | |
|         for key in keys_to_delete:
 | |
|             del allocations[key]
 | |
| 
 | |
|     def write(self, data, address):
 | |
|         self.transport.write(bytes(data))
 | |
| 
 | |
| 
 | |
| class TcpStunRedirectHandlerFactory(protocol.Factory):
 | |
|     """
 | |
|     Represents a TCP listen port for TURN that will redirect.
 | |
|     """
 | |
| 
 | |
|     def buildProtocol(self, addr):
 | |
|         return TcpStunRedirectHandler(addr)
 | |
| 
 | |
| 
 | |
| class TcpStunRedirectHandler(protocol.DatagramProtocol):
 | |
|     def __init__(self, addr):
 | |
|         self.address = addr
 | |
|         self.stun_handler = None
 | |
| 
 | |
|     def dataReceived(self, data):
 | |
|         # This needs to persist, since it handles framing. Framing matters here
 | |
|         # because we do a round of auth before redirecting.
 | |
|         if not self.stun_handler:
 | |
|             self.stun_handler = StunRedirectHandler(self)
 | |
|         self.stun_handler.data_received(data, self.address)
 | |
| 
 | |
|     def write(self, data, address):
 | |
|         self.transport.write(bytes(data))
 | |
| 
 | |
| 
 | |
| def get_default_route(family):
 | |
|     dummy_socket = socket.socket(family, socket.SOCK_DGRAM)
 | |
|     if family is socket.AF_INET:
 | |
|         dummy_socket.connect(("8.8.8.8", 53))
 | |
|     else:
 | |
|         dummy_socket.connect(("2001:4860:4860::8888", 53))
 | |
| 
 | |
|     default_route = dummy_socket.getsockname()[0]
 | |
|     dummy_socket.close()
 | |
|     return default_route
 | |
| 
 | |
| 
 | |
| turn_user = "foo"
 | |
| turn_pass = "bar"
 | |
| turn_realm = "mozilla.invalid"
 | |
| allocations = {}
 | |
| v4_address = get_default_route(socket.AF_INET)
 | |
| try:
 | |
|     v6_address = get_default_route(socket.AF_INET6)
 | |
| except:
 | |
|     v6_address = ""
 | |
| 
 | |
| 
 | |
| def prune_allocations():
 | |
|     now = time.time()
 | |
|     keys_to_delete = []
 | |
|     for key, allocation in allocations.items():
 | |
|         if allocation.expiry < now:
 | |
|             print("Allocation expired: {}".format(key))
 | |
|             keys_to_delete.append(key)
 | |
|             allocation.close()
 | |
| 
 | |
|     for key in keys_to_delete:
 | |
|         del allocations[key]
 | |
| 
 | |
| 
 | |
| CERT_FILE = "selfsigned.crt"
 | |
| KEY_FILE = "private.key"
 | |
| 
 | |
| 
 | |
| def create_self_signed_cert(name):
 | |
|     from OpenSSL import crypto
 | |
| 
 | |
|     if os.path.isfile(CERT_FILE) and os.path.isfile(KEY_FILE):
 | |
|         return
 | |
| 
 | |
|     # create a key pair
 | |
|     k = crypto.PKey()
 | |
|     k.generate_key(crypto.TYPE_RSA, 1024)
 | |
| 
 | |
|     # create a self-signed cert
 | |
|     cert = crypto.X509()
 | |
|     cert.get_subject().C = "US"
 | |
|     cert.get_subject().ST = "TX"
 | |
|     cert.get_subject().L = "Dallas"
 | |
|     cert.get_subject().O = "Mozilla test iceserver"
 | |
|     cert.get_subject().OU = "Mozilla test iceserver"
 | |
|     cert.get_subject().CN = name
 | |
|     cert.set_serial_number(1000)
 | |
|     cert.gmtime_adj_notBefore(0)
 | |
|     cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
 | |
|     cert.set_issuer(cert.get_subject())
 | |
|     cert.set_pubkey(k)
 | |
|     cert.add_extensions(
 | |
|         [crypto.X509Extension(b"subjectAltName", False, f"DNS:{name}".encode())]
 | |
|     )
 | |
|     cert.sign(k, "sha1")
 | |
| 
 | |
|     open(CERT_FILE, "wb").write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
 | |
|     open(KEY_FILE, "wb").write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k))
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     random.seed()
 | |
| 
 | |
|     if platform.system() == "Windows":
 | |
|         # Windows is finicky about allowing real interfaces to talk to loopback.
 | |
|         interface_4 = v4_address
 | |
|         interface_6 = v6_address
 | |
|         hostname = socket.gethostname()
 | |
|     else:
 | |
|         # Our linux builders do not have a hostname that resolves to the real
 | |
|         # interface.
 | |
|         interface_4 = "127.0.0.1"
 | |
|         interface_6 = "::1"
 | |
|         hostname = "localhost"
 | |
| 
 | |
|     reactor.listenUDP(STUN_PORT, UdpStunHandler(), interface=interface_4)
 | |
|     reactor.listenTCP(STUN_PORT, TcpStunHandlerFactory(), interface=interface_4)
 | |
| 
 | |
|     reactor.listenUDP(
 | |
|         TURN_REDIRECT_PORT, UdpStunRedirectHandler(), interface=interface_4
 | |
|     )
 | |
|     reactor.listenTCP(
 | |
|         TURN_REDIRECT_PORT, TcpStunRedirectHandlerFactory(), interface=interface_4
 | |
|     )
 | |
| 
 | |
|     try:
 | |
|         reactor.listenUDP(STUN_PORT, UdpStunHandler(), interface=interface_6)
 | |
|         reactor.listenTCP(STUN_PORT, TcpStunHandlerFactory(), interface=interface_6)
 | |
| 
 | |
|         reactor.listenUDP(
 | |
|             TURN_REDIRECT_PORT, UdpStunRedirectHandler(), interface=interface_6
 | |
|         )
 | |
|         reactor.listenTCP(
 | |
|             TURN_REDIRECT_PORT, TcpStunRedirectHandlerFactory(), interface=interface_6
 | |
|         )
 | |
|     except:
 | |
|         pass
 | |
| 
 | |
|     try:
 | |
|         from twisted.internet import ssl
 | |
|         from OpenSSL import SSL
 | |
| 
 | |
|         create_self_signed_cert(hostname)
 | |
|         tls_context_factory = ssl.DefaultOpenSSLContextFactory(
 | |
|             KEY_FILE, CERT_FILE, SSL.TLSv1_2_METHOD
 | |
|         )
 | |
|         reactor.listenSSL(
 | |
|             STUNS_PORT,
 | |
|             TcpStunHandlerFactory(),
 | |
|             tls_context_factory,
 | |
|             interface=interface_4,
 | |
|         )
 | |
| 
 | |
|         try:
 | |
|             reactor.listenSSL(
 | |
|                 STUNS_PORT,
 | |
|                 TcpStunHandlerFactory(),
 | |
|                 tls_context_factory,
 | |
|                 interface=interface_6,
 | |
|             )
 | |
| 
 | |
|             reactor.listenSSL(
 | |
|                 TURNS_REDIRECT_PORT,
 | |
|                 TcpStunRedirectHandlerFactory(),
 | |
|                 tls_context_factory,
 | |
|                 interface=interface_6,
 | |
|             )
 | |
|         except:
 | |
|             pass
 | |
| 
 | |
|         f = open(CERT_FILE, "r")
 | |
|         lines = f.readlines()
 | |
|         lines.pop(0)  # Remove BEGIN CERTIFICATE
 | |
|         lines.pop()  # Remove END CERTIFICATE
 | |
|         # pylint --py3k: W1636 W1649
 | |
|         lines = list(map(str.strip, lines))
 | |
|         certbase64 = "".join(lines)  # pylint --py3k: W1649
 | |
| 
 | |
|         turns_url = ', "turns:' + hostname + '"'
 | |
|         cert_prop = ', "cert":"' + certbase64 + '"'
 | |
|     except:
 | |
|         turns_url = ""
 | |
|         cert_prop = ""
 | |
|         pass
 | |
| 
 | |
|     allocation_pruner = LoopingCall(prune_allocations)
 | |
|     allocation_pruner.start(1)
 | |
| 
 | |
|     template = Template(
 | |
|         '[\
 | |
| {"urls":["stun:$hostname", "stun:$hostname?transport=tcp"]}, \
 | |
| {"username":"$user","credential":"$pwd","turn_redirect_port":"$TURN_REDIRECT_PORT","turns_redirect_port":"$TURNS_REDIRECT_PORT","urls": \
 | |
| ["turn:$hostname", "turn:$hostname?transport=tcp" $turns_url] \
 | |
| $cert_prop}]'  # Hack to make it easier to override cert checks
 | |
|     )
 | |
| 
 | |
|     print(
 | |
|         template.substitute(
 | |
|             user=turn_user,
 | |
|             pwd=turn_pass,
 | |
|             hostname=hostname,
 | |
|             turns_url=turns_url,
 | |
|             cert_prop=cert_prop,
 | |
|             TURN_REDIRECT_PORT=TURN_REDIRECT_PORT,
 | |
|             TURNS_REDIRECT_PORT=TURNS_REDIRECT_PORT,
 | |
|         )
 | |
|     )
 | |
| 
 | |
|     reactor.run()
 | 
