safe progress

This commit is contained in:
Istvan Ruzman
2020-08-06 18:04:24 +02:00
parent 3254bc66e0
commit fd16436c3c
53 changed files with 2167 additions and 4589 deletions

46
src/pyrad3/__init__.py Normal file
View File

@@ -0,0 +1,46 @@
"""Python RADIUS client code.
pyrad is an implementation of a RADIUS client as described in RFC2865.
It takes care of all the details like building RADIUS packets, sending
them and decoding responses.
Here is an example of doing a authentication request::
import pyrad.packet
from pyrad.client import Client
from pyrad.dictionary import Dictionary
srv = Client(server="radius.my.domain", secret="s3cr3t",
dict = Dictionary("dicts/dictionary", "dictionary.acc"))
req = srv.CreatePacket(code=pyrad.packet.AccessRequest,
User_Name = "wichert", NAS_Identifier="localhost")
req["User-Password"] = req.PwCrypt("password")
reply = srv.SendPacket(req)
if reply.code = =pyrad.packet.AccessAccept:
print "access accepted"
else:
print "access denied"
print "Attributes returned by server:"
for key, value in reply.items():
print f'{key}: {value}')
This package contains four modules:
- client: RADIUS client code
- dictionary: RADIUS attribute dictionary
- packet: a RADIUS packet as send to/from servers
- tools: utility functions
"""
__docformat__ = 'epytext en'
__author__ = 'Istvan Ruzman <istvan@ruzman.eu>'
__url__ = 'http://pyrad.readthedocs.io/en/latest/?badge=latest'
__copyright__ = 'Copyright 2020 Istvan Ruzman'
__version__ = '0.1.0'
__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'utils']

39
src/pyrad3/bidict.py Normal file
View File

@@ -0,0 +1,39 @@
# bidict.py
#
# Bidirectional map
class BiDict:
def __init__(self):
self.forward = {}
self.backward = {}
def add(self, one, two):
self.forward[one] = two
self.backward[two] = one
def __len__(self):
return len(self.forward)
def __getitem__(self, key):
return self.get_forward(key)
def __delitem__(self, key):
try:
del self.backward[self.forward[key]]
del self.forward[key]
except KeyError:
del self.forward[self.backward[key]]
del self.backward[key]
def get_forward(self, key):
return self.forward[key]
def has_forward(self, key):
return key in self.forward
def get_backward(self, key):
return self.backward[key]
def has_backward(self, key):
return key in self.backward

131
src/pyrad3/client.py Normal file
View File

