cleanup and a lot more tests

This commit is contained in:
Istvan Ruzman
2020-08-07 20:12:10 +02:00
parent e8dce3e2cf
commit 360927e559
35 changed files with 187 additions and 2849 deletions

View File

@@ -36,11 +36,11 @@ This package contains four modules:
- tools: utility functions
"""
__docformat__ = 'epytext en'
__docformat__ = "epytext en"
__author__ = 'Istvan Ruzman <istvan@ruzman.eu>'
__url__ = 'http://pyrad.readthedocs.io/en/latest/?badge=latest'
__copyright__ = 'Copyright 2020 Istvan Ruzman'
__version__ = '0.1.0'
__author__ = "Istvan Ruzman <istvan@ruzman.eu>"
__url__ = "http://pyrad.readthedocs.io/en/latest/?badge=latest"
__copyright__ = "Copyright 2020 Istvan Ruzman"
__version__ = "0.1.0"
__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'utils']
__all__ = ["client", "dictionary", "packet", "server", "tools", "utils"]

View File

@@ -1,39 +0,0 @@
# bidict.py
#
# Bidirectional map
class BiDict:
def __init__(self):
self.forward = {}
self.backward = {}
def add(self, one, two):
self.forward[one] = two
self.backward[two] = one
def __len__(self):
return len(self.forward)
def __getitem__(self, key):
return self.get_forward(key)
def __delitem__(self, key):
try:
del self.backward[self.forward[key]]
del self.forward[key]
except KeyError:
del self.forward[self.backward[key]]
del self.backward[key]
def get_forward(self, key):
return self.forward[key]
def has_forward(self, key):
return key in self.forward
def get_backward(self, key):
return self.backward[key]
def has_backward(self, key):
return key in self.backward

View File

@@ -1,442 +0,0 @@
# client_async.py
#
# Copyright 2018-2020 Geaaru <geaaru<@>gmail.com>
__docformat__ = "epytext en"
from datetime import datetime
import asyncio
import logging
import random
from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket
class DatagramProtocolClient(asyncio.Protocol):
def __init__(self, server, port, logger, client, retries=3, timeout=30):
self.transport = None
self.port = port
self.server = server
self.logger = logger
self.retries = retries
self.timeout = timeout
self.client = client
# Map of pending requests
self.pending_requests = {}
# Use cryptographic-safe random generator as provided by the OS.
random_generator = random.SystemRandom()
self.packet_id = random_generator.randrange(0, 256)
self.timeout_future = None
async def __timeout_handler__(self):
try:
while True:
req2delete = []
now = datetime.now()
next_weak_up = self.timeout
# noinspection PyShadowingBuiltins
for id, req in self.pending_requests.items():
secs = (req["send_date"] - now).seconds
if secs > self.timeout:
if req["retries"] == self.retries:
self.logger.debug(
"[%s:%d] For request %d execute all retries",
self.server,
self.port,
id,
)
req["future"].set_exception(
TimeoutError("Timeout on Reply")
)
req2delete.append(id)
else:
# Send again packet
req["send_date"] = now
req["retries"] += 1
self.logger.debug(
"[%s:%d] For request %d execute retry %d",
self.server,
self.port,
id,
req["retries"],
)
self.transport.sendto(req["packet"].RequestPacket())
elif next_weak_up > secs:
next_weak_up = secs
# noinspection PyShadowingBuiltins
for id in req2delete:
# Remove request for map
del self.pending_requests[id]
await asyncio.sleep(next_weak_up)
except asyncio.CancelledError:
pass
def send_packet(self, packet, future):
if packet.id in self.pending_requests:
raise Exception(f"Packet with id {packet.id} already present")
# Store packet on pending requests map
self.pending_requests[packet.id] = {
"packet": packet,
"creation_date": datetime.now(),
"retries": 0,
"future": future,
"send_date": datetime.now(),
}
# In queue packet raw on socket buffer
self.transport.sendto(packet.RequestPacket())
def connection_made(self, transport):
self.transport = transport
socket = transport.get_extra_info("socket")
self.logger.info(
"[%s:%d] Transport created with binding in %s:%d",
self.server,
self.port,
socket.getsockname()[0],
socket.getsockname()[1],
)
pre_loop = asyncio.get_event_loop()
asyncio.set_event_loop(loop=self.client.loop)
# Start asynchronous timer handler
self.timeout_future = asyncio.ensure_future(self.__timeout_handler__())
asyncio.set_event_loop(loop=pre_loop)
def error_received(self, exc):
self.logger.error(
"[%s:%d] Error received: %s", self.server, self.port, exc
)
def connection_lost(self, exc):
if exc:
self.logger.warn(
"[%s:%d] Connection lost: %s", self.server, self.port, str(exc)
)
else:
self.logger.info("[%s:%d] Transport closed", self.server, self.port)
# noinspection PyUnusedLocal
def datagram_received(self, data, addr):
try:
req = self.pending_requests[data[0]]
reply = req.VerifyPacket(data)
req["future"].set_result(reply)
# Remove request for map
del self.pending_requests[reply.id]
except KeyError:
self.logger.warn(
"[%s:%d] Ignore invalid reply: %s", self.server, self.port, data
)
except PacketError as exc:
self.logger.error(
"[%s:%d] Error on decode or verify packet: %s",
self.server,
self.port,
exc,
)
async def close_transport(self):
if self.transport:
self.logger.debug(
"[%s:%d] Closing transport...", self.server, self.port
)
self.transport.close()
self.transport = None
if self.timeout_future:
self.timeout_future.cancel()
await self.timeout_future
self.timeout_future = None
def create_id(self):
self.packet_id = (self.packet_id + 1) % 256
return self.packet_id
def __str__(self):
return (
f"DatagramProtocolClient(server?={self.server}, port={self.port})"
)
# Used as protocol_factory
def __call__(self):
return self
class ClientAsync:
"""Basic RADIUS client.
This class implements a basic RADIUS client. It can send requests
to a RADIUS server, taking care of timeouts and retries, and
validate its replies.
:ivar retries: number of times to retry sending a RADIUS request
:type retries: integer
:ivar timeout: number of seconds to wait for an answer
:type timeout: integer
"""
# noinspection PyShadowingBuiltins
def __init__(
self,
server,
auth_port=1812,
acct_port=1813,
coa_port=3799,
secret=b"",
dict=None,
loop=None,
retries=3,
timeout=30,
logger_name="pyrad",
):
"""Constructor.
:param server: hostname or IP address of RADIUS server
:type server: string
:param auth_port: port to use for authentication packets
:type auth_port: integer
:param acct_port: port to use for accounting packets
:type acct_port: integer
:param coa_port: port to use for CoA packets
:type coa_port: integer
:param secret: RADIUS secret
:type secret: string
:param dict: RADIUS dictionary
:type dict: pyrad.dictionary.Dictionary
:param loop: Python loop handler
:type loop: asyncio event loop
"""
if not loop:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
self.logger = logging.getLogger(logger_name)
self.server = server
self.secret = secret
self.retries = retries
self.timeout = timeout
self.dict = dict
self.auth_port = auth_port
self.protocol_auth = None
self.acct_port = acct_port
self.protocol_acct = None
self.protocol_coa = None
self.coa_port = coa_port
async def initialize_transports(
self,
enable_acct=False,
enable_auth=False,
enable_coa=False,
local_addr=None,
local_auth_port=None,
local_acct_port=None,
local_coa_port=None,
):
task_list = []
if not enable_acct and not enable_auth and not enable_coa:
raise Exception("No transports selected")
if enable_acct and not self.protocol_acct:
self.protocol_acct = DatagramProtocolClient(
self.server,
self.acct_port,
self.logger,
self,
retries=self.retries,
timeout=self.timeout,
)
bind_addr = None
if local_addr and local_acct_port:
bind_addr = (local_addr, local_acct_port)
acct_connect = self.loop.create_datagram_endpoint(
self.protocol_acct,
reuse_port=True,
remote_addr=(self.server, self.acct_port),
local_addr=bind_addr,
)
task_list.append(acct_connect)
if enable_auth and not self.protocol_auth:
self.protocol_auth = DatagramProtocolClient(
self.server,
self.auth_port,
self.logger,
self,
retries=self.retries,
timeout=self.timeout,
)
bind_addr = None
if local_addr and local_auth_port:
bind_addr = (local_addr, local_auth_port)
auth_connect = self.loop.create_datagram_endpoint(
self.protocol_auth,
reuse_port=True,
remote_addr=(self.server, self.auth_port),
local_addr=bind_addr,
)
task_list.append(auth_connect)
if enable_coa and not self.protocol_coa:
self.protocol_coa = DatagramProtocolClient(
self.server,
self.coa_port,
self.logger,
self,
retries=self.retries,
timeout=self.timeout,
)
bind_addr = None
if local_addr and local_coa_port:
bind_addr = (local_addr, local_coa_port)
coa_connect = self.loop.create_datagram_endpoint(
self.protocol_coa,
reuse_port=True,
remote_addr=(self.server, self.coa_port),
local_addr=bind_addr,
)
task_list.append(coa_connect)
await asyncio.ensure_future(
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
)
# noinspection SpellCheckingInspection
async def deinitialize_transports(
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
):
if self.protocol_coa and deinit_coa:
await self.protocol_coa.close_transport()
del self.protocol_coa
self.protocol_coa = None
if self.protocol_auth and deinit_auth:
await self.protocol_auth.close_transport()
del self.protocol_auth
self.protocol_auth = None
if self.protocol_acct and deinit_acct:
await self.protocol_acct.close_transport()
del self.protocol_acct
self.protocol_acct = None
# noinspection PyPep8Naming
def CreateAuthPacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS packet which can
be used to communicate with the RADIUS server this client
talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.Packet
"""
if not self.protocol_auth:
raise Exception("Transport not initialized")
return AuthPacket(
dict=self.dict,
id=self.protocol_auth.create_id(),
secret=self.secret,
**args,
)
# noinspection PyPep8Naming
def CreateAcctPacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS packet which can
be used to communicate with the RADIUS server this client
talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.Packet
"""
if not self.protocol_acct:
raise Exception("Transport not initialized")
return AcctPacket(
id=self.protocol_acct.create_id(),
dict=self.dict,
secret=self.secret,
**args,
)
# noinspection PyPep8Naming
def CreateCoAPacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS packet which can
be used to communicate with the RADIUS server this client
talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.Packet
"""
if not self.protocol_acct:
raise Exception("Transport not initialized")
return CoAPacket(
id=self.protocol_coa.create_id(),
dict=self.dict,
secret=self.secret,
**args,
)
# noinspection PyPep8Naming
# noinspection PyShadowingBuiltins
def CreatePacket(self, id, **args):
if not id:
raise Exception("Missing mandatory packet id")
return Packet(id=id, dict=self.dict, secret=self.secret, **args)
# noinspection PyPep8Naming
def SendPacket(self, pkt):
"""Send a packet to a RADIUS server.
:param pkt: the packet to send
:type pkt: pyrad.packet.Packet
:return: Future related with packet to send
:rtype: asyncio.Future
"""
ans = asyncio.Future(loop=self.loop)
if isinstance(pkt, AuthPacket):
if not self.protocol_auth:
raise Exception("Transport not initialized")
self.protocol_auth.send_packet(pkt, ans)
elif isinstance(pkt, AcctPacket):
if not self.protocol_acct:
raise Exception("Transport not initialized")
self.protocol_acct.send_packet(pkt, ans)
elif isinstance(pkt, CoAPacket):
if not self.protocol_coa:
raise Exception("Transport not initialized")
self.protocol_coa.send_packet(pkt, ans)
else:
raise Exception("Unsupported packet")
return ans

