cleanup and a lot more tests
This commit is contained in:
@@ -36,11 +36,11 @@ This package contains four modules:
|
||||
- tools: utility functions
|
||||
"""
|
||||
|
||||
__docformat__ = 'epytext en'
|
||||
__docformat__ = "epytext en"
|
||||
|
||||
__author__ = 'Istvan Ruzman <istvan@ruzman.eu>'
|
||||
__url__ = 'http://pyrad.readthedocs.io/en/latest/?badge=latest'
|
||||
__copyright__ = 'Copyright 2020 Istvan Ruzman'
|
||||
__version__ = '0.1.0'
|
||||
__author__ = "Istvan Ruzman <istvan@ruzman.eu>"
|
||||
__url__ = "http://pyrad.readthedocs.io/en/latest/?badge=latest"
|
||||
__copyright__ = "Copyright 2020 Istvan Ruzman"
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'utils']
|
||||
__all__ = ["client", "dictionary", "packet", "server", "tools", "utils"]
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# bidict.py
|
||||
#
|
||||
# Bidirectional map
|
||||
|
||||
|
||||
class BiDict:
|
||||
def __init__(self):
|
||||
self.forward = {}
|
||||
self.backward = {}
|
||||
|
||||
def add(self, one, two):
|
||||
self.forward[one] = two
|
||||
self.backward[two] = one
|
||||
|
||||
def __len__(self):
|
||||
return len(self.forward)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.get_forward(key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
try:
|
||||
del self.backward[self.forward[key]]
|
||||
del self.forward[key]
|
||||
except KeyError:
|
||||
del self.forward[self.backward[key]]
|
||||
del self.backward[key]
|
||||
|
||||
def get_forward(self, key):
|
||||
return self.forward[key]
|
||||
|
||||
def has_forward(self, key):
|
||||
return key in self.forward
|
||||
|
||||
def get_backward(self, key):
|
||||
return self.backward[key]
|
||||
|
||||
def has_backward(self, key):
|
||||
return key in self.backward
|
||||
@@ -1,442 +0,0 @@
|
||||
# client_async.py
|
||||
#
|
||||
# Copyright 2018-2020 Geaaru <geaaru<@>gmail.com>
|
||||
|
||||
__docformat__ = "epytext en"
|
||||
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
|
||||
from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket
|
||||
|
||||
|
||||
class DatagramProtocolClient(asyncio.Protocol):
|
||||
def __init__(self, server, port, logger, client, retries=3, timeout=30):
|
||||
self.transport = None
|
||||
self.port = port
|
||||
self.server = server
|
||||
self.logger = logger
|
||||
self.retries = retries
|
||||
self.timeout = timeout
|
||||
self.client = client
|
||||
|
||||
# Map of pending requests
|
||||
self.pending_requests = {}
|
||||
|
||||
# Use cryptographic-safe random generator as provided by the OS.
|
||||
random_generator = random.SystemRandom()
|
||||
self.packet_id = random_generator.randrange(0, 256)
|
||||
|
||||
self.timeout_future = None
|
||||
|
||||
async def __timeout_handler__(self):
|
||||
try:
|
||||
while True:
|
||||
req2delete = []
|
||||
now = datetime.now()
|
||||
next_weak_up = self.timeout
|
||||
# noinspection PyShadowingBuiltins
|
||||
for id, req in self.pending_requests.items():
|
||||
|
||||
secs = (req["send_date"] - now).seconds
|
||||
if secs > self.timeout:
|
||||
if req["retries"] == self.retries:
|
||||
self.logger.debug(
|
||||
"[%s:%d] For request %d execute all retries",
|
||||
self.server,
|
||||
self.port,
|
||||
id,
|
||||
)
|
||||
req["future"].set_exception(
|
||||
TimeoutError("Timeout on Reply")
|
||||
)
|
||||
req2delete.append(id)
|
||||
else:
|
||||
# Send again packet
|
||||
req["send_date"] = now
|
||||
req["retries"] += 1
|
||||
self.logger.debug(
|
||||
"[%s:%d] For request %d execute retry %d",
|
||||
self.server,
|
||||
self.port,
|
||||
id,
|
||||
req["retries"],
|
||||
)
|
||||
self.transport.sendto(req["packet"].RequestPacket())
|
||||
elif next_weak_up > secs:
|
||||
next_weak_up = secs
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
for id in req2delete:
|
||||
# Remove request for map
|
||||
del self.pending_requests[id]
|
||||
|
||||
await asyncio.sleep(next_weak_up)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def send_packet(self, packet, future):
|
||||
if packet.id in self.pending_requests:
|
||||
raise Exception(f"Packet with id {packet.id} already present")
|
||||
|
||||
# Store packet on pending requests map
|
||||
self.pending_requests[packet.id] = {
|
||||
"packet": packet,
|
||||
"creation_date": datetime.now(),
|
||||
"retries": 0,
|
||||
"future": future,
|
||||
"send_date": datetime.now(),
|
||||
}
|
||||
|
||||
# In queue packet raw on socket buffer
|
||||
self.transport.sendto(packet.RequestPacket())
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
socket = transport.get_extra_info("socket")
|
||||
self.logger.info(
|
||||
"[%s:%d] Transport created with binding in %s:%d",
|
||||
self.server,
|
||||
self.port,
|
||||
socket.getsockname()[0],
|
||||
socket.getsockname()[1],
|
||||
)
|
||||
|
||||
pre_loop = asyncio.get_event_loop()
|
||||
asyncio.set_event_loop(loop=self.client.loop)
|
||||
# Start asynchronous timer handler
|
||||
self.timeout_future = asyncio.ensure_future(self.__timeout_handler__())
|
||||
asyncio.set_event_loop(loop=pre_loop)
|
||||
|
||||
def error_received(self, exc):
|
||||
self.logger.error(
|
||||
"[%s:%d] Error received: %s", self.server, self.port, exc
|
||||
)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if exc:
|
||||
self.logger.warn(
|
||||
"[%s:%d] Connection lost: %s", self.server, self.port, str(exc)
|
||||
)
|
||||
else:
|
||||
self.logger.info("[%s:%d] Transport closed", self.server, self.port)
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
def datagram_received(self, data, addr):
|
||||
try:
|
||||
req = self.pending_requests[data[0]]
|
||||
reply = req.VerifyPacket(data)
|
||||
req["future"].set_result(reply)
|
||||
# Remove request for map
|
||||
del self.pending_requests[reply.id]
|
||||
except KeyError:
|
||||
self.logger.warn(
|
||||
"[%s:%d] Ignore invalid reply: %s", self.server, self.port, data
|
||||
)
|
||||
except PacketError as exc:
|
||||
self.logger.error(
|
||||
"[%s:%d] Error on decode or verify packet: %s",
|
||||
self.server,
|
||||
self.port,
|
||||
exc,
|
||||
)
|
||||
|
||||
async def close_transport(self):
|
||||
if self.transport:
|
||||
self.logger.debug(
|
||||
"[%s:%d] Closing transport...", self.server, self.port
|
||||
)
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
if self.timeout_future:
|
||||
self.timeout_future.cancel()
|
||||
await self.timeout_future
|
||||
self.timeout_future = None
|
||||
|
||||
def create_id(self):
|
||||
self.packet_id = (self.packet_id + 1) % 256
|
||||
return self.packet_id
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"DatagramProtocolClient(server?={self.server}, port={self.port})"
|
||||
)
|
||||
|
||||
# Used as protocol_factory
|
||||
def __call__(self):
|
||||
return self
|
||||
|
||||
|
||||
class ClientAsync:
|
||||
"""Basic RADIUS client.
|
||||
This class implements a basic RADIUS client. It can send requests
|
||||
to a RADIUS server, taking care of timeouts and retries, and
|
||||
validate its replies.
|
||||
|
||||
:ivar retries: number of times to retry sending a RADIUS request
|
||||
:type retries: integer
|
||||
:ivar timeout: number of seconds to wait for an answer
|
||||
:type timeout: integer
|
||||
"""
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
def __init__(
|
||||
self,
|
||||
server,
|
||||
auth_port=1812,
|
||||
acct_port=1813,
|
||||
coa_port=3799,
|
||||
secret=b"",
|
||||
dict=None,
|
||||
loop=None,
|
||||
retries=3,
|
||||
timeout=30,
|
||||
logger_name="pyrad",
|
||||
):
|
||||
|
||||
"""Constructor.
|
||||
|
||||
:param server: hostname or IP address of RADIUS server
|
||||
:type server: string
|
||||
:param auth_port: port to use for authentication packets
|
||||
:type auth_port: integer
|
||||
:param acct_port: port to use for accounting packets
|
||||
:type acct_port: integer
|
||||
:param coa_port: port to use for CoA packets
|
||||
:type coa_port: integer
|
||||
:param secret: RADIUS secret
|
||||
:type secret: string
|
||||
:param dict: RADIUS dictionary
|
||||
:type dict: pyrad.dictionary.Dictionary
|
||||
:param loop: Python loop handler
|
||||
:type loop: asyncio event loop
|
||||
"""
|
||||
if not loop:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self.loop = loop
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
|
||||
self.server = server
|
||||
self.secret = secret
|
||||
self.retries = retries
|
||||
self.timeout = timeout
|
||||
self.dict = dict
|
||||
|
||||
self.auth_port = auth_port
|
||||
self.protocol_auth = None
|
||||
|
||||
self.acct_port = acct_port
|
||||
self.protocol_acct = None
|
||||
|
||||
self.protocol_coa = None
|
||||
self.coa_port = coa_port
|
||||
|
||||
async def initialize_transports(
|
||||
self,
|
||||
enable_acct=False,
|
||||
enable_auth=False,
|
||||
enable_coa=False,
|
||||
local_addr=None,
|
||||
local_auth_port=None,
|
||||
local_acct_port=None,
|
||||
local_coa_port=None,
|
||||
):
|
||||
|
||||
task_list = []
|
||||
|
||||
if not enable_acct and not enable_auth and not enable_coa:
|
||||
raise Exception("No transports selected")
|
||||
|
||||
if enable_acct and not self.protocol_acct:
|
||||
self.protocol_acct = DatagramProtocolClient(
|
||||
self.server,
|
||||
self.acct_port,
|
||||
self.logger,
|
||||
self,
|
||||
retries=self.retries,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
bind_addr = None
|
||||
if local_addr and local_acct_port:
|
||||
bind_addr = (local_addr, local_acct_port)
|
||||
|
||||
acct_connect = self.loop.create_datagram_endpoint(
|
||||
self.protocol_acct,
|
||||
reuse_port=True,
|
||||
remote_addr=(self.server, self.acct_port),
|
||||
local_addr=bind_addr,
|
||||
)
|
||||
task_list.append(acct_connect)
|
||||
|
||||
if enable_auth and not self.protocol_auth:
|
||||
self.protocol_auth = DatagramProtocolClient(
|
||||
self.server,
|
||||
self.auth_port,
|
||||
self.logger,
|
||||
self,
|
||||
retries=self.retries,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
bind_addr = None
|
||||
if local_addr and local_auth_port:
|
||||
bind_addr = (local_addr, local_auth_port)
|
||||
|
||||
auth_connect = self.loop.create_datagram_endpoint(
|
||||
self.protocol_auth,
|
||||
reuse_port=True,
|
||||
remote_addr=(self.server, self.auth_port),
|
||||
local_addr=bind_addr,
|
||||
)
|
||||
task_list.append(auth_connect)
|
||||
|
||||
if enable_coa and not self.protocol_coa:
|
||||
self.protocol_coa = DatagramProtocolClient(
|
||||
self.server,
|
||||
self.coa_port,
|
||||
self.logger,
|
||||
self,
|
||||
retries=self.retries,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
bind_addr = None
|
||||
if local_addr and local_coa_port:
|
||||
bind_addr = (local_addr, local_coa_port)
|
||||
|
||||
coa_connect = self.loop.create_datagram_endpoint(
|
||||
self.protocol_coa,
|
||||
reuse_port=True,
|
||||
remote_addr=(self.server, self.coa_port),
|
||||
local_addr=bind_addr,
|
||||
)
|
||||
task_list.append(coa_connect)
|
||||
|
||||
await asyncio.ensure_future(
|
||||
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
|
||||
)
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
async def deinitialize_transports(
|
||||
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
|
||||
):
|
||||
if self.protocol_coa and deinit_coa:
|
||||
await self.protocol_coa.close_transport()
|
||||
del self.protocol_coa
|
||||
self.protocol_coa = None
|
||||
if self.protocol_auth and deinit_auth:
|
||||
await self.protocol_auth.close_transport()
|
||||
del self.protocol_auth
|
||||
self.protocol_auth = None
|
||||
if self.protocol_acct and deinit_acct:
|
||||
await self.protocol_acct.close_transport()
|
||||
del self.protocol_acct
|
||||
self.protocol_acct = None
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def CreateAuthPacket(self, **args):
|
||||
"""Create a new RADIUS packet.
|
||||
This utility function creates a new RADIUS packet which can
|
||||
be used to communicate with the RADIUS server this client
|
||||
talks to. This is initializing the new packet with the
|
||||
dictionary and secret used for the client.
|
||||
|
||||
:return: a new empty packet instance
|
||||
:rtype: pyrad.packet.Packet
|
||||
"""
|
||||
if not self.protocol_auth:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
return AuthPacket(
|
||||
dict=self.dict,
|
||||
id=self.protocol_auth.create_id(),
|
||||
secret=self.secret,
|
||||
**args,
|
||||
)
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def CreateAcctPacket(self, **args):
|
||||
"""Create a new RADIUS packet.
|
||||
This utility function creates a new RADIUS packet which can
|
||||
be used to communicate with the RADIUS server this client
|
||||
talks to. This is initializing the new packet with the
|
||||
dictionary and secret used for the client.
|
||||
|
||||
:return: a new empty packet instance
|
||||
:rtype: pyrad.packet.Packet
|
||||
"""
|
||||
if not self.protocol_acct:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
return AcctPacket(
|
||||
id=self.protocol_acct.create_id(),
|
||||
dict=self.dict,
|
||||
secret=self.secret,
|
||||
**args,
|
||||
)
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def CreateCoAPacket(self, **args):
|
||||
"""Create a new RADIUS packet.
|
||||
This utility function creates a new RADIUS packet which can
|
||||
be used to communicate with the RADIUS server this client
|
||||
talks to. This is initializing the new packet with the
|
||||
dictionary and secret used for the client.
|
||||
|
||||
:return: a new empty packet instance
|
||||
:rtype: pyrad.packet.Packet
|
||||
"""
|
||||
|
||||
if not self.protocol_acct:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
return CoAPacket(
|
||||
id=self.protocol_coa.create_id(),
|
||||
dict=self.dict,
|
||||
secret=self.secret,
|
||||
**args,
|
||||
)
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
# noinspection PyShadowingBuiltins
|
||||
def CreatePacket(self, id, **args):
|
||||
if not id:
|
||||
raise Exception("Missing mandatory packet id")
|
||||
|
||||
return Packet(id=id, dict=self.dict, secret=self.secret, **args)
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def SendPacket(self, pkt):
|
||||
"""Send a packet to a RADIUS server.
|
||||
|
||||
:param pkt: the packet to send
|
||||
:type pkt: pyrad.packet.Packet
|
||||
:return: Future related with packet to send
|
||||
:rtype: asyncio.Future
|
||||
"""
|
||||
|
||||
ans = asyncio.Future(loop=self.loop)
|
||||
|
||||
if isinstance(pkt, AuthPacket):
|
||||
if not self.protocol_auth:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
self.protocol_auth.send_packet(pkt, ans)
|
||||
|
||||
elif isinstance(pkt, AcctPacket):
|
||||
if not self.protocol_acct:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
self.protocol_acct.send_packet(pkt, ans)
|
||||
|
||||
elif isinstance(pkt, CoAPacket):
|
||||
if not self.protocol_coa:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
self.protocol_coa.send_packet(pkt, ans)
|
||||
|
||||
else:
|
||||
raise Exception("Unsupported packet")
|
||||
|
||||
return ans
|
||||
@@ -10,7 +10,6 @@ from enum import IntEnum, Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from os.path import dirname, isabs, join, normpath
|
||||
from typing import (
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
IO,
|
||||
@@ -167,18 +166,6 @@ def _parse_number(num: str) -> int:
|
||||
return int(num)
|
||||
|
||||
|
||||
def _parse_attribute_code(attr_code: str) -> List[int]:
|
||||
"""Parse attribute codes from (Free)RADIUS dictionaries
|
||||
|
||||
Codes can be either decimal, octal, or hexadecimal.
|
||||
TLV typed can
|
||||
"""
|
||||
codes = []
|
||||
for code in attr_code.split("."):
|
||||
codes.append(_parse_number(code))
|
||||
return codes
|
||||
|
||||
|
||||
class Dictionary:
|
||||
"""(Free)RADIUS Dictionary.
|
||||
|
||||
@@ -358,7 +345,7 @@ class Dictionary:
|
||||
|
||||
for flag in flags:
|
||||
flag_len = len(flag)
|
||||
if flag == 1:
|
||||
if flag_len == 1:
|
||||
value = None
|
||||
elif flag_len == 2:
|
||||
value = flag[1]
|
||||
@@ -372,6 +359,8 @@ class Dictionary:
|
||||
has_tag = True
|
||||
elif key == "encrypt":
|
||||
try:
|
||||
if value == "0":
|
||||
raise ValueError
|
||||
encrypt = Encrypt(int(value)) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
raise ParseError(
|
||||
@@ -386,6 +375,32 @@ class Dictionary:
|
||||
|
||||
return has_tag, encrypt
|
||||
|
||||
def _parse_attribute_code(self, attr_code: str, line_num: int) -> List[int]:
|
||||
filename = self.filestack[-1]
|
||||
tlength = self.cur_vendor.tlength
|
||||
codes = []
|
||||
for code in attr_code.split("."):
|
||||
try:
|
||||
code_num = _parse_number(code)
|
||||
except ValueError:
|
||||
raise ParseError(
|
||||
filename, f'invalid attribute code {attr_code}""', line_num
|
||||
)
|
||||
if 2 ** (8 * tlength) <= code_num:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"attribute code is too big, must be smaller than 2**{tlength}",
|
||||
line_num,
|
||||
)
|
||||
if code_num < 0:
|
||||
raise ParseError(
|
||||
filename,
|
||||
"negative attribute codes are not allowed",
|
||||
line_num,
|
||||
)
|
||||
codes.append(code_num)
|
||||
return codes
|
||||
|
||||
def _parse_attribute(self, tokens: Sequence[str], line_num: int):
|
||||
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
|
||||
filename = self.filestack[-1]
|
||||
@@ -409,27 +424,7 @@ class Dictionary:
|
||||
line_num,
|
||||
)
|
||||
|
||||
try:
|
||||
codes = _parse_attribute_code(attr_code)
|
||||
except ValueError:
|
||||
raise ParseError(
|
||||
filename, f'invalid attribute code {attr_code}""', line_num
|
||||
)
|
||||
|
||||
for code in codes:
|
||||
tlength = self.cur_vendor.tlength
|
||||
if 2 ** (8 * tlength) <= code:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"attribute code is too big, must be smaller than 2**{tlength}",
|
||||
line_num,
|
||||
)
|
||||
if code < 0:
|
||||
raise ParseError(
|
||||
filename,
|
||||
"negative attribute codes are not allowed",
|
||||
line_num,
|
||||
)
|
||||
codes = self._parse_attribute_code(attr_code, line_num)
|
||||
|
||||
# TODO: Do we some explicit handling of tlvs?
|
||||
# if len(codes) > 1:
|
||||
@@ -485,7 +480,14 @@ class Dictionary:
|
||||
filename, f"Invalid number {vvalue} for VALUE {key}", line_num
|
||||
)
|
||||
|
||||
attribute = self.attrindex[attr_name]
|
||||
try:
|
||||
attribute = self.attrindex[attr_name]
|
||||
except KeyError:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"ATTRIBUTE {attr_name} has not been defined yet",
|
||||
line_num,
|
||||
)
|
||||
try:
|
||||
datatype = str(attribute.datatype).split(".")[1]
|
||||
lmin, lmax = INTEGER_TYPES[datatype]
|
||||
@@ -504,3 +506,6 @@ class Dictionary:
|
||||
)
|
||||
attribute.values[value] = key
|
||||
attribute.values[key] = value
|
||||
|
||||
def __getitem__(self, key: Union[str, int, Tuple[int, ...]]) -> Attribute:
|
||||
return self.attrindex[key]
|
||||
|
||||
@@ -7,8 +7,9 @@ from pyrad3.dictionary import Dictionary
|
||||
from pyrad3 import packet
|
||||
|
||||
|
||||
class Host: # pylint: disable=too-many-arguments
|
||||
class Host: # pylint: disable=too-many-arguments
|
||||
"""Interface Class for RADIUS Clients and Servers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret: bytes,
|
||||
|
||||
@@ -169,7 +169,8 @@ class Packet(OrderedDict):
|
||||
self["Message-Authenticator"] = 16 * b"\00"
|
||||
attr = self.ordered_attributes[-1]
|
||||
generated = self._generate_message_authenticator(attr)
|
||||
self[attr.pos + 2 :] = generated
|
||||
index = attr.pos + 2
|
||||
self[index:] = generated
|
||||
|
||||
def refresh_message_authenticator(self):
|
||||
self.add_message_authenticator()
|
||||
@@ -283,7 +284,7 @@ class AcctPacket(Packet):
|
||||
|
||||
def increase_acct_delay_time(self, delay_time: float):
|
||||
try:
|
||||
self['Acct-Delay-Time'] += int(delay_time)
|
||||
self["Acct-Delay-Time"] += int(delay_time)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
# proxy.py
|
||||
#
|
||||
# Copyright 2005,2007 Wichert Akkerman <wichert@wiggy.net>
|
||||
#
|
||||
# A RADIUS proxy as defined in RFC 2138
|
||||
|
||||
import select
|
||||
import socket
|
||||
|
||||
from pyrad.server import Server, ServerPacketError
|
||||
from pyrad import packet
|
||||
|
||||
|
||||
class Proxy(Server):
|
||||
"""Base class for RADIUS proxies.
|
||||
This class extends tha RADIUS server class with the capability to
|
||||
handle communication with other RADIUS servers as well.
|
||||
|
||||
:ivar _proxyfd: network socket used to communicate with other servers
|
||||
:type _proxyfd: socket class instance
|
||||
"""
|
||||
|
||||
def _prepare_sockets(self):
|
||||
Server._prepare_sockets(self)
|
||||
self._proxyfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._fdmap[self._proxyfd.fileno()] = self._proxyfd
|
||||
self._poll.register(
|
||||
self._proxyfd.fileno(),
|
||||
(select.POLLIN | select.POLLPRI | select.POLLERR),
|
||||
)
|
||||
|
||||
def _handle_proxy_packet(self, pkt):
|
||||
"""Process a packet received on the reply socket.
|
||||
If this packet should be dropped instead of processed a
|
||||
:obj:`ServerPacketError` exception should be raised. The main loop
|
||||
will drop the packet and log the reason.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
if pkt.source[0] not in self.hosts:
|
||||
raise ServerPacketError("Received packet from unknown host")
|
||||
pkt.secret = self.hosts[pkt.source[0]].secret
|
||||
|
||||
if pkt.code not in [
|
||||
packet.AccessAccept,
|
||||
packet.AccessReject,
|
||||
packet.AccountingResponse,
|
||||
]:
|
||||
raise ServerPacketError("Received non-response on proxy socket")
|
||||
|
||||
def _process_input(self, fd):
|
||||
"""Process available data.
|
||||
If this packet should be dropped instead of processed a
|
||||
`ServerPacketError` exception should be raised. The main loop
|
||||
will drop the packet and log the reason.
|
||||
|
||||
This function calls either :obj:`HandleAuthPacket`,
|
||||
:obj:`HandleAcctPacket` or :obj:`_handle_proxy_packet` depending on
|
||||
which socket is being processed.
|
||||
|
||||
:param fd: socket to read packet from
|
||||
:type fd: socket class instance
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
if fd.fileno() == self._proxyfd.fileno():
|
||||
pkt = self._grab_packet(
|
||||
lambda data, s=self: s.CreatePacket(packet=data), fd
|
||||
)
|
||||
self._handle_proxy_packet(pkt)
|
||||
else:
|
||||
Server._process_input(self, fd)
|
||||
@@ -1,367 +0,0 @@
|
||||
# server.py
|
||||
#
|
||||
# Copyright 2003-2004,2007,2016 Wichert Akkerman <wichert@wiggy.net>
|
||||
|
||||
import logging
|
||||
import select
|
||||
import socket
|
||||
from pyrad import host
|
||||
from pyrad import packet
|
||||
|
||||
|
||||
LOGGER = logging.getLogger("pyrad")
|
||||
|
||||
|
||||
class RemoteHost:
|
||||
"""Remote RADIUS capable host we can talk to."""
|
||||
|
||||
def __init__(
|
||||
self, address, secret, name, authport=1812, acctport=1813, coaport=3799
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
:param address: IP address
|
||||
:type address: string
|
||||
:param secret: RADIUS secret
|
||||
:type secret: string
|
||||
:param name: short name (used for logging only)
|
||||
:type name: string
|
||||
:param authport: port used for authentication packets
|
||||
:type authport: integer
|
||||
:param acctport: port used for accounting packets
|
||||
:type acctport: integer
|
||||
:param coaport: port used for CoA packets
|
||||
:type coaport: integer
|
||||
"""
|
||||
self.address = address
|
||||
self.secret = secret
|
||||
self.authport = authport
|
||||
self.acctport = acctport
|
||||
self.coaport = coaport
|
||||
self.name = name
|
||||
|
||||
|
||||
class ServerPacketError(Exception):
|
||||
"""Exception class for bogus packets.
|
||||
ServerPacketError exceptions are only used inside the Server class to
|
||||
abort processing of a packet.
|
||||
"""
|
||||
|
||||
|
||||
class Server(host.Host):
|
||||
"""Basic RADIUS server.
|
||||
This class implements the basics of a RADIUS server. It takes care
|
||||
of the details of receiving and decoding requests; processing of
|
||||
the requests should be done by overloading the appropriate methods
|
||||
in derived classes.
|
||||
|
||||
:ivar hosts: hosts who are allowed to talk to us
|
||||
:type hosts: dictionary of Host class instances
|
||||
:ivar _poll: poll object for network sockets
|
||||
:type _poll: select.poll class instance
|
||||
:ivar _fdmap: map of filedescriptors to network sockets
|
||||
:type _fdmap: dictionary
|
||||
:cvar MaxPacketSize: maximum size of a RADIUS packet
|
||||
:type MaxPacketSize: integer
|
||||
"""
|
||||
|
||||
MaxPacketSize = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
addresses=[],
|
||||
authport=1812,
|
||||
acctport=1813,
|
||||
coaport=3799,
|
||||
hosts=None,
|
||||
dict=None,
|
||||
auth_enabled=True,
|
||||
acct_enabled=True,
|
||||
coa_enabled=False,
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
:param addresses: IP addresses to listen on
|
||||
:type addresses: sequence of strings
|
||||
:param authport: port to listen on for authentication packets
|
||||
:type authport: integer
|
||||
:param acctport: port to listen on for accounting packets
|
||||
:type acctport: integer
|
||||
:param coaport: port to listen on for CoA packets
|
||||
:type coaport: integer
|
||||
:param hosts: hosts who we can talk to
|
||||
:type hosts: dictionary mapping IP to RemoteHost class instances
|
||||
:param dict: RADIUS dictionary to use
|
||||
:type dict: Dictionary class instance
|
||||
:param auth_enabled: enable auth server (default True)
|
||||
:type auth_enabled: bool
|
||||
:param acct_enabled: enable accounting server (default True)
|
||||
:type acct_enabled: bool
|
||||
:param coa_enabled: enable coa server (default False)
|
||||
:type coa_enabled: bool
|
||||
"""
|
||||
host.Host.__init__(self, authport, acctport, coaport, dict)
|
||||
if hosts is None:
|
||||
self.hosts = {}
|
||||
else:
|
||||
self.hosts = hosts
|
||||
|
||||
self.auth_enabled = auth_enabled
|
||||
self.authfds = []
|
||||
self.acct_enabled = acct_enabled
|
||||
self.acctfds = []
|
||||
self.coa_enabled = coa_enabled
|
||||
self.coafds = []
|
||||
|
||||
for addr in addresses:
|
||||
self.BindToAddress(addr)
|
||||
|
||||
def _get_addr_info(self, addr):
|
||||
"""Use getaddrinfo to lookup all addresses for each address.
|
||||
|
||||
Returns a list of tuples or an empty list:
|
||||
[(family, address)]
|
||||
|
||||
:param addr: IP address to lookup
|
||||
:type addr: string
|
||||
"""
|
||||
results = set()
|
||||
try:
|
||||
tmp = socket.getaddrinfo(addr, "www")
|
||||
except socket.gaierror:
|
||||
return []
|
||||
|
||||
for el in tmp:
|
||||
results.add((el[0], el[4][0]))
|
||||
|
||||
return results
|
||||
|
||||
def BindToAddress(self, addr):
|
||||
"""Add an address to listen to.
|
||||
An empty string indicated you want to listen on all addresses.
|
||||
|
||||
:param addr: IP address to listen on
|
||||
:type addr: string
|
||||
"""
|
||||
addrFamily = self._get_addr_info(addr)
|
||||
for (family, address) in addrFamily:
|
||||
if self.auth_enabled:
|
||||
authfd = socket.socket(family, socket.SOCK_DGRAM)
|
||||
authfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
authfd.bind((address, self.authport))
|
||||
self.authfds.append(authfd)
|
||||
|
||||
if self.acct_enabled:
|
||||
acctfd = socket.socket(family, socket.SOCK_DGRAM)
|
||||
acctfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
acctfd.bind((address, self.acctport))
|
||||
self.acctfds.append(acctfd)
|
||||
|
||||
if self.coa_enabled:
|
||||
coafd = socket.socket(family, socket.SOCK_DGRAM)
|
||||
coafd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
coafd.bind((address, self.coaport))
|
||||
self.coafds.append(coafd)
|
||||
|
||||
def HandleAuthPacket(self, pkt):
|
||||
"""Authentication packet handler.
|
||||
This is an empty function that is called when a valid
|
||||
authentication packet has been received. It can be overriden in
|
||||
derived classes to add custom behaviour.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
|
||||
def HandleAcctPacket(self, pkt):
|
||||
"""Accounting packet handler.
|
||||
This is an empty function that is called when a valid
|
||||
accounting packet has been received. It can be overriden in
|
||||
derived classes to add custom behaviour.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
|
||||
def HandleCoaPacket(self, pkt):
|
||||
"""CoA packet handler.
|
||||
This is an empty function that is called when a valid
|
||||
accounting packet has been received. It can be overriden in
|
||||
derived classes to add custom behaviour.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
|
||||
def HandleDisconnectPacket(self, pkt):
|
||||
"""CoA packet handler.
|
||||
This is an empty function that is called when a valid
|
||||
accounting packet has been received. It can be overriden in
|
||||
derived classes to add custom behaviour.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
|
||||
def _add_secret(self, pkt):
|
||||
"""Add secret to packets received and raise ServerPacketError
|
||||
for unknown hosts.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
if pkt.source[0] in self.hosts:
|
||||
pkt.secret = self.hosts[pkt.source[0]].secret
|
||||
elif "0.0.0.0" in self.hosts:
|
||||
pkt.secret = self.hosts["0.0.0.0"].secret
|
||||
else:
|
||||
raise ServerPacketError("Received packet from unknown host")
|
||||
|
||||
def _handle_auth_packet(self, pkt):
|
||||
"""Process a packet received on the authentication port.
|
||||
If this packet should be dropped instead of processed a
|
||||
ServerPacketError exception should be raised. The main loop will
|
||||
drop the packet and log the reason.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
self._add_secret(pkt)
|
||||
if pkt.code != packet.AccessRequest:
|
||||
raise ServerPacketError(
|
||||
"Received non-authentication packet on authentication port"
|
||||
)
|
||||
self.HandleAuthPacket(pkt)
|
||||
|
||||
def _handle_acct_packet(self, pkt):
|
||||
"""Process a packet received on the accounting port.
|
||||
If this packet should be dropped instead of processed a
|
||||
ServerPacketError exception should be raised. The main loop will
|
||||
drop the packet and log the reason.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
self._add_secret(pkt)
|
||||
if pkt.code not in [
|
||||
packet.AccountingRequest,
|
||||
packet.AccountingResponse,
|
||||
]:
|
||||
raise ServerPacketError(
|
||||
"Received non-accounting packet on accounting port"
|
||||
)
|
||||
self.HandleAcctPacket(pkt)
|
||||
|
||||
def _handle_coa_packet(self, pkt):
|
||||
"""Process a packet received on the coa port.
|
||||
If this packet should be dropped instead of processed a
|
||||
ServerPacketError exception should be raised. The main loop will
|
||||
drop the packet and log the reason.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
self._add_secret(pkt)
|
||||
pkt.secret = self.hosts[pkt.source[0]].secret
|
||||
if pkt.code == packet.CoARequest:
|
||||
self.HandleCoaPacket(pkt)
|
||||
elif pkt.code == packet.DisconnectRequest:
|
||||
self.HandleDisconnectPacket(pkt)
|
||||
else:
|
||||
raise ServerPacketError("Received non-coa packet on coa port")
|
||||
|
||||
def _grab_packet(self, pktgen, fd):
|
||||
"""Read a packet from a network connection.
|
||||
This method assumes there is data waiting for to be read.
|
||||
|
||||
:param fd: socket to read packet from
|
||||
:type fd: socket class instance
|
||||
:return: RADIUS packet
|
||||
:rtype: Packet class instance
|
||||
"""
|
||||
(data, source) = fd.recvfrom(self.MaxPacketSize)
|
||||
pkt = pktgen(data)
|
||||
pkt.source = source
|
||||
pkt.fd = fd
|
||||
return pkt
|
||||
|
||||
def _prepare_sockets(self):
|
||||
"""Prepare all sockets to receive packets.
|
||||
"""
|
||||
for fd in self.authfds + self.acctfds + self.coafds:
|
||||
self._fdmap[fd.fileno()] = fd
|
||||
self._poll.register(
|
||||
fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR
|
||||
)
|
||||
if self.auth_enabled:
|
||||
self._realauthfds = list(map(lambda x: x.fileno(), self.authfds))
|
||||
if self.acct_enabled:
|
||||
self._realacctfds = list(map(lambda x: x.fileno(), self.acctfds))
|
||||
if self.coa_enabled:
|
||||
self._realcoafds = list(map(lambda x: x.fileno(), self.coafds))
|
||||
|
||||
def CreateReplyPacket(self, pkt, **attributes):
|
||||
"""Create a reply packet.
|
||||
Create a new packet which can be returned as a reply to a received
|
||||
packet.
|
||||
|
||||
:param pkt: original packet
|
||||
:type pkt: Packet instance
|
||||
"""
|
||||
reply = pkt.CreateReply(**attributes)
|
||||
reply.source = pkt.source
|
||||
return reply
|
||||
|
||||
def _process_input(self, fd):
|
||||
"""Process available data.
|
||||
If this packet should be dropped instead of processed a
|
||||
PacketError exception should be raised. The main loop will
|
||||
drop the packet and log the reason.
|
||||
|
||||
This function calls either HandleAuthPacket() or
|
||||
HandleAcctPacket() depending on which socket is being
|
||||
processed.
|
||||
|
||||
:param fd: socket to read packet from
|
||||
:type fd: socket class instance
|
||||
"""
|
||||
if self.auth_enabled and fd.fileno() in self._realauthfds:
|
||||
pkt = self._grab_packet(
|
||||
lambda data, s=self: s.CreateAuthPacket(packet=data), fd
|
||||
)
|
||||
self._handle_auth_packet(pkt)
|
||||
elif self.acct_enabled and fd.fileno() in self._realacctfds:
|
||||
pkt = self._grab_packet(
|
||||
lambda data, s=self: s.CreateAcctPacket(packet=data), fd
|
||||
)
|
||||
self._handle_acct_packet(pkt)
|
||||
elif self.coa_enabled:
|
||||
pkt = self._grab_packet(
|
||||
lambda data, s=self: s.CreateCoAPacket(packet=data), fd
|
||||
)
|
||||
self._handle_coa_packet(pkt)
|
||||
else:
|
||||
raise ServerPacketError("Received packet for unknown handler")
|
||||
|
||||
def Run(self):
|
||||
"""Main loop.
|
||||
This method is the main loop for a RADIUS server. It waits
|
||||
for packets to arrive via the network and calls other methods
|
||||
to process them.
|
||||
"""
|
||||
self._poll = select.poll()
|
||||
self._fdmap = {}
|
||||
self._prepare_sockets()
|
||||
|
||||
while True:
|
||||
for (fd, event) in self._poll.poll():
|
||||
if event == select.POLLIN:
|
||||
try:
|
||||
fdo = self._fdmap[fd]
|
||||
self._process_input(fdo)
|
||||
except ServerPacketError as err:
|
||||
LOGGER.info("Dropping packet: %s", err)
|
||||
except packet.PacketError as err:
|
||||
LOGGER.info("Received a broken packet: %s", err)
|
||||
else:
|
||||
LOGGER.error("Unexpected event in server main loop")
|
||||
@@ -1,429 +0,0 @@
|
||||
# server_async.py
|
||||
#
|
||||
# Copyright 2018-2019 Geaaru <geaaru@gmail.com>
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from pyrad.packet import (
|
||||
Packet,
|
||||
AccessAccept,
|
||||
AccessReject,
|
||||
AccountingRequest,
|
||||
AccountingResponse,
|
||||
DisconnectACK,
|
||||
DisconnectNAK,
|
||||
DisconnectRequest,
|
||||
CoARequest,
|
||||
CoAACK,
|
||||
CoANAK,
|
||||
AccessRequest,
|
||||
AuthPacket,
|
||||
AcctPacket,
|
||||
CoAPacket,
|
||||
PacketError,
|
||||
)
|
||||
|
||||
from pyrad.server import ServerPacketError
|
||||
|
||||
|
||||
class ServerType(Enum):
|
||||
Auth = "Authentication"
|
||||
Acct = "Accounting"
|
||||
Coa = "Coa"
|
||||
|
||||
|
||||
class DatagramProtocolServer(asyncio.Protocol):
|
||||
def __init__(
|
||||
self, ip, port, logger, server, server_type, hosts, request_callback
|
||||
):
|
||||
self.transport = None
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self.logger = logger
|
||||
self.server = server
|
||||
self.hosts = hosts
|
||||
self.server_type = server_type
|
||||
self.request_callback = request_callback
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
self.logger.info("[%s:%d] Transport created", self.ip, self.port)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if exc:
|
||||
self.logger.warn(
|
||||
"[%s:%d] Connection lost: %s", self.ip, self.port, str(exc)
|
||||
)
|
||||
else:
|
||||
self.logger.info("[%s:%d] Transport closed", self.ip, self.port)
|
||||
|
||||
def send_response(self, reply, addr):
|
||||
self.transport.sendto(reply.ReplyPacket(), addr)
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
self.logger.debug(
|
||||
"[%s:%d] Received %d bytes from %s",
|
||||
self.ip,
|
||||
self.port,
|
||||
len(data),
|
||||
addr,
|
||||
)
|
||||
|
||||
receive_date = datetime.utcnow()
|
||||
|
||||
if addr[0] in self.hosts:
|
||||
remote_host = self.hosts[addr[0]]
|
||||
elif "0.0.0.0" in self.hosts:
|
||||
remote_host = self.hosts["0.0.0.0"]
|
||||
else:
|
||||
self.logger.warn(
|
||||
"[%s:%d] Drop package from unknown source %s",
|
||||
self.ip,
|
||||
self.port,
|
||||
addr,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.debug(
|
||||
"[%s:%d] Received from %s packet: %s",
|
||||
self.ip,
|
||||
self.port,
|
||||
addr,
|
||||
data.hex(),
|
||||
)
|
||||
req = Packet(packet=data, dict=self.server.dict)
|
||||
except Exception as exc:
|
||||
self.logger.error(
|
||||
"[%s:%d] Error on decode packet: %s", self.ip, self.port, exc
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if req.code in (
|
||||
AccountingResponse,
|
||||
AccessAccept,
|
||||
AccessReject,
|
||||
CoANAK,
|
||||
CoAACK,
|
||||
DisconnectNAK,
|
||||
DisconnectACK,
|
||||
):
|
||||
raise ServerPacketError(f"Invalid response packet {req.code}")
|
||||
|
||||
elif self.server_type == ServerType.Auth:
|
||||
if req.code != AccessRequest:
|
||||
raise ServerPacketError(
|
||||
"Received non-auth packet on auth port"
|
||||
)
|
||||
req = AuthPacket(
|
||||
secret=remote_host.secret,
|
||||
dict=self.server.dict,
|
||||
packet=data,
|
||||
)
|
||||
if self.server.enable_pkt_verify:
|
||||
if req.VerifyAuthRequest():
|
||||
raise PacketError("Packet verification failed")
|
||||
|
||||
elif self.server_type == ServerType.Coa:
|
||||
if req.code != DisconnectRequest and req.code != CoARequest:
|
||||
raise ServerPacketError(
|
||||
"Received non-coa packet on coa port"
|
||||
)
|
||||
req = CoAPacket(
|
||||
secret=remote_host.secret,
|
||||
dict=self.server.dict,
|
||||
packet=data,
|
||||
)
|
||||
if self.server.enable_pkt_verify:
|
||||
if req.VerifyCoARequest():
|
||||
raise PacketError("Packet verification failed")
|
||||
|
||||
elif self.server_type == ServerType.Acct:
|
||||
|
||||
if req.code != AccountingRequest:
|
||||
raise ServerPacketError(
|
||||
"Received non-acct packet on acct port"
|
||||
)
|
||||
req = AcctPacket(
|
||||
secret=remote_host.secret,
|
||||
dict=self.server.dict,
|
||||
packet=data,
|
||||
)
|
||||
if self.server.enable_pkt_verify:
|
||||
if req.VerifyAcctRequest():
|
||||
raise PacketError("Packet verification failed")
|
||||
|
||||
# Call request callback
|
||||
self.request_callback(self, req, addr)
|
||||
except Exception as exc:
|
||||
if self.server.debug:
|
||||
self.logger.exception(
|
||||
"[%s:%d] Error for packet from %s", self.ip, self.port, addr
|
||||
)
|
||||
else:
|
||||
self.logger.error(
|
||||
"[%s:%d] Error for packet from %s: %s",
|
||||
self.ip,
|
||||
self.port,
|
||||
addr,
|
||||
exc,
|
||||
)
|
||||
|
||||
process_date = datetime.utcnow()
|
||||
self.logger.debug(
|
||||
"[%s:%d] Request from %s processed in %d ms",
|
||||
self.ip,
|
||||
self.port,
|
||||
addr,
|
||||
(process_date - receive_date).microseconds / 1000,
|
||||
)
|
||||
|
||||
def error_received(self, exc):
|
||||
self.logger.error("[%s:%d] Error received: %s", self.ip, self.port, exc)
|
||||
|
||||
async def close_transport(self):
|
||||
if self.transport:
|
||||
self.logger.debug("[%s:%d] Close transport...", self.ip, self.port)
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
|
||||
def __str__(self):
|
||||
return f"DatagramProtocolServer(ip={self.ip}, port={self.port})"
|
||||
|
||||
# Used as protocol_factory
|
||||
def __call__(self):
|
||||
return self
|
||||
|
||||
|
||||
class ServerAsync(metaclass=ABCMeta):
|
||||
def __init__(
|
||||
self,
|
||||
auth_port=1812,
|
||||
acct_port=1813,
|
||||
coa_port=3799,
|
||||
hosts=None,
|
||||
dictionary=None,
|
||||
loop=None,
|
||||
logger_name="pyrad",
|
||||
enable_pkt_verify=False,
|
||||
debug=False,
|
||||
):
|
||||
|
||||
if not loop:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self.loop = loop
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
|
||||
if hosts is None:
|
||||
self.hosts = {}
|
||||
else:
|
||||
self.hosts = hosts
|
||||
|
||||
self.auth_port = auth_port
|
||||
self.auth_protocols = []
|
||||
|
||||
self.acct_port = acct_port
|
||||
self.acct_protocols = []
|
||||
|
||||
self.coa_port = coa_port
|
||||
self.coa_protocols = []
|
||||
|
||||
self.dict = dictionary
|
||||
self.enable_pkt_verify = enable_pkt_verify
|
||||
|
||||
self.debug = debug
|
||||
|
||||
def __request_handler__(self, protocol, req, addr):
|
||||
|
||||
try:
|
||||
if protocol.server_type == ServerType.Acct:
|
||||
self.handle_acct_packet(protocol, req, addr)
|
||||
elif protocol.server_type == ServerType.Auth:
|
||||
self.handle_auth_packet(protocol, req, addr)
|
||||
elif (
|
||||
protocol.server_type == ServerType.Coa
|
||||
and req.code == CoARequest
|
||||
):
|
||||
self.handle_coa_packet(protocol, req, addr)
|
||||
elif (
|
||||
protocol.server_type == ServerType.Coa
|
||||
and req.code == DisconnectRequest
|
||||
):
|
||||
self.handle_disconnect_packet(protocol, req, addr)
|
||||
else:
|
||||
self.logger.error(
|
||||
"[%s:%s] Unexpected request found",
|
||||
protocol.ip,
|
||||
protocol.port,
|
||||
)
|
||||
except Exception as exc:
|
||||
if self.debug:
|
||||
self.logger.exception(
|
||||
"[%s:%s] Unexpected error", protocol.ip, protocol.port
|
||||
)
|
||||
|
||||
else:
|
||||
self.logger.error(
|
||||
"[%s:%s] Unexpected error: %s",
|
||||
protocol.ip,
|
||||
protocol.port,
|
||||
exc,
|
||||
)
|
||||
|
||||
def __is_present_proto__(self, ip, port):
|
||||
if port == self.auth_port:
|
||||
for proto in self.auth_protocols:
|
||||
if proto.ip == ip:
|
||||
return True
|
||||
elif port == self.acct_port:
|
||||
for proto in self.acct_protocols:
|
||||
if proto.ip == ip:
|
||||
return True
|
||||
elif port == self.coa_port:
|
||||
for proto in self.coa_protocols:
|
||||
if proto.ip == ip:
|
||||
return True
|
||||
return False
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
@staticmethod
|
||||
def CreateReplyPacket(pkt, **attributes):
|
||||
"""Create a reply packet.
|
||||
Create a new packet which can be returned as a reply to a received
|
||||
packet.
|
||||
|
||||
:param pkt: original packet
|
||||
:type pkt: Packet instance
|
||||
"""
|
||||
reply = pkt.CreateReply(**attributes)
|
||||
return reply
|
||||
|
||||
async def initialize_transports(
|
||||
self,
|
||||
enable_acct=False,
|
||||
enable_auth=False,
|
||||
enable_coa=False,
|
||||
addresses=None,
|
||||
):
|
||||
|
||||
task_list = []
|
||||
|
||||
if not enable_acct and not enable_auth and not enable_coa:
|
||||
raise Exception("No transports selected")
|
||||
if not addresses or len(addresses) == 0:
|
||||
addresses = ["127.0.0.1"]
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
for addr in addresses:
|
||||
|
||||
if enable_acct and not self.__is_present_proto__(
|
||||
addr, self.acct_port
|
||||
):
|
||||
protocol_acct = DatagramProtocolServer(
|
||||
addr,
|
||||
self.acct_port,
|
||||
self.logger,
|
||||
self,
|
||||
ServerType.Acct,
|
||||
self.hosts,
|
||||
self.__request_handler__,
|
||||
)
|
||||
|
||||
bind_addr = (addr, self.acct_port)
|
||||
acct_connect = self.loop.create_datagram_endpoint(
|
||||
protocol_acct, reuse_port=True, local_addr=bind_addr
|
||||
)
|
||||
self.acct_protocols.append(protocol_acct)
|
||||
task_list.append(acct_connect)
|
||||
|
||||
if enable_auth and not self.__is_present_proto__(
|
||||
addr, self.auth_port
|
||||
):
|
||||
protocol_auth = DatagramProtocolServer(
|
||||
addr,
|
||||
self.auth_port,
|
||||
self.logger,
|
||||
self,
|
||||
ServerType.Auth,
|
||||
self.hosts,
|
||||
self.__request_handler__,
|
||||
)
|
||||
bind_addr = (addr, self.auth_port)
|
||||
|
||||
auth_connect = self.loop.create_datagram_endpoint(
|
||||
protocol_auth, reuse_port=True, local_addr=bind_addr
|
||||
)
|
||||
self.auth_protocols.append(protocol_auth)
|
||||
task_list.append(auth_connect)
|
||||
|
||||
if enable_coa and not self.__is_present_proto__(
|
||||
addr, self.coa_port
|
||||
):
|
||||
protocol_coa = DatagramProtocolServer(
|
||||
addr,
|
||||
self.coa_port,
|
||||
self.logger,
|
||||
self,
|
||||
ServerType.Coa,
|
||||
self.hosts,
|
||||
self.__request_handler__,
|
||||
)
|
||||
bind_addr = (addr, self.coa_port)
|
||||
|
||||
coa_connect = self.loop.create_datagram_endpoint(
|
||||
protocol_coa, reuse_port=True, local_addr=bind_addr
|
||||
)
|
||||
self.coa_protocols.append(protocol_coa)
|
||||
task_list.append(coa_connect)
|
||||
|
||||
await asyncio.ensure_future(
|
||||
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
|
||||
)
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
async def deinitialize_transports(
|
||||
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
|
||||
):
|
||||
|
||||
if deinit_coa:
|
||||
for proto in self.coa_protocols:
|
||||
await proto.close_transport()
|
||||
del proto
|
||||
|
||||
self.coa_protocols = []
|
||||
|
||||
if deinit_auth:
|
||||
for proto in self.auth_protocols:
|
||||
await proto.close_transport()
|
||||
del proto
|
||||
|
||||
self.auth_protocols = []
|
||||
|
||||
if deinit_acct:
|
||||
for proto in self.acct_protocols:
|
||||
await proto.close_transport()
|
||||
del proto
|
||||
|
||||
self.acct_protocols = []
|
||||
|
||||
@abstractmethod
|
||||
def handle_auth_packet(self, protocol, pkt, addr):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_acct_packet(self, protocol, pkt, addr):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_coa_packet(self, protocol, pkt, addr):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_disconnect_packet(self, protocol, pkt, addr):
|
||||
pass
|
||||
@@ -4,7 +4,14 @@
|
||||
"""Collections of functions to en- and decode RADIUS Attributes"""
|
||||
|
||||
from typing import Union
|
||||
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network, ip_address
|
||||
from ipaddress import (
|
||||
IPv4Address,
|
||||
IPv4Network,
|
||||
IPv6Address,
|
||||
IPv6Network,
|
||||
ip_network,
|
||||
ip_address,
|
||||
)
|
||||
|
||||
import struct
|
||||
|
||||
@@ -33,13 +40,17 @@ def encode_address(addr: Union[str, IPv4Address]) -> bytes:
|
||||
def encode_network(network: Union[str, IPv4Network]) -> bytes:
|
||||
"""Encode a RADIUS value of type ipv4prefix"""
|
||||
address = IPv4Network(network)
|
||||
return struct.pack("2B", 0, address.prefixlen) + address.network_address.packed
|
||||
return (
|
||||
struct.pack("2B", 0, address.prefixlen) + address.network_address.packed
|
||||
)
|
||||
|
||||
|
||||
def encode_ipv6_prefix(network: Union[str, IPv6Network]) -> bytes:
|
||||
"""Encode a RADIUS value of type ipv6prefix"""
|
||||
address = IPv6Network(network)
|
||||
return struct.pack("2B", 0, address.prefixlen) + address.network_address.packed.rstrip(b'\0')
|
||||
return struct.pack(
|
||||
"2B", 0, address.prefixlen
|
||||
) + address.network_address.packed.rstrip(b"\0")
|
||||
|
||||
|
||||
def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes:
|
||||
@@ -212,7 +223,7 @@ ENCODE_MAP = {
|
||||
"signed": lambda value: encode_integer(value, "!i"),
|
||||
"short": lambda value: encode_integer(value, "!H"),
|
||||
"byte": lambda value: encode_integer(value, "!B"),
|
||||
"integer64": lambda value: encode_integer(value, '!Q'),
|
||||
"integer64": lambda value: encode_integer(value, "!Q"),
|
||||
"date": encode_date,
|
||||
}
|
||||
|
||||
|
||||
@@ -127,11 +127,16 @@ def parse_key(rad_dict: Dictionary, key_id: int) -> Union[str, int]:
|
||||
return key_id
|
||||
|
||||
|
||||
def parse_value(
|
||||
rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
|
||||
) -> List[Attribute]:
|
||||
def parse_value(*_):
|
||||
"""Parse the Value in the given Key/Dictionary Context"""
|
||||
return []
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# def parse_value(
|
||||
# rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
|
||||
# ) -> List[Attribute]:
|
||||
# """Parse the Value in the given Key/Dictionary Context"""
|
||||
# return []
|
||||
|
||||
|
||||
def calculate_authenticator(
|
||||
|
||||
Reference in New Issue
Block a user