@@ -0,0 +1,131 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Implementation of a simple but extensible RADIUS Client"""
from typing import Optional, Union
from ipaddress import IPv4Address, IPv6Address
import select
import socket
import time
import pyrad3.packet as P
from pyrad3 import host
SUPPORTED_SEND_TYPES = [
P.Code.AccessRequest,
P.Code.AccountingRequest,
P.Code.CoARequest,
]
PACKET_TYPE_PORT_MAPPING = {
P.Code.AccessRequest: "authport",
P.Code.AccountingRequest: "acctport",
P.Code.CoARequest: "coaport",
}
class Timeout(Exception):
"""Exception for wait timeouts"""
class UnsupportedPacketType(Exception):
"""Exception for received packets"""
class Client(host.Host):
"""A simple and extensible RADIUS Client."""
def __init__(
self,
server: Union[str, IPv4Address, IPv6Address],
secret: bytes,
radius_dictionary: dict,
interface: Optional[str],
**kwargs,
):
super().__init__(secret, radius_dictionary, **kwargs)
self.server = server
self.interface = interface
self._socket: Optional[socket.socket] = None
self._poll: Optional[select.poll] = None
def bind(self, addr):
"""Bind the Address to some socket"""
self._socket_close()
self._socket_open()
self._socket.bind(addr)
def _socket_open(self):
"""Open a client socket"""
if self._socket is not None:
return
try:
family = socket.getaddrinfo(self.server, "www")[0][0]
except socket.gaierror:
family = socket.AF_INET
self._socket = socket.socket(family, socket.SOCK_DGRAM)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if self.interface is not None:
# This will fail on non-Linux systems - that's ok, because we don't
# have any implementation yet for non-Linux systems.
# Better to fail loudly than to do something unexpected silently.
self._socket.setsockopt(
socket.SOL_SOCKET,
socket.SO_BINDTODEVICE,
self.interface,
len(self.interface),
)
self._poll = select.poll()
self._poll.register(self._socket, select.POLLIN)
def _socket_close(self):
"""Close the Client socket"""
if self._socket is not None:
self._poll.unregister(self._socket)
self._socket.close()
self._socket = None
def _select_port(self, packet: P.Packet):
"""Select the RADIUS Port depending on the RADIUS Packet type"""
try:
port_type = PACKET_TYPE_PORT_MAPPING[packet.code]
return getattr(self, port_type)
except (AttributeError, KeyError):
pass
raise UnsupportedPacketType(f"The packet type {packet.code} by Client")
def _send_packet(self, packet: P.Packet):
"""Send a RADIUS Packet and wait for the reply"""
assert self._socket is not None
assert self._poll is not None
port = self._select_port(packet)
raw_packet = packet.serialize()
for _attempt in range(self.retries):
now = time.time()
waitto = now + self.timeout
self._socket.sendto(raw_packet, (self.server, port))
while now < waitto:
if not self._poll.poll((waitto - now) * 1000):
# socket is not ready for some reason
now = time.time()
continue
rawreply = self._socket.recv(4096)
try:
return packet.verify_reply(rawreply)
except P.PacketError:
pass
# timed out: try the next attempt after increasing the acct delay time
if packet.code == packet.AccountingRequest:
packet.increase_acct_delay_time(self.timeout)
raw_packet = packet.serialize()
raise Timeout

442
src/pyrad3/client_async.py Normal file
View File

@@ -0,0 +1,442 @@
# client_async.py
#
# Copyright 2018-2020 Geaaru <geaaru<@>gmail.com>
__docformat__ = "epytext en"
from datetime import datetime
import asyncio
import logging
import random
from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket
class DatagramProtocolClient(asyncio.Protocol):
def __init__(self, server, port, logger, client, retries=3, timeout=30):
self.transport = None
self.port = port
self.server = server
self.logger = logger
self.retries = retries
self.timeout = timeout
self.client = client
# Map of pending requests
self.pending_requests = {}
# Use cryptographic-safe random generator as provided by the OS.
random_generator = random.SystemRandom()
self.packet_id = random_generator.randrange(0, 256)
self.timeout_future = None
async def __timeout_handler__(self):
try:
while True:
req2delete = []
now = datetime.now()
next_weak_up = self.timeout
# noinspection PyShadowingBuiltins
for id, req in self.pending_requests.items():
secs = (req["send_date"] - now).seconds
if secs > self.timeout:
if req["retries"] == self.retries:
self.logger.debug(
"[%s:%d] For request %d execute all retries",
self.server,
self.port,
id,
)
req["future"].set_exception(
TimeoutError("Timeout on Reply")
)
req2delete.append(id)
else:
# Send again packet
req["send_date"] = now
req["retries"] += 1
self.logger.debug(
"[%s:%d] For request %d execute retry %d",
self.server,
self.port,
id,
req["retries"],
)
self.transport.sendto(req["packet"].RequestPacket())
elif next_weak_up > secs:
next_weak_up = secs
# noinspection PyShadowingBuiltins
for id in req2delete:
# Remove request for map
del self.pending_requests[id]
await asyncio.sleep(next_weak_up)
except asyncio.CancelledError:
pass
def send_packet(self, packet, future):
if packet.id in self.pending_requests:
raise Exception(f"Packet with id {packet.id} already present")
# Store packet on pending requests map
self.pending_requests[packet.id] = {
"packet": packet,
"creation_date": datetime.now(),
"retries": 0,
"future": future,
"send_date": datetime.now(),
}
# In queue packet raw on socket buffer
self.transport.sendto(packet.RequestPacket())
def connection_made(self, transport):
self.transport = transport
socket = transport.get_extra_info("socket")
self.logger.info(
"[%s:%d] Transport created with binding in %s:%d",
self.server,
self.port,
socket.getsockname()[0],
socket.getsockname()[1],
)
pre_loop = asyncio.get_event_loop()
asyncio.set_event_loop(loop=self.client.loop)
# Start asynchronous timer handler
self.timeout_future = asyncio.ensure_future(self.__timeout_handler__())
asyncio.set_event_loop(loop=pre_loop)
def error_received(self, exc):
self.logger.error(
"[%s:%d] Error received: %s", self.server, self.port, exc
)
def connection_lost(self, exc):
if exc:
self.logger.warn(
"[%s:%d] Connection lost: %s", self.server, self.port, str(exc)
)
else:
self.logger.info("[%s:%d] Transport closed", self.server, self.port)
# noinspection PyUnusedLocal
def datagram_received(self, data, addr):
try:
req = self.pending_requests[data[0]]
reply = req.VerifyPacket(data)
req["future"].set_result(reply)
# Remove request for map
del self.pending_requests[reply.id]
except KeyError:
self.logger.warn(
"[%s:%d] Ignore invalid reply: %s", self.server, self.port, data
)
except PacketError as exc:
self.logger.error(
"[%s:%d] Error on decode or verify packet: %s",
self.server,
self.port,
exc,
)
async def close_transport(self):
if self.transport:
self.logger.debug(
"[%s:%d] Closing transport...", self.server, self.port
)
self.transport.close()
self.transport = None
if self.timeout_future:
self.timeout_future.cancel()
await self.timeout_future
self.timeout_future = None
def create_id(self):
self.packet_id = (self.packet_id + 1) % 256
return self.packet_id
def __str__(self):
return (
f"DatagramProtocolClient(server?={self.server}, port={self.port})"
)
# Used as protocol_factory
def __call__(self):
return self
class ClientAsync:
"""Basic RADIUS client.
This class implements a basic RADIUS client. It can send requests
to a RADIUS server, taking care of timeouts and retries, and
validate its replies.
:ivar retries: number of times to retry sending a RADIUS request
:type retries: integer
:ivar timeout: number of seconds to wait for an answer
:type timeout: integer
"""
# noinspection PyShadowingBuiltins
def __init__(
self,
server,
auth_port=1812,
acct_port=1813,
coa_port=3799,
secret=b"",
dict=None,
loop=None,
retries=3,
timeout=30,
logger_name="pyrad",
):
"""Constructor.
:param server: hostname or IP address of RADIUS server
:type server: string
:param auth_port: port to use for authentication packets
:type auth_port: integer
:param acct_port: port to use for accounting packets
:type acct_port: integer
:param coa_port: port to use for CoA packets
:type coa_port: integer
:param secret: RADIUS secret
:type secret: string
:param dict: RADIUS dictionary
:type dict: pyrad.dictionary.Dictionary
:param loop: Python loop handler
:type loop: asyncio event loop
"""
if not loop:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
self.logger = logging.getLogger(logger_name)
self.server = server
self.secret = secret
self.retries = retries
self.timeout = timeout
self.dict = dict
self.auth_port = auth_port
self.protocol_auth = None
self.acct_port = acct_port
self.protocol_acct = None
self.protocol_coa = None
self.coa_port = coa_port
async def initialize_transports(
self,
enable_acct=False,
enable_auth=False,
enable_coa=False,
local_addr=None,
local_auth_port=None,
local_acct_port=None,
local_coa_port=None,
):
task_list = []
if not enable_acct and not enable_auth and not enable_coa:
raise Exception("No transports selected")
if enable_acct and not self.protocol_acct:
self.protocol_acct = DatagramProtocolClient(
self.server,
self.acct_port,
self.logger,
self,
retries=self.retries,
timeout=self.timeout,
)
bind_addr = None
if local_addr and local_acct_port:
bind_addr = (local_addr, local_acct_port)
acct_connect = self.loop.create_datagram_endpoint(
self.protocol_acct,
reuse_port=True,
remote_addr=(self.server, self.acct_port),
local_addr=bind_addr,
)
task_list.append(acct_connect)
if enable_auth and not self.protocol_auth:
self.protocol_auth = DatagramProtocolClient(
self.server,
self.auth_port,
self.logger,
self,
retries=self.retries,
timeout=self.timeout,
)
bind_addr = None
if local_addr and local_auth_port:
bind_addr = (local_addr, local_auth_port)
auth_connect = self.loop.create_datagram_endpoint(
self.protocol_auth,
reuse_port=True,
remote_addr=(self.server, self.auth_port),
local_addr=bind_addr,
)
task_list.append(auth_connect)
if enable_coa and not self.protocol_coa:
self.protocol_coa = DatagramProtocolClient(
self.server,
self.coa_port,
self.logger,
self,
retries=self.retries,
timeout=self.timeout,
)
bind_addr = None
if local_addr and local_coa_port:
bind_addr = (local_addr, local_coa_port)
coa_connect = self.loop.create_datagram_endpoint(
self.protocol_coa,
reuse_port=True,
remote_addr=(self.server, self.coa_port),
local_addr=bind_addr,
)
task_list.append(coa_connect)
await asyncio.ensure_future(
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
)
# noinspection SpellCheckingInspection
async def deinitialize_transports(
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
):
if self.protocol_coa and deinit_coa:
await self.protocol_coa.close_transport()
del self.protocol_coa
self.protocol_coa = None
if self.protocol_auth and deinit_auth:
await self.protocol_auth.close_transport()
del self.protocol_auth
self.protocol_auth = None
if self.protocol_acct and deinit_acct:
await self.protocol_acct.close_transport()
del self.protocol_acct
self.protocol_acct = None
# noinspection PyPep8Naming
def CreateAuthPacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS packet which can
be used to communicate with the RADIUS server this client
talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.Packet
"""
if not self.protocol_auth:
raise Exception("Transport not initialized")
return AuthPacket(
dict=self.dict,
id=self.protocol_auth.create_id(),
secret=self.secret,
**args,
)
# noinspection PyPep8Naming
def CreateAcctPacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS packet which can
be used to communicate with the RADIUS server this client
talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.Packet
"""
if not self.protocol_acct:
raise Exception("Transport not initialized")
return AcctPacket(
id=self.protocol_acct.create_id(),
dict=self.dict,
secret=self.secret,
**args,
)
# noinspection PyPep8Naming
def CreateCoAPacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS packet which can
be used to communicate with the RADIUS server this client
talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.Packet
"""
if not self.protocol_acct:
raise Exception("Transport not initialized")
return CoAPacket(
id=self.protocol_coa.create_id(),
dict=self.dict,
secret=self.secret,
**args,
)
# noinspection PyPep8Naming
# noinspection PyShadowingBuiltins
def CreatePacket(self, id, **args):
if not id:
raise Exception("Missing mandatory packet id")
return Packet(id=id, dict=self.dict, secret=self.secret, **args)
# noinspection PyPep8Naming
def SendPacket(self, pkt):
"""Send a packet to a RADIUS server.
:param pkt: the packet to send
:type pkt: pyrad.packet.Packet
:return: Future related with packet to send
:rtype: asyncio.Future
"""
ans = asyncio.Future(loop=self.loop)
if isinstance(pkt, AuthPacket):
if not self.protocol_auth:
raise Exception("Transport not initialized")
self.protocol_auth.send_packet(pkt, ans)
elif isinstance(pkt, AcctPacket):
if not self.protocol_acct:
raise Exception("Transport not initialized")
self.protocol_acct.send_packet(pkt, ans)
elif isinstance(pkt, CoAPacket):
if not self.protocol_coa:
raise Exception("Transport not initialized")
self.protocol_coa.send_packet(pkt, ans)
else:
raise Exception("Unsupported packet")
return ans

483
src/pyrad3/dictionary.py Normal file
View File