View File

@@ -10,7 +10,6 @@ from enum import IntEnum, Enum, auto
from dataclasses import dataclass
from os.path import dirname, isabs, join, normpath
from typing import (
cast,
Dict,
Generator,
IO,
@@ -167,18 +166,6 @@ def _parse_number(num: str) -> int:
return int(num)
def _parse_attribute_code(attr_code: str) -> List[int]:
"""Parse attribute codes from (Free)RADIUS dictionaries
Codes can be either decimal, octal, or hexadecimal.
TLV typed can
"""
codes = []
for code in attr_code.split("."):
codes.append(_parse_number(code))
return codes
class Dictionary:
"""(Free)RADIUS Dictionary.
@@ -358,7 +345,7 @@ class Dictionary:
for flag in flags:
flag_len = len(flag)
if flag == 1:
if flag_len == 1:
value = None
elif flag_len == 2:
value = flag[1]
@@ -372,6 +359,8 @@ class Dictionary:
has_tag = True
elif key == "encrypt":
try:
if value == "0":
raise ValueError
encrypt = Encrypt(int(value)) # type: ignore
except (ValueError, TypeError):
raise ParseError(
@@ -386,6 +375,32 @@ class Dictionary:
return has_tag, encrypt
def _parse_attribute_code(self, attr_code: str, line_num: int) -> List[int]:
filename = self.filestack[-1]
tlength = self.cur_vendor.tlength
codes = []
for code in attr_code.split("."):
try:
code_num = _parse_number(code)
except ValueError:
raise ParseError(
filename, f'invalid attribute code {attr_code}""', line_num
)
if 2 ** (8 * tlength) <= code_num:
raise ParseError(
filename,
f"attribute code is too big, must be smaller than 2**{tlength}",
line_num,
)
if code_num < 0:
raise ParseError(
filename,
"negative attribute codes are not allowed",
line_num,
)
codes.append(code_num)
return codes
def _parse_attribute(self, tokens: Sequence[str], line_num: int):
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
@@ -409,27 +424,7 @@ class Dictionary:
line_num,
)
try:
codes = _parse_attribute_code(attr_code)
except ValueError:
raise ParseError(
filename, f'invalid attribute code {attr_code}""', line_num
)
for code in codes:
tlength = self.cur_vendor.tlength
if 2 ** (8 * tlength) <= code:
raise ParseError(
filename,
f"attribute code is too big, must be smaller than 2**{tlength}",
line_num,
)
if code < 0:
raise ParseError(
filename,
"negative attribute codes are not allowed",
line_num,
)
codes = self._parse_attribute_code(attr_code, line_num)
# TODO: Do we some explicit handling of tlvs?
# if len(codes) > 1:
@@ -485,7 +480,14 @@ class Dictionary:
filename, f"Invalid number {vvalue} for VALUE {key}", line_num
)
attribute = self.attrindex[attr_name]
try:
attribute = self.attrindex[attr_name]
except KeyError:
raise ParseError(
filename,
f"ATTRIBUTE {attr_name} has not been defined yet",
line_num,
)
try:
datatype = str(attribute.datatype).split(".")[1]
lmin, lmax = INTEGER_TYPES[datatype]
@@ -504,3 +506,6 @@ class Dictionary:
)
attribute.values[value] = key
attribute.values[key] = value
def __getitem__(self, key: Union[str, int, Tuple[int, ...]]) -> Attribute:
return self.attrindex[key]

View File

@@ -7,8 +7,9 @@ from pyrad3.dictionary import Dictionary
from pyrad3 import packet
class Host: # pylint: disable=too-many-arguments
class Host: # pylint: disable=too-many-arguments
"""Interface Class for RADIUS Clients and Servers"""
def __init__(
self,
secret: bytes,

View File

@@ -169,7 +169,8 @@ class Packet(OrderedDict):
self["Message-Authenticator"] = 16 * b"\00"
attr = self.ordered_attributes[-1]
generated = self._generate_message_authenticator(attr)
self[attr.pos + 2 :] = generated
index = attr.pos + 2
self[index:] = generated
def refresh_message_authenticator(self):
self.add_message_authenticator()
@@ -283,7 +284,7 @@ class AcctPacket(Packet):
def increase_acct_delay_time(self, delay_time: float):
try:
self['Acct-Delay-Time'] += int(delay_time)
self["Acct-Delay-Time"] += int(delay_time)
except KeyError:
pass

View File

@@ -1,73 +0,0 @@
# proxy.py
#
# Copyright 2005,2007 Wichert Akkerman <wichert@wiggy.net>
#
# A RADIUS proxy as defined in RFC 2138
import select
import socket
from pyrad.server import Server, ServerPacketError
from pyrad import packet
class Proxy(Server):
"""Base class for RADIUS proxies.
This class extends tha RADIUS server class with the capability to
handle communication with other RADIUS servers as well.
:ivar _proxyfd: network socket used to communicate with other servers
:type _proxyfd: socket class instance
"""
def _prepare_sockets(self):
Server._prepare_sockets(self)
self._proxyfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._fdmap[self._proxyfd.fileno()] = self._proxyfd
self._poll.register(
self._proxyfd.fileno(),
(select.POLLIN | select.POLLPRI | select.POLLERR),
)
def _handle_proxy_packet(self, pkt):
"""Process a packet received on the reply socket.
If this packet should be dropped instead of processed a
:obj:`ServerPacketError` exception should be raised. The main loop
will drop the packet and log the reason.
:param pkt: packet to process
:type pkt: Packet class instance
"""
if pkt.source[0] not in self.hosts:
raise ServerPacketError("Received packet from unknown host")
pkt.secret = self.hosts[pkt.source[0]].secret
if pkt.code not in [
packet.AccessAccept,
packet.AccessReject,
packet.AccountingResponse,
]:
raise ServerPacketError("Received non-response on proxy socket")
def _process_input(self, fd):
"""Process available data.
If this packet should be dropped instead of processed a
`ServerPacketError` exception should be raised. The main loop
will drop the packet and log the reason.
This function calls either :obj:`HandleAuthPacket`,
:obj:`HandleAcctPacket` or :obj:`_handle_proxy_packet` depending on
which socket is being processed.
:param fd: socket to read packet from
:type fd: socket class instance
:param pkt: packet to process
:type pkt: Packet class instance
"""
if fd.fileno() == self._proxyfd.fileno():
pkt = self._grab_packet(
lambda data, s=self: s.CreatePacket(packet=data), fd
)
self._handle_proxy_packet(pkt)
else:
Server._process_input(self, fd)

View File

@@ -1,367 +0,0 @@
# server.py
#
# Copyright 2003-2004,2007,2016 Wichert Akkerman <wichert@wiggy.net>
import logging
import select
import socket
from pyrad import host
from pyrad import packet
LOGGER = logging.getLogger("pyrad")
class RemoteHost:
"""Remote RADIUS capable host we can talk to."""
def __init__(
self, address, secret, name, authport=1812, acctport=1813, coaport=3799
):
"""Constructor.
:param address: IP address
:type address: string
:param secret: RADIUS secret
:type secret: string
:param name: short name (used for logging only)
:type name: string
:param authport: port used for authentication packets
:type authport: integer
:param acctport: port used for accounting packets
:type acctport: integer
:param coaport: port used for CoA packets
:type coaport: integer
"""
self.address = address
self.secret = secret
self.authport = authport
self.acctport = acctport
self.coaport = coaport
self.name = name
class ServerPacketError(Exception):
"""Exception class for bogus packets.
ServerPacketError exceptions are only used inside the Server class to
abort processing of a packet.
"""
class Server(host.Host):
"""Basic RADIUS server.
This class implements the basics of a RADIUS server. It takes care
of the details of receiving and decoding requests; processing of
the requests should be done by overloading the appropriate methods
in derived classes.
:ivar hosts: hosts who are allowed to talk to us
:type hosts: dictionary of Host class instances
:ivar _poll: poll object for network sockets
:type _poll: select.poll class instance
:ivar _fdmap: map of filedescriptors to network sockets
:type _fdmap: dictionary
:cvar MaxPacketSize: maximum size of a RADIUS packet
:type MaxPacketSize: integer
"""
MaxPacketSize = 4096
def __init__(
self,
addresses=[],
authport=1812,
acctport=1813,
coaport=3799,
hosts=None,
dict=None,
auth_enabled=True,
acct_enabled=True,
coa_enabled=False,
):
"""Constructor.
:param addresses: IP addresses to listen on
:type addresses: sequence of strings
:param authport: port to listen on for authentication packets
:type authport: integer
:param acctport: port to listen on for accounting packets
:type acctport: integer
:param coaport: port to listen on for CoA packets
:type coaport: integer
:param hosts: hosts who we can talk to
:type hosts: dictionary mapping IP to RemoteHost class instances
:param dict: RADIUS dictionary to use
:type dict: Dictionary class instance
:param auth_enabled: enable auth server (default True)
:type auth_enabled: bool
:param acct_enabled: enable accounting server (default True)
:type acct_enabled: bool
:param coa_enabled: enable coa server (default False)
:type coa_enabled: bool
"""
host.Host.__init__(self, authport, acctport, coaport, dict)
if hosts is None:
self.hosts = {}
else:
self.hosts = hosts
self.auth_enabled = auth_enabled
self.authfds = []
self.acct_enabled = acct_enabled
self.acctfds = []
self.coa_enabled = coa_enabled
self.coafds = []
for addr in addresses:
self.BindToAddress(addr)
def _get_addr_info(self, addr):
"""Use getaddrinfo to lookup all addresses for each address.
Returns a list of tuples or an empty list:
[(family, address)]
:param addr: IP address to lookup
:type addr: string
"""
results = set()
try:
tmp = socket.getaddrinfo(addr, "www")
except socket.gaierror:
return []
for el in tmp:
results.add((el[0], el[4][0]))
return results
def BindToAddress(self, addr):
"""Add an address to listen to.
An empty string indicated you want to listen on all addresses.
:param addr: IP address to listen on
:type addr: string
"""
addrFamily = self._get_addr_info(addr)
for (family, address) in addrFamily:
if self.auth_enabled:
authfd = socket.socket(family, socket.SOCK_DGRAM)
authfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
authfd.bind((address, self.authport))
self.authfds.append(authfd)
if self.acct_enabled:
acctfd = socket.socket(family, socket.SOCK_DGRAM)
acctfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
acctfd.bind((address, self.acctport))
self.acctfds.append(acctfd)
if self.coa_enabled:
coafd = socket.socket(family, socket.SOCK_DGRAM)
coafd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
coafd.bind((address, self.coaport))
self.coafds.append(coafd)
def HandleAuthPacket(self, pkt):
"""Authentication packet handler.
This is an empty function that is called when a valid
authentication packet has been received. It can be overriden in
derived classes to add custom behaviour.
:param pkt: packet to process
:type pkt: Packet class instance
"""
def HandleAcctPacket(self, pkt):
"""Accounting packet handler.
This is an empty function that is called when a valid
accounting packet has been received. It can be overriden in
derived classes to add custom behaviour.
:param pkt: packet to process
:type pkt: Packet class instance
"""
def HandleCoaPacket(self, pkt):
"""CoA packet handler.
This is an empty function that is called when a valid
accounting packet has been received. It can be overriden in
derived classes to add custom behaviour.
:param pkt: packet to process
:type pkt: Packet class instance
"""
def HandleDisconnectPacket(self, pkt):
"""CoA packet handler.
This is an empty function that is called when a valid
accounting packet has been received. It can be overriden in
derived classes to add custom behaviour.
:param pkt: packet to process
:type pkt: Packet class instance
"""
def _add_secret(self, pkt):
"""Add secret to packets received and raise ServerPacketError
for unknown hosts.
:param pkt: packet to process
:type pkt: Packet class instance
"""
if pkt.source[0] in self.hosts:
pkt.secret = self.hosts[pkt.source[0]].secret
elif "0.0.0.0" in self.hosts:
pkt.secret = self.hosts["0.0.0.0"].secret
else:
raise ServerPacketError("Received packet from unknown host")
def _handle_auth_packet(self, pkt):
"""Process a packet received on the authentication port.
If this packet should be dropped instead of processed a
ServerPacketError exception should be raised. The main loop will
drop the packet and log the reason.
:param pkt: packet to process
:type pkt: Packet class instance
"""
self._add_secret(pkt)
if pkt.code != packet.AccessRequest:
raise ServerPacketError(
"Received non-authentication packet on authentication port"
)
self.HandleAuthPacket(pkt)
def _handle_acct_packet(self, pkt):
"""Process a packet received on the accounting port.
If this packet should be dropped instead of processed a
ServerPacketError exception should be raised. The main loop will
drop the packet and log the reason.
:param pkt: packet to process
:type pkt: Packet class instance
"""
self._add_secret(pkt)
if pkt.code not in [
packet.AccountingRequest,
packet.AccountingResponse,
]:
raise ServerPacketError(
"Received non-accounting packet on accounting port"
)
self.HandleAcctPacket(pkt)
def _handle_coa_packet(self, pkt):
"""Process a packet received on the coa port.
If this packet should be dropped instead of processed a
ServerPacketError exception should be raised. The main loop will
drop the packet and log the reason.
:param pkt: packet to process
:type pkt: Packet class instance
"""
self._add_secret(pkt)
pkt.secret = self.hosts[pkt.source[0]].secret
if pkt.code == packet.CoARequest:
self.HandleCoaPacket(pkt)
elif pkt.code == packet.DisconnectRequest:
self.HandleDisconnectPacket(pkt)
else:
raise ServerPacketError("Received non-coa packet on coa port")
def _grab_packet(self, pktgen, fd):
"""Read a packet from a network connection.
This method assumes there is data waiting for to be read.
:param fd: socket to read packet from
:type fd: socket class instance
:return: RADIUS packet
:rtype: Packet class instance
"""
(data, source) = fd.recvfrom(self.MaxPacketSize)
pkt = pktgen(data)
pkt.source = source
pkt.fd = fd
return pkt
def _prepare_sockets(self):
"""Prepare all sockets to receive packets.
"""
for fd in self.authfds + self.acctfds + self.coafds:
self._fdmap[fd.fileno()] = fd
self._poll.register(
fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR
)
if self.auth_enabled:
self._realauthfds = list(map(lambda x: x.fileno(), self.authfds))
if self.acct_enabled:
self._realacctfds = list(map(lambda x: x.fileno(), self.acctfds))
if self.coa_enabled:
self._realcoafds = list(map(lambda x: x.fileno(), self.coafds))
def CreateReplyPacket(self, pkt, **attributes):
"""Create a reply packet.
Create a new packet which can be returned as a reply to a received
packet.
:param pkt: original packet
:type pkt: Packet instance
"""
reply = pkt.CreateReply(**attributes)
reply.source = pkt.source
return reply
def _process_input(self, fd):
"""Process available data.
If this packet should be dropped instead of processed a
PacketError exception should be raised. The main loop will
drop the packet and log the reason.
This function calls either HandleAuthPacket() or
HandleAcctPacket() depending on which socket is being
processed.
:param fd: socket to read packet from
:type fd: socket class instance
"""
if self.auth_enabled and fd.fileno() in self._realauthfds:
pkt = self._grab_packet(
lambda data, s=self: s.CreateAuthPacket(packet=data), fd
)
self._handle_auth_packet(pkt)
elif self.acct_enabled and fd.fileno() in self._realacctfds:
pkt = self._grab_packet(
lambda data, s=self: s.CreateAcctPacket(packet=data), fd
)
self._handle_acct_packet(pkt)
elif self.coa_enabled:
pkt = self._grab_packet(
lambda data, s=self: s.CreateCoAPacket(packet=data), fd
)
self._handle_coa_packet(pkt)
else:
raise ServerPacketError("Received packet for unknown handler")
def Run(self):
"""Main loop.
This method is the main loop for a RADIUS server. It waits
for packets to arrive via the network and calls other methods
to process them.
"""
self._poll = select.poll()
self._fdmap = {}
self._prepare_sockets()
while True:
for (fd, event) in self._poll.poll():
if event == select.POLLIN:
try:
fdo = self._fdmap[fd]
self._process_input(fdo)
except ServerPacketError as err:
LOGGER.info("Dropping packet: %s", err)
except packet.PacketError as err:
LOGGER.info("Received a broken packet: %s", err)
else:
LOGGER.error("Unexpected event in server main loop")

View File

@@ -1,429 +0,0 @@
# server_async.py
#
# Copyright 2018-2019 Geaaru <geaaru@gmail.com>
import asyncio
import logging
from abc import abstractmethod, ABCMeta
from enum import Enum
from datetime import datetime
from pyrad.packet import (
Packet,
AccessAccept,
AccessReject,
AccountingRequest,
AccountingResponse,
DisconnectACK,
DisconnectNAK,
DisconnectRequest,
CoARequest,
CoAACK,
CoANAK,
AccessRequest,
AuthPacket,
AcctPacket,
CoAPacket,
PacketError,
)
from pyrad.server import ServerPacketError
class ServerType(Enum):
Auth = "Authentication"
Acct = "Accounting"
Coa = "Coa"
class DatagramProtocolServer(asyncio.Protocol):
def __init__(
self, ip, port, logger, server, server_type, hosts, request_callback
):
self.transport = None
self.ip = ip
self.port = port
self.logger = logger
self.server = server
self.hosts = hosts
self.server_type = server_type
self.request_callback = request_callback
def connection_made(self, transport):
self.transport = transport
self.logger.info("[%s:%d] Transport created", self.ip, self.port)
def connection_lost(self, exc):
if exc:
self.logger.warn(
"[%s:%d] Connection lost: %s", self.ip, self.port, str(exc)
)
else:
self.logger.info("[%s:%d] Transport closed", self.ip, self.port)
def send_response(self, reply, addr):
self.transport.sendto(reply.ReplyPacket(), addr)
def datagram_received(self, data, addr):
self.logger.debug(
"[%s:%d] Received %d bytes from %s",
self.ip,
self.port,
len(data),
addr,
)
receive_date = datetime.utcnow()
if addr[0] in self.hosts:
remote_host = self.hosts[addr[0]]
elif "0.0.0.0" in self.hosts:
remote_host = self.hosts["0.0.0.0"]
else:
self.logger.warn(
"[%s:%d] Drop package from unknown source %s",
self.ip,
self.port,
addr,
)
return
try:
self.logger.debug(
"[%s:%d] Received from %s packet: %s",
self.ip,
self.port,
addr,
data.hex(),
)
req = Packet(packet=data, dict=self.server.dict)
except Exception as exc:
self.logger.error(
"[%s:%d] Error on decode packet: %s", self.ip, self.port, exc
)
return
try:
if req.code in (
AccountingResponse,
AccessAccept,
AccessReject,
CoANAK,
CoAACK,
DisconnectNAK,
DisconnectACK,
):
raise ServerPacketError(f"Invalid response packet {req.code}")
elif self.server_type == ServerType.Auth:
if req.code != AccessRequest:
raise ServerPacketError(
"Received non-auth packet on auth port"
)
req = AuthPacket(
secret=remote_host.secret,
dict=self.server.dict,
packet=data,
)
if self.server.enable_pkt_verify:
if req.VerifyAuthRequest():
raise PacketError("Packet verification failed")
elif self.server_type == ServerType.Coa:
if req.code != DisconnectRequest and req.code != CoARequest:
raise ServerPacketError(
"Received non-coa packet on coa port"
)
req = CoAPacket(
secret=remote_host.secret,
dict=self.server.dict,
packet=data,
)
if self.server.enable_pkt_verify:
if req.VerifyCoARequest():
raise PacketError("Packet verification failed")
elif self.server_type == ServerType.Acct:
if req.code != AccountingRequest:
raise ServerPacketError(
"Received non-acct packet on acct port"
)
req = AcctPacket(
secret=remote_host.secret,
dict=self.server.dict,
packet=data,
)
if self.server.enable_pkt_verify:
if req.VerifyAcctRequest():
raise PacketError("Packet verification failed")
# Call request callback
self.request_callback(self, req, addr)
except Exception as exc:
if self.server.debug:
self.logger.exception(
"[%s:%d] Error for packet from %s", self.ip, self.port, addr
)
else:
self.logger.error(
"[%s:%d] Error for packet from %s: %s",
self.ip,
self.port,
addr,
exc,
)
process_date = datetime.utcnow()
self.logger.debug(
"[%s:%d] Request from %s processed in %d ms",
self.ip,
self.port,
addr,
(process_date - receive_date).microseconds / 1000,
)
def error_received(self, exc):
self.logger.error("[%s:%d] Error received: %s", self.ip, self.port, exc)
async def close_transport(self):
if self.transport:
self.logger.debug("[%s:%d] Close transport...", self.ip, self.port)
self.transport.close()
self.transport = None
def __str__(self):
return f"DatagramProtocolServer(ip={self.ip}, port={self.port})"
# Used as protocol_factory
def __call__(self):
return self
class ServerAsync(metaclass=ABCMeta):
def __init__(
self,
auth_port=1812,
acct_port=1813,
coa_port=3799,
hosts=None,
dictionary=None,
loop=None,
logger_name="pyrad",
enable_pkt_verify=False,
debug=False,
):
if not loop:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
self.logger = logging.getLogger(logger_name)
if hosts is None:
self.hosts = {}
else:
self.hosts = hosts
self.auth_port = auth_port
self.auth_protocols = []
self.acct_port = acct_port
self.acct_protocols = []
self.coa_port = coa_port
self.coa_protocols = []
self.dict = dictionary
self.enable_pkt_verify = enable_pkt_verify
self.debug = debug
def __request_handler__(self, protocol, req, addr):
try:
if protocol.server_type == ServerType.Acct:
self.handle_acct_packet(protocol, req, addr)
elif protocol.server_type == ServerType.Auth:
self.handle_auth_packet(protocol, req, addr)
elif (
protocol.server_type == ServerType.Coa
and req.code == CoARequest
):
self.handle_coa_packet(protocol, req, addr)
elif (
protocol.server_type == ServerType.Coa
and req.code == DisconnectRequest
):
self.handle_disconnect_packet(protocol, req, addr)
else:
self.logger.error(
"[%s:%s] Unexpected request found",
protocol.ip,
protocol.port,
)
except Exception as exc:
if self.debug:
self.logger.exception(
"[%s:%s] Unexpected error", protocol.ip, protocol.port
)
else:
self.logger.error(
"[%s:%s] Unexpected error: %s",
protocol.ip,
protocol.port,
exc,
)
def __is_present_proto__(self, ip, port):
if port == self.auth_port:
for proto in self.auth_protocols:
if proto.ip == ip:
return True
elif port == self.acct_port:
for proto in self.acct_protocols:
if proto.ip == ip:
return True
elif port == self.coa_port:
for proto in self.coa_protocols:
if proto.ip == ip:
return True
return False
# noinspection PyPep8Naming
@staticmethod
def CreateReplyPacket(pkt, **attributes):
"""Create a reply packet.
Create a new packet which can be returned as a reply to a received
packet.
:param pkt: original packet
:type pkt: Packet instance
"""
reply = pkt.CreateReply(**attributes)
return reply
async def initialize_transports(
self,
enable_acct=False,
enable_auth=False,
enable_coa=False,
addresses=None,
):
task_list = []
if not enable_acct and not enable_auth and not enable_coa:
raise Exception("No transports selected")
if not addresses or len(addresses) == 0:
addresses = ["127.0.0.1"]
# noinspection SpellCheckingInspection
for addr in addresses:
if enable_acct and not self.__is_present_proto__(
addr, self.acct_port
):
protocol_acct = DatagramProtocolServer(
addr,
self.acct_port,
self.logger,
self,
ServerType.Acct,
self.hosts,
self.__request_handler__,
)
bind_addr = (addr, self.acct_port)
acct_connect = self.loop.create_datagram_endpoint(
protocol_acct, reuse_port=True, local_addr=bind_addr
)
self.acct_protocols.append(protocol_acct)
task_list.append(acct_connect)
if enable_auth and not self.__is_present_proto__(
addr, self.auth_port
):
protocol_auth = DatagramProtocolServer(
addr,
self.auth_port,
self.logger,
self,
ServerType.Auth,
self.hosts,
self.__request_handler__,
)
bind_addr = (addr, self.auth_port)
auth_connect = self.loop.create_datagram_endpoint(
protocol_auth, reuse_port=True, local_addr=bind_addr
)
self.auth_protocols.append(protocol_auth)
task_list.append(auth_connect)
if enable_coa and not self.__is_present_proto__(
addr, self.coa_port
):
protocol_coa = DatagramProtocolServer(
addr,
self.coa_port,
self.logger,
self,
ServerType.Coa,
self.hosts,
self.__request_handler__,
)
bind_addr = (addr, self.coa_port)
coa_connect = self.loop.create_datagram_endpoint(
protocol_coa, reuse_port=True, local_addr=bind_addr
)
self.coa_protocols.append(protocol_coa)
task_list.append(coa_connect)
await asyncio.ensure_future(
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
)
# noinspection SpellCheckingInspection
async def deinitialize_transports(
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
):
if deinit_coa:
for proto in self.coa_protocols:
await proto.close_transport()
del proto
self.coa_protocols = []
if deinit_auth:
for proto in self.auth_protocols:
await proto.close_transport()
del proto
self.auth_protocols = []
if deinit_acct:
for proto in self.acct_protocols:
await proto.close_transport()
del proto
self.acct_protocols = []
@abstractmethod
def handle_auth_packet(self, protocol, pkt, addr):
pass
@abstractmethod
def handle_acct_packet(self, protocol, pkt, addr):
pass
@abstractmethod
def handle_coa_packet(self, protocol, pkt, addr):
pass
@abstractmethod
def handle_disconnect_packet(self, protocol, pkt, addr):
pass

View File

@@ -4,7 +4,14 @@
"""Collections of functions to en- and decode RADIUS Attributes"""
from typing import Union
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network, ip_address
from ipaddress import (
IPv4Address,
IPv4Network,
IPv6Address,
IPv6Network,
ip_network,
ip_address,
)
import struct
@@ -33,13 +40,17 @@ def encode_address(addr: Union[str, IPv4Address]) -> bytes:
def encode_network(network: Union[str, IPv4Network]) -> bytes:
"""Encode a RADIUS value of type ipv4prefix"""
address = IPv4Network(network)
return struct.pack("2B", 0, address.prefixlen) + address.network_address.packed
return (
struct.pack("2B", 0, address.prefixlen) + address.network_address.packed
)
def encode_ipv6_prefix(network: Union[str, IPv6Network]) -> bytes:
"""Encode a RADIUS value of type ipv6prefix"""
address = IPv6Network(network)
return struct.pack("2B", 0, address.prefixlen) + address.network_address.packed.rstrip(b'\0')
return struct.pack(
"2B", 0, address.prefixlen
) + address.network_address.packed.rstrip(b"\0")
def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes:
@@ -212,7 +223,7 @@ ENCODE_MAP = {
"signed": lambda value: encode_integer(value, "!i"),
"short": lambda value: encode_integer(value, "!H"),
"byte": lambda value: encode_integer(value, "!B"),
"integer64": lambda value: encode_integer(value, '!Q'),
"integer64": lambda value: encode_integer(value, "!Q"),
"date": encode_date,
}

View File

@@ -127,11 +127,16 @@ def parse_key(rad_dict: Dictionary, key_id: int) -> Union[str, int]:
return key_id
def parse_value(
rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
) -> List[Attribute]:
def parse_value(*_):
"""Parse the Value in the given Key/Dictionary Context"""
return []
raise NotImplementedError
# def parse_value(
# rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
# ) -> List[Attribute]:
# """Parse the Value in the given Key/Dictionary Context"""
# return []
def calculate_authenticator(