@@ -0,0 +1,483 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""RADIUS Dictionary.
Classes and Types to parse and represent a RADIUS dictionary.
"""
from enum import IntEnum, Enum, auto
from dataclasses import dataclass
from os.path import dirname, isabs, join, normpath
from typing import Dict, Generator, IO, List, Optional, Sequence, Tuple, Union
import logging
LOG = logging.getLogger(__name__)
INTEGER_TYPES = {
"byte": (0, 255),
"short": (0, 2 ** 16 - 1),
"signed": (-(2 ** 31), 2 ** 31 - 1),
"integer": (0, 2 ** 32 - 1),
"integer64": (0, 2 ** 64 - 1),
}
class Datatype(Enum):
"""Possible Datatypes for ATTRIBUTES"""
string = auto()
octets = auto()
date = auto()
abinary = auto()
byte = auto()
short = auto()
integer = auto()
signed = auto()
integer64 = auto()
ipaddr = auto()
ipv4prefix = auto()
ipv6addr = auto()
ipv6prefix = auto()
comboip = auto()
ifid = auto()
ether = auto()
concat = auto()
tlv = auto()
extended = auto()
longextended = auto()
evs = auto()
class ParseError(Exception):
"""RADIUS Dictionary Parser Error"""
def __init__(
self,
filename: str,
msg: str = None,
line: Optional[int] = None,
**data,
):
super().__init__()
self.msg = msg
self.file = filename
self.line = line
self.data = data
def __str__(self):
line = f"({self.line}" if self.line is not None else ""
return f"{self.file}{line}: ParseError: {self.msg}"
class Encrypt(IntEnum):
"""Enum for different RADIUS Encryption types."""
NoEncrpytion = 0
RadiusCrypt = 1
SaltCrypt = 2
AscendCrypt = 3
@dataclass
class Attribute: # pylint: disable=too-many-instance-attributes
"""RADIUS Attribute definition"""
name: str
code: int
datatype: Datatype
has_tag: bool = False
encrypt: Encrypt = Encrypt(0)
is_sub_attr: bool = False
# vendor = Dictionary
values: Dict[Union[int, str], Union[int, str]] = None
@dataclass
class Vendor:
"""Representation of a vendor"""
name: str
code: int
tlength: int
llength: int
continuation: bool
attrs: Dict[Union[int, Tuple[int, ...]], Attribute]
def dict_parser(
filename: str, rad_dict: IO
) -> Generator[Tuple[int, List[str]], None, None]:
"""Tokenstream of RADIUS Dictionary files
Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED"
and "FILE_CLOSED" tokens will be emitted.
"""
yield (-1, ["FILE_OPENED", filename])
for line_num, line in enumerate(rad_dict.readlines()):
tokens = line.split("#", 1)[0].strip().split()
if tokens:
first_tok = tokens[0] = tokens[0].upper()
if first_tok == "$INCLUDE":
try:
inner_filename = tokens[1]
except IndexError:
raise ParseError(
filename, "$INCLUDE is missing a filename", line_num,
)
if not isabs(tokens[1]):
path = dirname(filename)
inner_filename = normpath(join(path, inner_filename))
yield from dict_loader(inner_filename)
yield (line_num, tokens)
yield (-1, ["FILE_CLOSED"])
def dict_loader(filename: str) -> Generator[Tuple[int, List[str]], None, None]:
"""Tokenstream of RADIUS Dictionary files
Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED"
and "FILE_CLOSED" tokens will be emitted.
"""
with open(filename, "r") as rad_dict:
yield from dict_parser(filename, rad_dict)
def _parse_number(num: str) -> int:
"""Parse a number from (Free)RADIUS dictionaries
Numbers can be either decimal, octal, or hexadecimal.
"""
if num.startswith("0x"):
return int(num, 16)
if num.startswith("0o"):
return int(num, 8)
return int(num)
def _parse_attribute_code(attr_code: str) -> List[int]:
"""Parse attribute codes from (Free)RADIUS dictionaries
Codes can be either decimal, octal, or hexadecimal.
TLV typed can
"""
codes = []
for code in attr_code.split("."):
codes.append(_parse_number(code))
return codes
class Dictionary:
"""(Free)RADIUS Dictionary.
#TODO: Better documentation
This dictionary can "contain" multiple dictionaries.
"""
# there must be some nicer way to unittest this...
def __init__(self, dictionary: str, __dictio: Optional[IO] = None):
self.vendor: Dict[int, Vendor] = {}
self.vendor_lookup_id_by_name: Dict[str, int] = {}
self.attrindex: Dict[Union[int, str], Attribute] = {}
self.rfc_vendor = Vendor("RFC", 0, 1, 1, False, {})
self.cur_vendor = self.rfc_vendor
if __dictio is not None:
loader = dict_parser(dictionary, __dictio)
else:
loader = dict_loader(dictionary)
self.read_dictionary(loader)
def read_dictionary(
self, reader: Generator[Tuple[int, List[str]], None, None]
):
"""Read and parse a (Free)RADIUS dictionary."""
self.filestack = []
for line_num, tokens in reader:
key = tokens[0]
if key == "ATTRIBUTE":
self._parse_attribute(tokens, line_num)
elif key == "VALUE":
self._parse_value(tokens, line_num)
elif key == "FILE_OPENED":
LOG.info("Parsing file: %s", tokens[1])
if tokens[1] in self.filestack:
raise ParseError(
self.filestack[-1], "Include recursion detected"
)
self.filestack.append(tokens[1])
elif key == "FILE_CLOSED":
filename = self.filestack.pop()
LOG.info("Finished parsing file: %s", filename)
elif key == "VENDOR":
self._parse_vendor(tokens, line_num)
elif key == "BEGIN-VENDOR":
self._parse_begin_vendor(tokens, line_num)
elif key == "END-VENDOR":
self._parse_end_vendor(tokens, line_num)
elif key == "BEGIN-TLV":
raise NotImplementedError(
"BEGIN-TLV is deprecated and not supported by pyrad3"
)
elif key == "END-TLV":
raise NotImplementedError(
"END-TLV is deprecated and not supported by pyrad3"
)
else:
raise ParseError(
self.filestack[-1], f"Invalid Token key {key}", line_num
)
def _parse_vendor(self, tokens: Sequence[str], line_num: int):
"""Parse the vendor definition"""
filename = self.filestack[-1]
if len(tokens) not in {3, 4}:
raise ParseError(
filename, "Incorrect number of tokens for vendor statement"
)
vendor_name = tokens[1]
vendor_id = int(tokens[2], 0)
continuation = False
# Parse optional vendor specification
try:
vendor_format = tokens[3].split("=")
if vendor_format[0] != "format":
raise ParseError(
filename,
f"Unknown option {vendor_format[0]} for vendor definition",
line_num,
)
try:
vendor_format = vendor_format[1].split(",")
t_len, l_len = (int(a) for a in vendor_format[:2])
if t_len not in {1, 2, 4}:
raise ParseError(
filename,
f'Invalid type length definition "{t_len}" for vendor {vendor_name}',
line_num,
)
if l_len not in {0, 1, 2}:
raise ParseError(
filename,
f'Invalid length definition "{l_len}" for vendor {vendor_name}',
line_num,
)
try:
if vendor_format[2] == "c":
if not vendor_name.upper() == "WIMAX":
# Not sure why, but FreeRADIUS has this limit,
# so we just do the same cause they know better than me
raise ParseError(
filename,
"continuation-bit is only supported for WiMAX",
line_num,
)
continuation = True
except IndexError:
pass
except ValueError:
raise ParseError(
filename,
f"Syntax error in specification for vendor {vendor_name}",
line_num,
)
except IndexError:
# no format definition
t_len, l_len = 1, 1
vendor = Vendor(vendor_name, vendor_id, t_len, l_len, continuation, {})
self.vendor_lookup_id_by_name[vendor_name] = vendor_id
self.vendor[vendor_id] = vendor
def _parse_begin_vendor(self, tokens: Sequence[str], line_num: int):
"""Parse the BEGIN-VENDOR line of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
if self.cur_vendor != self.rfc_vendor:
raise ParseError(
filename,
"vendor-begin sections are not allowed to be nested",
line_num,
)
if len(tokens) != 2:
raise ParseError(
filename,
"Incorrect number of tokens for begin-vendor statement",
line_num,
)
try:
vendor_id = self.vendor_lookup_id_by_name[tokens[1]]
self.cur_vendor = self.vendor[vendor_id]
except KeyError:
raise ParseError(
filename,
f"Unknown vendor {tokens[1]} in begin-vendor statement",
line_num,
)
def _parse_end_vendor(self, tokens: Sequence[str], line_num: int):
"""Parse the END-VENDOR line of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
if len(tokens) != 2:
raise ParseError(
filename,
"Incorrect number of tokens for end-vendor statement",
line_num,
)
if self.cur_vendor.name != tokens[1]:
raise ParseError(
filename,
f"Closing non-opened vendor {tokens[1]} in end-vendor statement",
line_num,
)
self.cur_vendor = self.rfc_vendor
def _parse_attribute_flags(
self, tokens: Sequence[str], line_num: int
) -> Tuple[bool, Encrypt]:
"""Parse Attribute flags of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
has_tag = False
encrypt = Encrypt.NoEncrpytion
try:
flags = [flag.split("=") for flag in tokens[4].split(",")]
except IndexError:
return False, Encrypt.NoEncrpytion
for flag in flags:
flag_len = len(flag)
if flag == 1:
value = None
elif flag_len == 2:
value = flag[1]
else:
raise ParseError(
filename, f"Incorrect attribute flag {flag}", line_num
)
key = flag[0]
if key == "has_tag":
has_tag = True
elif key == "encrypt":
try:
encrypt = Encrypt(int(value)) # type: ignore
except (ValueError, TypeError):
raise ParseError(
filename,
f"Illegal attribute encryption {value}",
line_num,
)
else:
raise ParseError(
filename, "Unknown attribute flag {key}", line_num
)
return has_tag, encrypt
def _parse_attribute(self, tokens: Sequence[str], line_num: int):
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
if not len(tokens) in {4, 5}:
raise ParseError(
filename,
"Incorrect number of tokens for attribute definition",
line_num,
)
has_tag, encrypt = self._parse_attribute_flags(tokens, line_num)
name, code, datatype = tokens[1:4]
if datatype == "concat" and self.cur_vendor != self.rfc_vendor:
raise ParseError(
filename,
'vendor attributes are not allowed to have the datatype "concat"',
line_num,
)
try:
codes = _parse_attribute_code(code)
except ValueError:
raise ParseError(
filename, f'invalid attribute code {code}""', line_num
)
for code in codes:
tlength = self.cur_vendor.tlength
if 2 ** (8 * tlength) <= code:
raise ParseError(
filename,
f"attribute code is too big, must be smaller than 2**{tlength}",
line_num,
)
if code < 0:
raise ParseError(
filename,
"negative attribute codes are not allowed",
line_num,
)
# TODO: Do we some explicit handling of tlvs?
# if len(codes) > 1:
# self._parse_attribute_tlv(codes, line_num)
# else:
# pass
base_datatype = datatype.split("[")[0].replace("-", "")
try:
attribute_type = Datatype[base_datatype]
except KeyError:
raise ParseError(filename, f"Illegal type: {datatype}", line_num)
attribute = Attribute(
name,
codes[-1],
attribute_type,
has_tag,
encrypt,
len(codes) > 1,
{},
)
attrcode = codes[0] if len(codes) == 1 else tuple(codes)
self.cur_vendor.attrs[attrcode] = attribute
if self.cur_vendor != self.rfc_vendor:
codes = tuple([26] + codes)
attrcode = codes[0] if len(codes) == 1 else tuple(codes)
self.attrindex[attrcode] = attribute
self.attrindex[name] = attribute
def _parse_value(self, tokens: Sequence[str], line_num: int):
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
if len(tokens) != 4:
raise ParseError(
filename,
"Incorrect number of tokens for VALUE definition",
line_num,
)
(attr_name, key, value) = tokens[1:]
value = _parse_number(value)
attribute = self.attrindex[attr_name]
try:
datatype = str(attribute.datatype).split(".")[1]
lmin, lmax = INTEGER_TYPES[datatype]
if value < lmin or value > lmax:
raise ParseError(
filename,
f"VALUE {key}({value}) is not in the limit of type {datatype}",
line_num,
)
except KeyError:
raise ParseError(
filename,
f"only attributes with integer typed datatypes can have"
f"value definitions {attribute.datatype}",
line_num,
)
attribute.values[value] = key
attribute.values[key] = value

46
src/pyrad3/host.py Normal file
View File

@@ -0,0 +1,46 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Interface Class for RADIUS Clients and Servers"""
from pyrad3.dictionary import Dictionary
from pyrad3 import packet
class Host: # pylint: disable=too-many-arguments
"""Interface Class for RADIUS Clients and Servers"""
def __init__(
self,
secret: bytes,
radius_dict: Dictionary,
authport: int = 1812,
acctport: int = 1813,
coaport: int = 3799,
timeout: float = 30,
retries: int = 3,
):
self.secret = secret
self.dictionary = radius_dict
self.authport = authport
self.acctport = acctport
self.coaport = coaport
self.timeout = timeout
self.retries = retries
def create_packet(self, **kwargs):
"""Create a generic RADIUS Packet"""
return packet.Packet(self, **kwargs)
def create_auth_packet(self, **kwargs):
"""Create an Authentictaion packet (request per default)"""
return packet.AuthPacket(self, **kwargs)
def create_acct_packet(self, **kwargs):
"""Create an Accounting packet (request per default)"""
return packet.AcctPacket(self, **kwargs)
def create_coa_packet(self, **kwargs):
"""Create an CoA packet (requset per default)"""
return packet.CoAPacket(self, **kwargs)

305
src/pyrad3/packet.py Normal file
View File

@@ -0,0 +1,305 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
from collections import OrderedDict
from enum import IntEnum
from secrets import token_bytes
from typing import Any, Dict, Optional, Sequence
import hashlib
import hmac
from pyrad3.host import Host
from pyrad3.utils import (
PacketError,
Attribute,
parse_header,
parse_attributes,
calculate_authenticator,
validate_pap_password,
validate_chap_password,
)
HMAC = hmac.new
# Packet codes
class Code(IntEnum):
AccessRequest = 1
AccessAccept = 2
AccessReject = 3
AccountingRequest = 4
AccountingResponse = 5
AccessChallenge = 11
StatusServer = 12
StatusClient = 13
DisconnectRequest = 40
DisconnectACK = 41
DisconnectNAK = 42
CoARequest = 43
CoAACK = 44
CoANAK = 45
class AuthError(Exception):
pass
class Packet(OrderedDict):
def __init__(
self,
host: Host,
code: Code,
radius_id: int,
*,
request: "Packet" = None,
**attributes
):
super().__init__(**attributes)
self.code = code
self.id = radius_id
self.host = host
self.request = request
self.ordered_attributes: Sequence[Attribute] = []
self.raw_packet: Optional[bytearray] = None
self.authenticator: Optional[bytes] = None
@staticmethod
def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
(code, radius_id, _length, authenticator) = parse_header(raw_packet)
ordered_attrs = parse_attributes(host.dictionary, raw_packet)
# Can we do better than Any with type hinting?
attrs: Dict[str, Any] = {}
for attr in ordered_attrs:
try:
attrs[attr.name].append(attr.value)
except KeyError:
attrs[attr.name] = [attr.value]
parsed_packet = Packet(host, code, radius_id, **attrs)
parsed_packet.authenticator = authenticator
parsed_packet.raw_packet = raw_packet
parsed_packet.ordered_attributes = ordered_attrs
return parsed_packet
def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
self.verify_reply(raw_packet)
reply = Packet.from_raw(self.host, raw_packet)
reply.request = self
try:
if not reply.validate_message_authenticator():
raise PacketError("Packet has a wrong message authenticator")
except KeyError:
if "EAP-Message" in reply:
raise PacketError("Packet is missing a message authenticator")
return reply
def send(self):
"""Send the packet to the Client/Server.
"""
self.host._send_packet(self)
def verify_reply(self, raw_reply: bytes):
"""Verify the reply to this packet.
"""
if self.id != raw_reply[1]:
raise PacketError("Response has a wrong id")
# self.authenticator MUST be set, this packet got send so by definitation
# self.authenticator will not be non, but bytes
radius_hash = calculate_authenticator(
self.host.secret,
self.authenticator, # type: ignore
raw_reply,
)
if radius_hash != raw_reply[4:20]:
raise PacketError("Reply Packet has a wrong authenticator")
def validate_message_authenticator(self):
message_authenticator = self["Message-Authenticator"]
if isinstance(list, message_authenticator):
# There are multiple Message Authenticators, but a packet MUST NOT have
# more than one
return False
ma_attribute = self.find_first_attribute("Message-Authenticator")
generated = self._generate_message_authenticator(ma_attribute)
return message_authenticator == generated
def _generate_message_authenticator(self, ma_attr: Attribute):
assert self.authenticator is not None
assert self.request is not None
assert self.request.authenticator is not None
assert self.raw_packet is not None
# The message authenticator must be treated as 16 * \00
start_pos = ma_attr.pos + 2
end_pos = start_pos + 16
original_ma: bytes = ma_attr.value
self.raw_packet[start_pos:end_pos] = 16 * b"\00"
hmac_builder = HMAC(self.host.secret, digestmod=hashlib.md5)
hmac_builder.update(self.raw_packet)
if self.code in (Code.AccessRequest, Code.StatusServer):
hmac_builder.update(self.authenticator)
elif self.code in (
Code.AccessAccept,
Code.AccessChallenge,
Code.AccessReject,
):
hmac_builder.update(self.request.authenticator)
else:
hmac_builder.update(16 * b"\00")
hmac_builder.update(self.raw_packet[20:])
self.raw_packet[start_pos:end_pos] = original_ma
return hmac_builder.digest()
def add_message_authenticator(self):
self._encode_packet()
self._generate_message_authenticator(self)
try:
# quick lookup before we iterate over the whole packet
_ = self["Message-Authenticator"]
attr = self.find_first_attribute("Message-Authenticator")
except KeyError:
self["Message-Authenticator"] = 16 * b"\00"
attr = self.ordered_attributes[-1]
generated = self._generate_message_authenticator(attr)
self[attr.pos + 2 :] = generated
def refresh_message_authenticator(self):
self.add_message_authenticator()
def find_first_attribute(self, attr_type_name: str) -> Attribute:
for attr in self.ordered_attributes:
if attr.type == attr_type_name:
return attr.type
raise KeyError
def _encode_packet(self):
self.raw_packet = None
class AuthPacket(Packet):
def __init__(
self,
host: Host,
radius_id: int,
auth_type,
*,
code: Code = Code.AccessRequest,
request: Optional[Packet] = None,
**attributes
):
super().__init__(host, code, radius_id, request=request, **attributes)
self.auth_type = auth_type
if code == Code.AccessRequest:
self.authenticator = token_bytes(16)
def create_accept(self, **attributes):
return AuthPacket(
self.host,
self.id,
self.auth_type,
request=self,
code=Code.AccessAccept,
**attributes
)
def create_reject(self, **attributes):
return AuthPacket(
self.host,
self.id,
self.auth_type,
request=self,
code=Code.AccessReject,
**attributes
)
def create_challange(self, **attributes):
return AuthPacket(
self.host,
self.id,
self.auth_type,
request=self,
code=Code.AccessChallenge,
**attributes
)
def validate_password(self, password: bytes) -> bool:
try:
return self.validate_pap(password)
except KeyError:
pass
# Will throw KeyError if no chap password exists
return self.validate_chap(password)
def validate_pap(self, password: bytes) -> bool:
packet_password = self["User-Password"]
return validate_pap_password(
self.host.secret,
self.authenticator, # type: ignore
packet_password,
password,
)
def validate_chap(self, password: bytes) -> bool:
packet_password = self["Chap-Password"]
chap_id = packet_password[:1]
chap_password = packet_password[1:]
try:
challenge = self["Chap-Challenge"]
except KeyError:
challenge = self.authenticator
return validate_chap_password(
chap_id, challenge, chap_password, password, # type: ignore
)
class AcctPacket(Packet):
def __init__(
self,
host: Host,
radius_id: int,
*,
code: Code = Code.AccountingRequest,
request: Optional[Packet] = None,
**attributes
):
super().__init__(host, code, radius_id, request=request, **attributes)
def create_response(self, **attributes):
return AcctPacket(
self.host,
self.id,
code=Code.AccountingResponse,
request=self,
**attributes
)
class CoAPacket(Packet):
def __init__(
self,
host: Host,
radius_id: int,
*,
code: Code = Code.CoARequest,
request: Optional[Packet] = None,
**attributes
):
super().__init__(host, code, radius_id, request=request, **attributes)
def create_ack(self, **attributes):
return CoAPacket(
self.host, self.id, code=Code.CoAACK, request=self, **attributes
)
def create_nack(self, **attributes):
return CoAPacket(
self.host, self.id, code=Code.CoANAK, request=self, **attributes
)

73
src/pyrad3/proxy.py Normal file
View File

@@ -0,0 +1,73 @@
# proxy.py
#
# Copyright 2005,2007 Wichert Akkerman <wichert@wiggy.net>
#
# A RADIUS proxy as defined in RFC 2138
import select
import socket
from pyrad.server import Server, ServerPacketError
from pyrad import packet
class Proxy(Server):
"""Base class for RADIUS proxies.
This class extends tha RADIUS server class with the capability to
handle communication with other RADIUS servers as well.
:ivar _proxyfd: network socket used to communicate with other servers
:type _proxyfd: socket class instance
"""
def _prepare_sockets(self):
Server._prepare_sockets(self)
self._proxyfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._fdmap[self._proxyfd.fileno()] = self._proxyfd
self._poll.register(
self._proxyfd.fileno(),
(select.POLLIN | select.POLLPRI | select.POLLERR),
)
def _handle_proxy_packet(self, pkt):
"""Process a packet received on the reply socket.
If this packet should be dropped instead of processed a
:obj:`ServerPacketError` exception should be raised. The main loop
will drop the packet and log the reason.
:param pkt: packet to process
:type pkt: Packet class instance
"""
if pkt.source[0] not in self.hosts:
raise ServerPacketError("Received packet from unknown host")
pkt.secret = self.hosts[pkt.source[0]].secret
if pkt.code not in [
packet.AccessAccept,
packet.AccessReject,
packet.AccountingResponse,
]:
raise ServerPacketError("Received non-response on proxy socket")
def _process_input(self, fd):
"""Process available data.
If this packet should be dropped instead of processed a
`ServerPacketError` exception should be raised. The main loop
will drop the packet and log the reason.
This function calls either :obj:`HandleAuthPacket`,
:obj:`HandleAcctPacket` or :obj:`_handle_proxy_packet` depending on
which socket is being processed.
:param fd: socket to read packet from
:type fd: socket class instance
:param pkt: packet to process
:type pkt: Packet class instance
"""
if fd.fileno() == self._proxyfd.fileno():
pkt = self._grab_packet(
lambda data, s=self: s.CreatePacket(packet=data), fd
)
self._handle_proxy_packet(pkt)
else:
Server._process_input(self, fd)

367
src/pyrad3/server.py Normal file
View File

@@ -0,0 +1,367 @@
# server.py
#
# Copyright 2003-2004,2007,2016 Wichert Akkerman <wichert@wiggy.net>
import logging
import select
import socket
from pyrad import host
from pyrad import packet
LOGGER = logging.getLogger("pyrad")
class RemoteHost:
"""Remote RADIUS capable host we can talk to."""
def __init__(
self, address, secret, name, authport=1812, acctport=1813, coaport=3799
):
"""Constructor.
:param address: IP address
:type address: string
:param secret: RADIUS secret
:type secret: string
:param name: short name (used for logging only)
:type name: string
:param authport: port used for authentication packets
:type authport: integer
:param acctport: port used for accounting packets
:type acctport: integer
:param coaport: port used for CoA packets
:type coaport: integer
"""
self.address = address
self.secret = secret
self.authport = authport
self.acctport = acctport
self.coaport = coaport
self.name = name
class ServerPacketError(Exception):
"""Exception class for bogus packets.
ServerPacketError exceptions are only used inside the Server class to
abort processing of a packet.
"""
class Server(host.Host):
"""Basic RADIUS server.
This class implements the basics of a RADIUS server. It takes care
of the details of receiving and decoding requests; processing of
the requests should be done by overloading the appropriate methods
in derived classes.
:ivar hosts: hosts who are allowed to talk to us
:type hosts: dictionary of Host class instances
:ivar _poll: poll object for network sockets
:type _poll: select.poll class instance
:ivar _fdmap: map of filedescriptors to network sockets
:type _fdmap: dictionary
:cvar MaxPacketSize: maximum size of a RADIUS packet
:type MaxPacketSize: integer
"""
MaxPacketSize = 4096
def __init__(
self,
addresses=[],
authport=1812,
acctport=1813,
coaport=3799,
hosts=None,
dict=None,
auth_enabled=True,
acct_enabled=True,
coa_enabled=False,
):
"""Constructor.
:param addresses: IP addresses to listen on
:type addresses: sequence of strings
:param authport: port to listen on for authentication packets
:type authport: integer
:param acctport: port to listen on for accounting packets
:type acctport: integer
:param coaport: port to listen on for CoA packets
:type coaport: integer
:param hosts: hosts who we can talk to
:type hosts: dictionary mapping IP to RemoteHost class instances
:param dict: RADIUS dictionary to use
:type dict: Dictionary class instance
:param auth_enabled: enable auth server (default True)
:type auth_enabled: bool
:param acct_enabled: enable accounting server (default True)
:type acct_enabled: bool
:param coa_enabled: enable coa server (default False)
:type coa_enabled: bool
"""
host.Host.__init__(self, authport, acctport, coaport, dict)
if hosts is None:
self.hosts = {}
else:
self.hosts = hosts
self.auth_enabled = auth_enabled
self.authfds = []
self.acct_enabled = acct_enabled
self.acctfds = []
self.coa_enabled = coa_enabled
self.coafds = []
for addr in addresses:
self.BindToAddress(addr)
def _get_addr_info(self, addr):
"""Use getaddrinfo to lookup all addresses for each address.
Returns a list of tuples or an empty list:
[(family, address)]
:param addr: IP address to lookup
:type addr: string
"""
results = set()
try:
tmp = socket.getaddrinfo(addr, "www")
except socket.gaierror:
return []
for el in tmp:
results.add((el[0], el[4][0]))
return results
def BindToAddress(self, addr):
"""Add an address to listen to.
An empty string indicated you want to listen on all addresses.
:param addr: IP address to listen on
:type addr: string
"""
addrFamily = self._get_addr_info(addr)
for (family, address) in addrFamily:
if self.auth_enabled:
authfd = socket.socket(family, socket.SOCK_DGRAM)
authfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
authfd.bind((address, self.authport))
self.authfds.append(authfd)
if self.acct_enabled:
acctfd = socket.socket(family, socket.SOCK_DGRAM)
acctfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
acctfd.bind((address, self.acctport))
self.acctfds.append(acctfd)
if self.coa_enabled:
coafd = socket.socket(family, socket.SOCK_DGRAM)
coafd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
coafd.bind((address, self.coaport))
self.coafds.append(coafd)
def HandleAuthPacket(self, pkt):
"""Authentication packet handler.
This is an empty function that is called when a valid
authentication packet has been received. It can be overriden in
derived classes to add custom behaviour.
:param pkt: packet to process
:type pkt: Packet class instance
"""
def HandleAcctPacket(self, pkt):
"""Accounting packet handler.
This is an empty function that is called when a valid
accounting packet has been received. It can be overriden in
derived classes to add custom behaviour.
:param pkt: packet to process
:type pkt: Packet class instance
"""
def HandleCoaPacket(self, pkt):
"""CoA packet handler.
This is an empty function that is called when a valid
accounting packet has been received. It can be overriden in
derived classes to add custom behaviour.
:param pkt: packet to process
:type pkt: Packet class instance
"""
def HandleDisconnectPacket(self, pkt):
"""CoA packet handler.
This is an empty function that is called when a valid
accounting packet has been received. It can be overriden in
derived classes to add custom behaviour.
:param pkt: packet to process
:type pkt: Packet class instance
"""
def _add_secret(self, pkt):
"""Add secret to packets received and raise ServerPacketError
for unknown hosts.
:param pkt: packet to process
:type pkt: Packet class instance
"""
if pkt.source[0] in self.hosts:
pkt.secret = self.hosts[pkt.source[0]].secret
elif "0.0.0.0" in self.hosts:
pkt.secret = self.hosts["0.0.0.0"].secret
else:
raise ServerPacketError("Received packet from unknown host")
def _handle_auth_packet(self, pkt):
"""Process a packet received on the authentication port.
If this packet should be dropped instead of processed a
ServerPacketError exception should be raised. The main loop will
drop the packet and log the reason.
:param pkt: packet to process
:type pkt: Packet class instance
"""
self._add_secret(pkt)
if pkt.code != packet.AccessRequest:
raise ServerPacketError(
"Received non-authentication packet on authentication port"
)
self.HandleAuthPacket(pkt)
def _handle_acct_packet(self, pkt):
"""Process a packet received on the accounting port.
If this packet should be dropped instead of processed a
ServerPacketError exception should be raised. The main loop will
drop the packet and log the reason.
:param pkt: packet to process
:type pkt: Packet class instance
"""
self._add_secret(pkt)
if pkt.code not in [
packet.AccountingRequest,
packet.AccountingResponse,
]:
raise ServerPacketError(
"Received non-accounting packet on accounting port"
)
self.HandleAcctPacket(pkt)
def _handle_coa_packet(self, pkt):
"""Process a packet received on the coa port.
If this packet should be dropped instead of processed a
ServerPacketError exception should be raised. The main loop will
drop the packet and log the reason.
:param pkt: packet to process
:type pkt: Packet class instance
"""
self._add_secret(pkt)
pkt.secret = self.hosts[pkt.source[0]].secret
if pkt.code == packet.CoARequest:
self.HandleCoaPacket(pkt)
elif pkt.code == packet.DisconnectRequest:
self.HandleDisconnectPacket(pkt)
else:
raise ServerPacketError("Received non-coa packet on coa port")
def _grab_packet(self, pktgen, fd):
"""Read a packet from a network connection.
This method assumes there is data waiting for to be read.
:param fd: socket to read packet from
:type fd: socket class instance
:return: RADIUS packet
:rtype: Packet class instance
"""
(data, source) = fd.recvfrom(self.MaxPacketSize)
pkt = pktgen(data)
pkt.source = source
pkt.fd = fd
return pkt
def _prepare_sockets(self):
"""Prepare all sockets to receive packets.
"""
for fd in self.authfds + self.acctfds + self.coafds:
self._fdmap[fd.fileno()] = fd
self._poll.register(
fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR
)
if self.auth_enabled:
self._realauthfds = list(map(lambda x: x.fileno(), self.authfds))
if self.acct_enabled:
self._realacctfds = list(map(lambda x: x.fileno(), self.acctfds))
if self.coa_enabled:
self._realcoafds = list(map(lambda x: x.fileno(), self.coafds))
def CreateReplyPacket(self, pkt, **attributes):
"""Create a reply packet.
Create a new packet which can be returned as a reply to a received
packet.
:param pkt: original packet
:type pkt: Packet instance
"""
reply = pkt.CreateReply(**attributes)
reply.source = pkt.source
return reply
def _process_input(self, fd):
"""Process available data.
If this packet should be dropped instead of processed a
PacketError exception should be raised. The main loop will
drop the packet and log the reason.
This function calls either HandleAuthPacket() or
HandleAcctPacket() depending on which socket is being
processed.
:param fd: socket to read packet from
:type fd: socket class instance
"""
if self.auth_enabled and fd.fileno() in self._realauthfds:
pkt = self._grab_packet(
lambda data, s=self: s.CreateAuthPacket(packet=data), fd
)
self._handle_auth_packet(pkt)
elif self.acct_enabled and fd.fileno() in self._realacctfds:
pkt = self._grab_packet(
lambda data, s=self: s.CreateAcctPacket(packet=data), fd
)
self._handle_acct_packet(pkt)
elif self.coa_enabled:
pkt = self._grab_packet(
lambda data, s=self: s.CreateCoAPacket(packet=data), fd
)
self._handle_coa_packet(pkt)
else:
raise ServerPacketError("Received packet for unknown handler")
def Run(self):
"""Main loop.
This method is the main loop for a RADIUS server. It waits
for packets to arrive via the network and calls other methods
to process them.
"""
self._poll = select.poll()
self._fdmap = {}
self._prepare_sockets()
while True:
for (fd, event) in self._poll.poll():
if event == select.POLLIN:
try:
fdo = self._fdmap[fd]
self._process_input(fdo)
except ServerPacketError as err:
LOGGER.info("Dropping packet: %s", err)
except packet.PacketError as err:
LOGGER.info("Received a broken packet: %s", err)
else:
LOGGER.error("Unexpected event in server main loop")

429
src/pyrad3/server_async.py Normal file
View File

@@ -0,0 +1,429 @@
# server_async.py
#
# Copyright 2018-2019 Geaaru <geaaru@gmail.com>
import asyncio
import logging
from abc import abstractmethod, ABCMeta
from enum import Enum
from datetime import datetime
from pyrad.packet import (
Packet,
AccessAccept,
AccessReject,
AccountingRequest,
AccountingResponse,
DisconnectACK,
DisconnectNAK,
DisconnectRequest,
CoARequest,
CoAACK,
CoANAK,
AccessRequest,
AuthPacket,
AcctPacket,
CoAPacket,
PacketError,
)
from pyrad.server import ServerPacketError
class ServerType(Enum):
Auth = "Authentication"
Acct = "Accounting"
Coa = "Coa"
class DatagramProtocolServer(asyncio.Protocol):
def __init__(
self, ip, port, logger, server, server_type, hosts, request_callback
):
self.transport = None
self.ip = ip
self.port = port
self.logger = logger
self.server = server
self.hosts = hosts
self.server_type = server_type
self.request_callback = request_callback
def connection_made(self, transport):
self.transport = transport
self.logger.info("[%s:%d] Transport created", self.ip, self.port)
def connection_lost(self, exc):
if exc:
self.logger.warn(
"[%s:%d] Connection lost: %s", self.ip, self.port, str(exc)
)
else:
self.logger.info("[%s:%d] Transport closed", self.ip, self.port)
def send_response(self, reply, addr):
self.transport.sendto(reply.ReplyPacket(), addr)
def datagram_received(self, data, addr):
self.logger.debug(
"[%s:%d] Received %d bytes from %s",
self.ip,
self.port,
len(data),
addr,
)
receive_date = datetime.utcnow()
if addr[0] in self.hosts:
remote_host = self.hosts[addr[0]]
elif "0.0.0.0" in self.hosts:
remote_host = self.hosts["0.0.0.0"]
else:
self.logger.warn(
"[%s:%d] Drop package from unknown source %s",
self.ip,
self.port,
addr,
)
return
try:
self.logger.debug(
"[%s:%d] Received from %s packet: %s",
self.ip,
self.port,
addr,
data.hex(),
)
req = Packet(packet=data, dict=self.server.dict)
except Exception as exc:
self.logger.error(
"[%s:%d] Error on decode packet: %s", self.ip, self.port, exc
)
return
try:
if req.code in (
AccountingResponse,
AccessAccept,
AccessReject,
CoANAK,
CoAACK,
DisconnectNAK,
DisconnectACK,
):
raise ServerPacketError(f"Invalid response packet {req.code}")
elif self.server_type == ServerType.Auth:
if req.code != AccessRequest:
raise ServerPacketError(
"Received non-auth packet on auth port"
)
req = AuthPacket(
secret=remote_host.secret,
dict=self.server.dict,
packet=data,
)
if self.server.enable_pkt_verify:
if req.VerifyAuthRequest():
raise PacketError("Packet verification failed")
elif self.server_type == ServerType.Coa:
if req.code != DisconnectRequest and req.code != CoARequest:
raise ServerPacketError(
"Received non-coa packet on coa port"
)
req = CoAPacket(
secret=remote_host.secret,
dict=self.server.dict,
packet=data,
)
if self.server.enable_pkt_verify:
if req.VerifyCoARequest():
raise PacketError("Packet verification failed")
elif self.server_type == ServerType.Acct:
if req.code != AccountingRequest:
raise ServerPacketError(
"Received non-acct packet on acct port"
)
req = AcctPacket(
secret=remote_host.secret,
dict=self.server.dict,
packet=data,
)
if self.server.enable_pkt_verify:
if req.VerifyAcctRequest():
raise PacketError("Packet verification failed")
# Call request callback
self.request_callback(self, req, addr)
except Exception as exc:
if self.server.debug:
self.logger.exception(
"[%s:%d] Error for packet from %s", self.ip, self.port, addr
)
else:
self.logger.error(
"[%s:%d] Error for packet from %s: %s",
self.ip,
self.port,
addr,
exc,
)
process_date = datetime.utcnow()
self.logger.debug(
"[%s:%d] Request from %s processed in %d ms",
self.ip,
self.port,
addr,
(process_date - receive_date).microseconds / 1000,
)
def error_received(self, exc):
self.logger.error("[%s:%d] Error received: %s", self.ip, self.port, exc)
async def close_transport(self):
if self.transport:
self.logger.debug("[%s:%d] Close transport...", self.ip, self.port)
self.transport.close()
self.transport = None
def __str__(self):
return f"DatagramProtocolServer(ip={self.ip}, port={self.port})"
# Used as protocol_factory
def __call__(self):
return self
class ServerAsync(metaclass=ABCMeta):
def __init__(
self,
auth_port=1812,
acct_port=1813,
coa_port=3799,
hosts=None,
dictionary=None,
loop=None,
logger_name="pyrad",
enable_pkt_verify=False,
debug=False,
):
if not loop:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
self.logger = logging.getLogger(logger_name)
if hosts is None:
self.hosts = {}
else:
self.hosts = hosts
self.auth_port = auth_port
self.auth_protocols = []
self.acct_port = acct_port
self.acct_protocols = []
self.coa_port = coa_port
self.coa_protocols = []
self.dict = dictionary
self.enable_pkt_verify = enable_pkt_verify
self.debug = debug
def __request_handler__(self, protocol, req, addr):
try:
if protocol.server_type == ServerType.Acct:
self.handle_acct_packet(protocol, req, addr)
elif protocol.server_type == ServerType.Auth:
self.handle_auth_packet(protocol, req, addr)
elif (
protocol.server_type == ServerType.Coa
and req.code == CoARequest
):
self.handle_coa_packet(protocol, req, addr)
elif (
protocol.server_type == ServerType.Coa
and req.code == DisconnectRequest
):
self.handle_disconnect_packet(protocol, req, addr)
else:
self.logger.error(
"[%s:%s] Unexpected request found",
protocol.ip,
protocol.port,
)
except Exception as exc:
if self.debug:
self.logger.exception(
"[%s:%s] Unexpected error", protocol.ip, protocol.port
)
else:
self.logger.error(
"[%s:%s] Unexpected error: %s",
protocol.ip,
protocol.port,
exc,
)
def __is_present_proto__(self, ip, port):
if port == self.auth_port:
for proto in self.auth_protocols:
if proto.ip == ip:
return True
elif port == self.acct_port:
for proto in self.acct_protocols:
if proto.ip == ip:
return True
elif port == self.coa_port:
for proto in self.coa_protocols:
if proto.ip == ip:
return True
return False
# noinspection PyPep8Naming
@staticmethod
def CreateReplyPacket(pkt, **attributes):
"""Create a reply packet.
Create a new packet which can be returned as a reply to a received
packet.
:param pkt: original packet
:type pkt: Packet instance
"""
reply = pkt.CreateReply(**attributes)
return reply
async def initialize_transports(
self,
enable_acct=False,
enable_auth=False,
enable_coa=False,
addresses=None,
):
task_list = []
if not enable_acct and not enable_auth and not enable_coa:
raise Exception("No transports selected")
if not addresses or len(addresses) == 0:
addresses = ["127.0.0.1"]
# noinspection SpellCheckingInspection
for addr in addresses:
if enable_acct and not self.__is_present_proto__(
addr, self.acct_port
):
protocol_acct = DatagramProtocolServer(
addr,
self.acct_port,
self.logger,
self,
ServerType.Acct,
self.hosts,
self.__request_handler__,
)
bind_addr = (addr, self.acct_port)
acct_connect = self.loop.create_datagram_endpoint(
protocol_acct, reuse_port=True, local_addr=bind_addr
)
self.acct_protocols.append(protocol_acct)
task_list.append(acct_connect)
if enable_auth and not self.__is_present_proto__(
addr, self.auth_port
):
protocol_auth = DatagramProtocolServer(
addr,
self.auth_port,
self.logger,
self,
ServerType.Auth,
self.hosts,
self.__request_handler__,
)
bind_addr = (addr, self.auth_port)
auth_connect = self.loop.create_datagram_endpoint(
protocol_auth, reuse_port=True, local_addr=bind_addr
)
self.auth_protocols.append(protocol_auth)
task_list.append(auth_connect)
if enable_coa and not self.__is_present_proto__(
addr, self.coa_port
):
protocol_coa = DatagramProtocolServer(
addr,
self.coa_port,
self.logger,
self,
ServerType.Coa,
self.hosts,
self.__request_handler__,
)
bind_addr = (addr, self.coa_port)
coa_connect = self.loop.create_datagram_endpoint(
protocol_coa, reuse_port=True, local_addr=bind_addr
)
self.coa_protocols.append(protocol_coa)
task_list.append(coa_connect)
await asyncio.ensure_future(
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
)
# noinspection SpellCheckingInspection
async def deinitialize_transports(
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
):
if deinit_coa:
for proto in self.coa_protocols:
await proto.close_transport()
del proto
self.coa_protocols = []
if deinit_auth:
for proto in self.auth_protocols:
await proto.close_transport()
del proto
self.auth_protocols = []
if deinit_acct:
for proto in self.acct_protocols:
await proto.close_transport()
del proto
self.acct_protocols = []
@abstractmethod
def handle_auth_packet(self, protocol, pkt, addr):
pass
@abstractmethod
def handle_acct_packet(self, protocol, pkt, addr):
pass
@abstractmethod
def handle_coa_packet(self, protocol, pkt, addr):
pass
@abstractmethod
def handle_disconnect_packet(self, protocol, pkt, addr):
pass

243
src/pyrad3/tools.py Normal file
View File

@@ -0,0 +1,243 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Collections of functions to en- and decode RADIUS Attributes"""
from typing import Union
from ipaddress import IPv4Address, IPv6Address, IPv6Network, ip_network, ip_address
import struct
def encode_string(string: str) -> bytes:
"""Encode a RADIUS value of type string"""
if len(string) > 253:
raise ValueError("Can only encode strings of <= 253 characters")
if isinstance(string, str):
return string.encode("utf-8")
return string
def encode_octets(string: bytes) -> bytes:
"""Encode a RADIUS value of type octet"""
if len(string) > 253:
raise ValueError("Can only encode strings of <= 253 characters")
return string
def encode_address(addr: Union[str, IPv4Address]) -> bytes:
"""Encode a RADIUS value of type ipaddr"""
return IPv4Address(addr).packed
def encode_ipv6_prefix(addr: Union[str, IPv6Network]) -> bytes:
"""Encode a RADIUS value of type ipv6prefix"""
address = IPv6Network(addr)
return struct.pack("2B", *[0, address.prefixlen]) + address.network_address.packed
def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes:
"""Encode a RADIUS value of type ipv6addr"""
return IPv6Address(addr).packed
def encode_combo_ip(addr: Union[str, IPv4Address, IPv6Address]) -> bytes:
"""Encode a RADIUS value of type combo-ip"""
return ip_address(addr).packed
def encode_ascend_binary(string: str) -> bytes:
"""
struct_format: List of type=value pairs sperated by spaces.
Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32'
Type:
family ipv4(default) or ipv6
action discard(default) or accept
direction in(default) or out
src source prefix (default ignore)
dst destination prefix (default ignore)
proto protocol number / next-header number (default ignore)
sport source port (default ignore)
dport destination port (default ignore)
sportq source port qualifier (default 0)
dportq destination port qualifier (default 0)
Source/Destination Port Qualifier:
0 no compare
1 less than
2 equal to
3 greater than
4 not equal to
"""
terms = {
"family": b"\x01",
"action": b"\x00",
"direction": b"\x01",
"src": b"\x00\x00\x00\x00",
"dst": b"\x00\x00\x00\x00",
"srcl": b"\x00",
"dstl": b"\x00",
"proto": b"\x00",
"sport": b"\x00\x00",
"dport": b"\x00\x00",
"sportq": b"\x00",
"dportq": b"\x00",
}
for term in string.split(" "):
key, value = term.split("=")
if key == "family" and value == "ipv6":
terms[key] = b"\x03"
if terms["src"] == b"\x00\x00\x00\x00":
terms["src"] = 16 * b"\x00"
if terms["dst"] == b"\x00\x00\x00\x00":
terms["dst"] = 16 * b"\x00"
elif key == "action" and value == "accept":
terms[key] = b"\x01"
elif key == "direction" and value == "out":
terms[key] = b"\x00"
elif key in ("src", "dst"):
address = ip_network(value)
terms[key] = address.network_address.packed
terms[key + "l"] = struct.pack("B", address.prefixlen)
elif key in ("sport", "dport"):
terms[key] = struct.pack("!H", int(value))
elif key in ("sportq", "dportq", "proto"):
terms[key] = struct.pack("B", int(value))
trailer = 8 * b"\x00"
result = b"".join(
(
terms["family"],
terms["action"],
terms["direction"],
b"\x00",
terms["src"],
terms["dst"],
terms["srcl"],
terms["dstl"],
terms["proto"],
b"\x00",
terms["sport"],
terms["dport"],
terms["sportq"],
terms["dportq"],
b"\x00\x00",
trailer,
)
)
return result
def encode_integer(num: Union[int, str], struct_format="!I") -> bytes:
"""Encode a RADIUS value of some type integer"""
return struct.pack(struct_format, int(num))
def encode_date(num: Union[int, str]) -> bytes:
"""Encode a RADIUS value of type date"""
return struct.pack("!I", int(num))
def decode_string(string: bytes) -> Union[str, bytes]:
"""Decode a RADIUS value of type string"""
try:
return string.decode("utf-8")
except UnicodeDecodeError:
return string
def decode_octets(string: bytes) -> bytes:
"""Decode a RADIUS value of type octet"""
return string
def decode_address(addr: bytes) -> IPv4Address:
"""Decode a RADIUS value of type ipaddr"""
return IPv4Address(addr)
def decode_ipv6_prefix(addr: bytes) -> IPv6Network:
"""Decode a RADIUS value of type ipv6prefix"""
addr = addr + b"\x00" * (18 - len(addr))
prefix = addr[:2]
addr = addr[2:]
return IPv6Network((prefix, addr))
def decode_ipv6_address(addr: bytes) -> IPv6Address:
"""Decode a RADIUS value of type ipv6addr"""
addr = addr + b"\x00" * (16 - len(addr))
return IPv6Address(addr)
def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]:
"""Decode a RADIUS value of type combo-ip"""
return ip_address(addr).packed
def decode_ascend_binary(string):
"""Decode a RADIUS value of type abinary"""
raise NotImplementedError
def decode_integer(num: bytes, struct_format="!I") -> int:
"""Decode a RADIUS value of some integer type"""
return (struct.unpack(struct_format, num))[0]
def decode_date(num):
"""Decode a RADIUS value of type date"""
return (struct.unpack("!I", num))[0]
ENCODE_MAP = {
"string": encode_string,
"octets": encode_octets,
"integer": encode_integer,
"ipaddr": encode_address,
"ipv6prefix": encode_ipv6_prefix,
"ipv6addr": encode_ipv6_address,
"abinary": encode_ascend_binary,
"signed": lambda value: encode_integer(value, "!i"),
"short": lambda value: encode_integer(value, "!H"),
"byte": lambda value: encode_integer(value, "!B"),
"integer64": lambda value: encode_integer(value, '!Q'),
"date": encode_date,
}
def encode_attr(datatype, value):
"""Encode a RADIUS attribute"""
try:
return ENCODE_MAP[datatype](value)
except KeyError:
raise ValueError(f"Unknown attribute type {datatype}")
DECODE_MAP = {
"string": decode_string,
"octets": decode_octets,
"integer": decode_integer,
"ipaddr": decode_address,
"ipv6prefix": decode_ipv6_prefix,
"ipv6addr": decode_ipv6_address,
"abinary": decode_ascend_binary,
"signed": lambda value: decode_integer(value, "!i"),
"short": lambda value: decode_integer(value, "!H"),
"byte": lambda value: decode_integer(value, "!B"),
"integer64": lambda value: decode_integer(value, "!Q"),
"date": decode_date,
}
def decode_attr(datatype, value):
"""Decode a RADIUS attribute"""
try:
return DECODE_MAP[datatype](value)
except KeyError:
raise ValueError(f"Unknown attribute type {datatype}")

234
src/pyrad3/utils.py Normal file
View File

@@ -0,0 +1,234 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Collection of functions to deal with RADIUS packet en- and decoding."""
from collections import namedtuple
from typing import List, Union
import hashlib
import secrets
import struct
from pyrad3.dictionary import Dictionary
RANDOM_GENERATOR = secrets.SystemRandom()
MD5 = hashlib.md5
class PacketError(Exception):
"""Exception for Invalid Packets"""
Header = namedtuple("Header", ["code", "radius_id", "length", "authenticator"])
Attribute = namedtuple("Attribute", ["name", "pos", "type", "length", "value"])
def parse_header(raw_packet: bytes) -> Header:
"""Parse the Header of a RADIUS Packet."""
try:
header = struct.unpack("!BBH16s", raw_packet)
except struct.error:
raise PacketError("Packet header is corrupt")
length = header[3]
if len(raw_packet) != length:
raise PacketError(
f"RADIUS Packet ({len(raw_packet)}) has an invalid length ({length})"
)
if length > 4096:
raise PacketError(f"Packet length is too big ({length})")
return Header(*header)
def parse_attributes(
rad_dict: Dictionary, raw_packet: bytes
) -> List[Attribute]:
"""Parse the Attributes in a RADIUS Packet.
This function skips the Header. The Header must be parsed and verified
separately.
"""
attributes = []
packet = raw_packet[20:]
while packet:
try:
(key, length) = struct.unpack("!BB", packet[0:2])
except struct.error:
raise PacketError("Attribute header is corrupt")
if length < 2:
raise PacketError(f"Attribute length ({length}) is too small")
value = packet[2:length]
offset = len(raw_packet) - len(packet) + length
if key == 26:
try:
attributes.extend(
parse_vendor_attributes(rad_dict, offset, value)
)
except (PacketError, IndexError):
attributes.append(
Attribute(
name="Unknown-Vendor-Attribute",
pos=offset,
type="octets",
length=int(packet[1]),
value=packet[2:],
)
)
else:
key = parse_key(rad_dict, key)
attributes.extend(parse_value(rad_dict, key, offset, value))
packet = packet[length:]
return attributes
def parse_vendor_attributes(
rad_dict: Dictionary, offset: int, vendor_value: bytes
) -> List[Attribute]:
"""Parse A Vendor Attribute"""
if len(vendor_value) < 4:
raise PacketError
vendor_id = int.from_bytes(vendor_value[:4], "big")
vendor_dict = rad_dict.vendor_by_id[vendor_id]
vendor_name = vendor_dict.name
attributes = []
vendor_tlv = vendor_value[4:]
while vendor_tlv:
try:
(key, length) = struct.unpack("!BB", vendor_tlv[0:2])
except struct.error:
attribute = [
Attribute(
name=f"Unknown-{vendor_name}-Attribute",
pos=offset - len(vendor_value),
type="octets",
length=len(vendor_value) - 4,
value=vendor_value,
)
]
else:
offset = offset - len(vendor_tlv) + length
key = parse_key(vendor_dict, key)
attribute = parse_value(vendor_dict, key, offset, vendor_tlv)
attributes.extend(attribute)
vendor_tlv = vendor_tlv[length:]
return attributes
def parse_key(rad_dict: Dictionary, key_id: int) -> Union[str, int]:
"""Parse the key in the Dictionary Context"""
try:
return rad_dict.attrs[key_id].name
except KeyError:
return key_id
def parse_value(
rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
) -> List[Attribute]:
"""Parse the Value in the given Key/Dictionary Context"""
return []
def calculate_authenticator(
secret: bytes, authenticator: bytes, raw_packet: bytes
) -> bytes:
"""Calculate the Authenticator for the RADIUS Packet"""
return MD5(
raw_packet[0:4] + authenticator + raw_packet[20:] + secret
).digest()
def validate_pap_password(
secret: bytes,
authenticator: bytes,
obfuscated_password: bytes,
plaintext_password: bytes,
) -> bool:
"""Check if the plaintext and the RADIUS passwords match
This function does not "decrypt" the received password.
"""
obf_pass = password_encode(secret, authenticator, plaintext_password)
return obfuscated_password == obf_pass
def password_encode(
secret: bytes, authenticator: bytes, password: bytes
) -> bytes:
"""Obfuscate the plaintext Password for RADIUS"""
password += b"\x00" * (16 - (len(password) % 16))
return obfuscation_algorithm(secret, authenticator, password)
def password_decode(
secret: bytes, authenticator: bytes, obfuscated_password: bytes
) -> str:
"""Reverse the RADIUS obfuscation on a given password
The password password is padded with \\x00 to a 16 byte boundary. The padding will
be removed by this function.
If the original password had some trailing \\x00 it will get lost. Therefore it is
not recommended to use (trailing) \\x00 in passwords.
"""
deobfuscated = obfuscation_algorithm(
secret, authenticator, obfuscated_password
)
return deobfuscated.rstrip(b"\x00").decode("utf-8")
def obfuscation_algorithm(
secret: bytes, authenticator: bytes, password: bytes
) -> bytes:
"""Obfuscate the plaintext password.
This function does not deal with the padding (which the
RADIUS Protocol requires.)
The User has to pad the password themself, or better use
the `password_encode` or `password_decode` function.
"""
result = b""
buf = password
last = authenticator
while buf:
cur_hash = MD5(secret + last).digest()
for cbuf, chash in zip(buf, cur_hash):
result += bytes([cbuf ^ chash])
(last, buf) = (buf[:16], buf[16:])
return result
def validate_chap_password(
chap_id: bytes,
challenge: bytes,
chap_password: bytes,
plaintext_password: bytes,
) -> bool:
"""Validate the CHAP password against the given plaintext password"""
return (
chap_password == MD5(chap_id + plaintext_password + challenge).digest()
)
def salt_encrypt(secret: bytes, authenticator: bytes, value: bytes) -> bytes:
"""Salt Encrypt the given value"""
# The highest bit MUST be 1
random_value = RANDOM_GENERATOR.randrange(32768, 65535)
salt = struct.pack("!H", random_value)
salted_auth = authenticator + salt
return obfuscation_algorithm(secret, salted_auth, value)
def salt_decrypt(
secret: bytes, authenticator: bytes, salt: bytes, encrypted_value: bytes
) -> bytes:
"""Decrypt the given value"""
salted_auth = authenticator + salt
return obfuscation_algorithm(secret, salted_auth, encrypted_value)