safe progress
This commit is contained in:
46
src/pyrad3/__init__.py
Normal file
46
src/pyrad3/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Python RADIUS client code.
|
||||
|
||||
pyrad is an implementation of a RADIUS client as described in RFC2865.
|
||||
It takes care of all the details like building RADIUS packets, sending
|
||||
them and decoding responses.
|
||||
|
||||
Here is an example of doing a authentication request::
|
||||
|
||||
import pyrad.packet
|
||||
from pyrad.client import Client
|
||||
from pyrad.dictionary import Dictionary
|
||||
|
||||
srv = Client(server="radius.my.domain", secret="s3cr3t",
|
||||
dict = Dictionary("dicts/dictionary", "dictionary.acc"))
|
||||
|
||||
req = srv.CreatePacket(code=pyrad.packet.AccessRequest,
|
||||
User_Name = "wichert", NAS_Identifier="localhost")
|
||||
req["User-Password"] = req.PwCrypt("password")
|
||||
|
||||
reply = srv.SendPacket(req)
|
||||
if reply.code = =pyrad.packet.AccessAccept:
|
||||
print "access accepted"
|
||||
else:
|
||||
print "access denied"
|
||||
|
||||
print "Attributes returned by server:"
|
||||
for key, value in reply.items():
|
||||
print f'{key}: {value}')
|
||||
|
||||
|
||||
This package contains four modules:
|
||||
|
||||
- client: RADIUS client code
|
||||
- dictionary: RADIUS attribute dictionary
|
||||
- packet: a RADIUS packet as send to/from servers
|
||||
- tools: utility functions
|
||||
"""
|
||||
|
||||
__docformat__ = 'epytext en'
|
||||
|
||||
__author__ = 'Istvan Ruzman <istvan@ruzman.eu>'
|
||||
__url__ = 'http://pyrad.readthedocs.io/en/latest/?badge=latest'
|
||||
__copyright__ = 'Copyright 2020 Istvan Ruzman'
|
||||
__version__ = '0.1.0'
|
||||
|
||||
__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'utils']
|
||||
39
src/pyrad3/bidict.py
Normal file
39
src/pyrad3/bidict.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# bidict.py
|
||||
#
|
||||
# Bidirectional map
|
||||
|
||||
|
||||
class BiDict:
|
||||
def __init__(self):
|
||||
self.forward = {}
|
||||
self.backward = {}
|
||||
|
||||
def add(self, one, two):
|
||||
self.forward[one] = two
|
||||
self.backward[two] = one
|
||||
|
||||
def __len__(self):
|
||||
return len(self.forward)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.get_forward(key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
try:
|
||||
del self.backward[self.forward[key]]
|
||||
del self.forward[key]
|
||||
except KeyError:
|
||||
del self.forward[self.backward[key]]
|
||||
del self.backward[key]
|
||||
|
||||
def get_forward(self, key):
|
||||
return self.forward[key]
|
||||
|
||||
def has_forward(self, key):
|
||||
return key in self.forward
|
||||
|
||||
def get_backward(self, key):
|
||||
return self.backward[key]
|
||||
|
||||
def has_backward(self, key):
|
||||
return key in self.backward
|
||||
131
src/pyrad3/client.py
Normal file
131
src/pyrad3/client.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright 2020 Istvan Ruzman
|
||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||
|
||||
"""Implementation of a simple but extensible RADIUS Client"""
|
||||
|
||||
from typing import Optional, Union
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
|
||||
import select
|
||||
import socket
|
||||
import time
|
||||
|
||||
import pyrad3.packet as P
|
||||
from pyrad3 import host
|
||||
|
||||
SUPPORTED_SEND_TYPES = [
|
||||
P.Code.AccessRequest,
|
||||
P.Code.AccountingRequest,
|
||||
P.Code.CoARequest,
|
||||
]
|
||||
|
||||
PACKET_TYPE_PORT_MAPPING = {
|
||||
P.Code.AccessRequest: "authport",
|
||||
P.Code.AccountingRequest: "acctport",
|
||||
P.Code.CoARequest: "coaport",
|
||||
}
|
||||
|
||||
|
||||
class Timeout(Exception):
|
||||
"""Exception for wait timeouts"""
|
||||
|
||||
|
||||
class UnsupportedPacketType(Exception):
|
||||
"""Exception for received packets"""
|
||||
|
||||
|
||||
class Client(host.Host):
|
||||
"""A simple and extensible RADIUS Client."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: Union[str, IPv4Address, IPv6Address],
|
||||
secret: bytes,
|
||||
radius_dictionary: dict,
|
||||
interface: Optional[str],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(secret, radius_dictionary, **kwargs)
|
||||
self.server = server
|
||||
self.interface = interface
|
||||
self._socket: Optional[socket.socket] = None
|
||||
self._poll: Optional[select.poll] = None
|
||||
|
||||
def bind(self, addr):
|
||||
"""Bind the Address to some socket"""
|
||||
self._socket_close()
|
||||
self._socket_open()
|
||||
self._socket.bind(addr)
|
||||
|
||||
def _socket_open(self):
|
||||
"""Open a client socket"""
|
||||
if self._socket is not None:
|
||||
return
|
||||
try:
|
||||
family = socket.getaddrinfo(self.server, "www")[0][0]
|
||||
except socket.gaierror:
|
||||
family = socket.AF_INET
|
||||
self._socket = socket.socket(family, socket.SOCK_DGRAM)
|
||||
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
if self.interface is not None:
|
||||
# This will fail on non-Linux systems - that's ok, because we don't
|
||||
# have any implementation yet for non-Linux systems.
|
||||
# Better to fail loudly than to do something unexpected silently.
|
||||
self._socket.setsockopt(
|
||||
socket.SOL_SOCKET,
|
||||
socket.SO_BINDTODEVICE,
|
||||
self.interface,
|
||||
len(self.interface),
|
||||
)
|
||||
self._poll = select.poll()
|
||||
self._poll.register(self._socket, select.POLLIN)
|
||||
|
||||
def _socket_close(self):
|
||||
"""Close the Client socket"""
|
||||
if self._socket is not None:
|
||||
self._poll.unregister(self._socket)
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
|
||||
def _select_port(self, packet: P.Packet):
|
||||
"""Select the RADIUS Port depending on the RADIUS Packet type"""
|
||||
try:
|
||||
port_type = PACKET_TYPE_PORT_MAPPING[packet.code]
|
||||
return getattr(self, port_type)
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
raise UnsupportedPacketType(f"The packet type {packet.code} by Client")
|
||||
|
||||
def _send_packet(self, packet: P.Packet):
|
||||
"""Send a RADIUS Packet and wait for the reply"""
|
||||
assert self._socket is not None
|
||||
assert self._poll is not None
|
||||
|
||||
port = self._select_port(packet)
|
||||
|
||||
raw_packet = packet.serialize()
|
||||
|
||||
for _attempt in range(self.retries):
|
||||
now = time.time()
|
||||
waitto = now + self.timeout
|
||||
|
||||
self._socket.sendto(raw_packet, (self.server, port))
|
||||
|
||||
while now < waitto:
|
||||
if not self._poll.poll((waitto - now) * 1000):
|
||||
# socket is not ready for some reason
|
||||
now = time.time()
|
||||
continue
|
||||
|
||||
rawreply = self._socket.recv(4096)
|
||||
try:
|
||||
return packet.verify_reply(rawreply)
|
||||
except P.PacketError:
|
||||
pass
|
||||
|
||||
# timed out: try the next attempt after increasing the acct delay time
|
||||
if packet.code == packet.AccountingRequest:
|
||||
packet.increase_acct_delay_time(self.timeout)
|
||||
raw_packet = packet.serialize()
|
||||
|
||||
raise Timeout
|
||||
442
src/pyrad3/client_async.py
Normal file
442
src/pyrad3/client_async.py
Normal file
@@ -0,0 +1,442 @@
|
||||
# client_async.py
|
||||
#
|
||||
# Copyright 2018-2020 Geaaru <geaaru<@>gmail.com>
|
||||
|
||||
__docformat__ = "epytext en"
|
||||
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
|
||||
from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket
|
||||
|
||||
|
||||
class DatagramProtocolClient(asyncio.Protocol):
|
||||
def __init__(self, server, port, logger, client, retries=3, timeout=30):
|
||||
self.transport = None
|
||||
self.port = port
|
||||
self.server = server
|
||||
self.logger = logger
|
||||
self.retries = retries
|
||||
self.timeout = timeout
|
||||
self.client = client
|
||||
|
||||
# Map of pending requests
|
||||
self.pending_requests = {}
|
||||
|
||||
# Use cryptographic-safe random generator as provided by the OS.
|
||||
random_generator = random.SystemRandom()
|
||||
self.packet_id = random_generator.randrange(0, 256)
|
||||
|
||||
self.timeout_future = None
|
||||
|
||||
async def __timeout_handler__(self):
|
||||
try:
|
||||
while True:
|
||||
req2delete = []
|
||||
now = datetime.now()
|
||||
next_weak_up = self.timeout
|
||||
# noinspection PyShadowingBuiltins
|
||||
for id, req in self.pending_requests.items():
|
||||
|
||||
secs = (req["send_date"] - now).seconds
|
||||
if secs > self.timeout:
|
||||
if req["retries"] == self.retries:
|
||||
self.logger.debug(
|
||||
"[%s:%d] For request %d execute all retries",
|
||||
self.server,
|
||||
self.port,
|
||||
id,
|
||||
)
|
||||
req["future"].set_exception(
|
||||
TimeoutError("Timeout on Reply")
|
||||
)
|
||||
req2delete.append(id)
|
||||
else:
|
||||
# Send again packet
|
||||
req["send_date"] = now
|
||||
req["retries"] += 1
|
||||
self.logger.debug(
|
||||
"[%s:%d] For request %d execute retry %d",
|
||||
self.server,
|
||||
self.port,
|
||||
id,
|
||||
req["retries"],
|
||||
)
|
||||
self.transport.sendto(req["packet"].RequestPacket())
|
||||
elif next_weak_up > secs:
|
||||
next_weak_up = secs
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
for id in req2delete:
|
||||
# Remove request for map
|
||||
del self.pending_requests[id]
|
||||
|
||||
await asyncio.sleep(next_weak_up)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def send_packet(self, packet, future):
|
||||
if packet.id in self.pending_requests:
|
||||
raise Exception(f"Packet with id {packet.id} already present")
|
||||
|
||||
# Store packet on pending requests map
|
||||
self.pending_requests[packet.id] = {
|
||||
"packet": packet,
|
||||
"creation_date": datetime.now(),
|
||||
"retries": 0,
|
||||
"future": future,
|
||||
"send_date": datetime.now(),
|
||||
}
|
||||
|
||||
# In queue packet raw on socket buffer
|
||||
self.transport.sendto(packet.RequestPacket())
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
socket = transport.get_extra_info("socket")
|
||||
self.logger.info(
|
||||
"[%s:%d] Transport created with binding in %s:%d",
|
||||
self.server,
|
||||
self.port,
|
||||
socket.getsockname()[0],
|
||||
socket.getsockname()[1],
|
||||
)
|
||||
|
||||
pre_loop = asyncio.get_event_loop()
|
||||
asyncio.set_event_loop(loop=self.client.loop)
|
||||
# Start asynchronous timer handler
|
||||
self.timeout_future = asyncio.ensure_future(self.__timeout_handler__())
|
||||
asyncio.set_event_loop(loop=pre_loop)
|
||||
|
||||
def error_received(self, exc):
|
||||
self.logger.error(
|
||||
"[%s:%d] Error received: %s", self.server, self.port, exc
|
||||
)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if exc:
|
||||
self.logger.warn(
|
||||
"[%s:%d] Connection lost: %s", self.server, self.port, str(exc)
|
||||
)
|
||||
else:
|
||||
self.logger.info("[%s:%d] Transport closed", self.server, self.port)
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
def datagram_received(self, data, addr):
|
||||
try:
|
||||
req = self.pending_requests[data[0]]
|
||||
reply = req.VerifyPacket(data)
|
||||
req["future"].set_result(reply)
|
||||
# Remove request for map
|
||||
del self.pending_requests[reply.id]
|
||||
except KeyError:
|
||||
self.logger.warn(
|
||||
"[%s:%d] Ignore invalid reply: %s", self.server, self.port, data
|
||||
)
|
||||
except PacketError as exc:
|
||||
self.logger.error(
|
||||
"[%s:%d] Error on decode or verify packet: %s",
|
||||
self.server,
|
||||
self.port,
|
||||
exc,
|
||||
)
|
||||
|
||||
async def close_transport(self):
|
||||
if self.transport:
|
||||
self.logger.debug(
|
||||
"[%s:%d] Closing transport...", self.server, self.port
|
||||
)
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
if self.timeout_future:
|
||||
self.timeout_future.cancel()
|
||||
await self.timeout_future
|
||||
self.timeout_future = None
|
||||
|
||||
def create_id(self):
|
||||
self.packet_id = (self.packet_id + 1) % 256
|
||||
return self.packet_id
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"DatagramProtocolClient(server?={self.server}, port={self.port})"
|
||||
)
|
||||
|
||||
# Used as protocol_factory
|
||||
def __call__(self):
|
||||
return self
|
||||
|
||||
|
||||
class ClientAsync:
|
||||
"""Basic RADIUS client.
|
||||
This class implements a basic RADIUS client. It can send requests
|
||||
to a RADIUS server, taking care of timeouts and retries, and
|
||||
validate its replies.
|
||||
|
||||
:ivar retries: number of times to retry sending a RADIUS request
|
||||
:type retries: integer
|
||||
:ivar timeout: number of seconds to wait for an answer
|
||||
:type timeout: integer
|
||||
"""
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
def __init__(
|
||||
self,
|
||||
server,
|
||||
auth_port=1812,
|
||||
acct_port=1813,
|
||||
coa_port=3799,
|
||||
secret=b"",
|
||||
dict=None,
|
||||
loop=None,
|
||||
retries=3,
|
||||
timeout=30,
|
||||
logger_name="pyrad",
|
||||
):
|
||||
|
||||
"""Constructor.
|
||||
|
||||
:param server: hostname or IP address of RADIUS server
|
||||
:type server: string
|
||||
:param auth_port: port to use for authentication packets
|
||||
:type auth_port: integer
|
||||
:param acct_port: port to use for accounting packets
|
||||
:type acct_port: integer
|
||||
:param coa_port: port to use for CoA packets
|
||||
:type coa_port: integer
|
||||
:param secret: RADIUS secret
|
||||
:type secret: string
|
||||
:param dict: RADIUS dictionary
|
||||
:type dict: pyrad.dictionary.Dictionary
|
||||
:param loop: Python loop handler
|
||||
:type loop: asyncio event loop
|
||||
"""
|
||||
if not loop:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self.loop = loop
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
|
||||
self.server = server
|
||||
self.secret = secret
|
||||
self.retries = retries
|
||||
self.timeout = timeout
|
||||
self.dict = dict
|
||||
|
||||
self.auth_port = auth_port
|
||||
self.protocol_auth = None
|
||||
|
||||
self.acct_port = acct_port
|
||||
self.protocol_acct = None
|
||||
|
||||
self.protocol_coa = None
|
||||
self.coa_port = coa_port
|
||||
|
||||
async def initialize_transports(
|
||||
self,
|
||||
enable_acct=False,
|
||||
enable_auth=False,
|
||||
enable_coa=False,
|
||||
local_addr=None,
|
||||
local_auth_port=None,
|
||||
local_acct_port=None,
|
||||
local_coa_port=None,
|
||||
):
|
||||
|
||||
task_list = []
|
||||
|
||||
if not enable_acct and not enable_auth and not enable_coa:
|
||||
raise Exception("No transports selected")
|
||||
|
||||
if enable_acct and not self.protocol_acct:
|
||||
self.protocol_acct = DatagramProtocolClient(
|
||||
self.server,
|
||||
self.acct_port,
|
||||
self.logger,
|
||||
self,
|
||||
retries=self.retries,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
bind_addr = None
|
||||
if local_addr and local_acct_port:
|
||||
bind_addr = (local_addr, local_acct_port)
|
||||
|
||||
acct_connect = self.loop.create_datagram_endpoint(
|
||||
self.protocol_acct,
|
||||
reuse_port=True,
|
||||
remote_addr=(self.server, self.acct_port),
|
||||
local_addr=bind_addr,
|
||||
)
|
||||
task_list.append(acct_connect)
|
||||
|
||||
if enable_auth and not self.protocol_auth:
|
||||
self.protocol_auth = DatagramProtocolClient(
|
||||
self.server,
|
||||
self.auth_port,
|
||||
self.logger,
|
||||
self,
|
||||
retries=self.retries,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
bind_addr = None
|
||||
if local_addr and local_auth_port:
|
||||
bind_addr = (local_addr, local_auth_port)
|
||||
|
||||
auth_connect = self.loop.create_datagram_endpoint(
|
||||
self.protocol_auth,
|
||||
reuse_port=True,
|
||||
remote_addr=(self.server, self.auth_port),
|
||||
local_addr=bind_addr,
|
||||
)
|
||||
task_list.append(auth_connect)
|
||||
|
||||
if enable_coa and not self.protocol_coa:
|
||||
self.protocol_coa = DatagramProtocolClient(
|
||||
self.server,
|
||||
self.coa_port,
|
||||
self.logger,
|
||||
self,
|
||||
retries=self.retries,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
bind_addr = None
|
||||
if local_addr and local_coa_port:
|
||||
bind_addr = (local_addr, local_coa_port)
|
||||
|
||||
coa_connect = self.loop.create_datagram_endpoint(
|
||||
self.protocol_coa,
|
||||
reuse_port=True,
|
||||
remote_addr=(self.server, self.coa_port),
|
||||
local_addr=bind_addr,
|
||||
)
|
||||
task_list.append(coa_connect)
|
||||
|
||||
await asyncio.ensure_future(
|
||||
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
|
||||
)
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
async def deinitialize_transports(
|
||||
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
|
||||
):
|
||||
if self.protocol_coa and deinit_coa:
|
||||
await self.protocol_coa.close_transport()
|
||||
del self.protocol_coa
|
||||
self.protocol_coa = None
|
||||
if self.protocol_auth and deinit_auth:
|
||||
await self.protocol_auth.close_transport()
|
||||
del self.protocol_auth
|
||||
self.protocol_auth = None
|
||||
if self.protocol_acct and deinit_acct:
|
||||
await self.protocol_acct.close_transport()
|
||||
del self.protocol_acct
|
||||
self.protocol_acct = None
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def CreateAuthPacket(self, **args):
|
||||
"""Create a new RADIUS packet.
|
||||
This utility function creates a new RADIUS packet which can
|
||||
be used to communicate with the RADIUS server this client
|
||||
talks to. This is initializing the new packet with the
|
||||
dictionary and secret used for the client.
|
||||
|
||||
:return: a new empty packet instance
|
||||
:rtype: pyrad.packet.Packet
|
||||
"""
|
||||
if not self.protocol_auth:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
return AuthPacket(
|
||||
dict=self.dict,
|
||||
id=self.protocol_auth.create_id(),
|
||||
secret=self.secret,
|
||||
**args,
|
||||
)
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def CreateAcctPacket(self, **args):
|
||||
"""Create a new RADIUS packet.
|
||||
This utility function creates a new RADIUS packet which can
|
||||
be used to communicate with the RADIUS server this client
|
||||
talks to. This is initializing the new packet with the
|
||||
dictionary and secret used for the client.
|
||||
|
||||
:return: a new empty packet instance
|
||||
:rtype: pyrad.packet.Packet
|
||||
"""
|
||||
if not self.protocol_acct:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
return AcctPacket(
|
||||
id=self.protocol_acct.create_id(),
|
||||
dict=self.dict,
|
||||
secret=self.secret,
|
||||
**args,
|
||||
)
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def CreateCoAPacket(self, **args):
|
||||
"""Create a new RADIUS packet.
|
||||
This utility function creates a new RADIUS packet which can
|
||||
be used to communicate with the RADIUS server this client
|
||||
talks to. This is initializing the new packet with the
|
||||
dictionary and secret used for the client.
|
||||
|
||||
:return: a new empty packet instance
|
||||
:rtype: pyrad.packet.Packet
|
||||
"""
|
||||
|
||||
if not self.protocol_acct:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
return CoAPacket(
|
||||
id=self.protocol_coa.create_id(),
|
||||
dict=self.dict,
|
||||
secret=self.secret,
|
||||
**args,
|
||||
)
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
# noinspection PyShadowingBuiltins
|
||||
def CreatePacket(self, id, **args):
|
||||
if not id:
|
||||
raise Exception("Missing mandatory packet id")
|
||||
|
||||
return Packet(id=id, dict=self.dict, secret=self.secret, **args)
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
def SendPacket(self, pkt):
|
||||
"""Send a packet to a RADIUS server.
|
||||
|
||||
:param pkt: the packet to send
|
||||
:type pkt: pyrad.packet.Packet
|
||||
:return: Future related with packet to send
|
||||
:rtype: asyncio.Future
|
||||
"""
|
||||
|
||||
ans = asyncio.Future(loop=self.loop)
|
||||
|
||||
if isinstance(pkt, AuthPacket):
|
||||
if not self.protocol_auth:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
self.protocol_auth.send_packet(pkt, ans)
|
||||
|
||||
elif isinstance(pkt, AcctPacket):
|
||||
if not self.protocol_acct:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
self.protocol_acct.send_packet(pkt, ans)
|
||||
|
||||
elif isinstance(pkt, CoAPacket):
|
||||
if not self.protocol_coa:
|
||||
raise Exception("Transport not initialized")
|
||||
|
||||
self.protocol_coa.send_packet(pkt, ans)
|
||||
|
||||
else:
|
||||
raise Exception("Unsupported packet")
|
||||
|
||||
return ans
|
||||
483
src/pyrad3/dictionary.py
Normal file
483
src/pyrad3/dictionary.py
Normal file
@@ -0,0 +1,483 @@
|
||||
# Copyright 2020 Istvan Ruzman
|
||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||
|
||||
"""RADIUS Dictionary.
|
||||
|
||||
Classes and Types to parse and represent a RADIUS dictionary.
|
||||
"""
|
||||
|
||||
from enum import IntEnum, Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from os.path import dirname, isabs, join, normpath
|
||||
from typing import Dict, Generator, IO, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
INTEGER_TYPES = {
|
||||
"byte": (0, 255),
|
||||
"short": (0, 2 ** 16 - 1),
|
||||
"signed": (-(2 ** 31), 2 ** 31 - 1),
|
||||
"integer": (0, 2 ** 32 - 1),
|
||||
"integer64": (0, 2 ** 64 - 1),
|
||||
}
|
||||
|
||||
|
||||
class Datatype(Enum):
|
||||
"""Possible Datatypes for ATTRIBUTES"""
|
||||
|
||||
string = auto()
|
||||
octets = auto()
|
||||
date = auto()
|
||||
abinary = auto()
|
||||
byte = auto()
|
||||
short = auto()
|
||||
integer = auto()
|
||||
signed = auto()
|
||||
integer64 = auto()
|
||||
ipaddr = auto()
|
||||
ipv4prefix = auto()
|
||||
ipv6addr = auto()
|
||||
ipv6prefix = auto()
|
||||
comboip = auto()
|
||||
ifid = auto()
|
||||
ether = auto()
|
||||
concat = auto()
|
||||
tlv = auto()
|
||||
extended = auto()
|
||||
longextended = auto()
|
||||
evs = auto()
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""RADIUS Dictionary Parser Error"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
msg: str = None,
|
||||
line: Optional[int] = None,
|
||||
**data,
|
||||
):
|
||||
super().__init__()
|
||||
self.msg = msg
|
||||
self.file = filename
|
||||
self.line = line
|
||||
self.data = data
|
||||
|
||||
def __str__(self):
|
||||
line = f"({self.line}" if self.line is not None else ""
|
||||
return f"{self.file}{line}: ParseError: {self.msg}"
|
||||
|
||||
|
||||
class Encrypt(IntEnum):
|
||||
"""Enum for different RADIUS Encryption types."""
|
||||
|
||||
NoEncrpytion = 0
|
||||
RadiusCrypt = 1
|
||||
SaltCrypt = 2
|
||||
AscendCrypt = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class Attribute: # pylint: disable=too-many-instance-attributes
|
||||
"""RADIUS Attribute definition"""
|
||||
|
||||
name: str
|
||||
code: int
|
||||
datatype: Datatype
|
||||
has_tag: bool = False
|
||||
encrypt: Encrypt = Encrypt(0)
|
||||
is_sub_attr: bool = False
|
||||
# vendor = Dictionary
|
||||
values: Dict[Union[int, str], Union[int, str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Vendor:
|
||||
"""Representation of a vendor"""
|
||||
|
||||
name: str
|
||||
code: int
|
||||
tlength: int
|
||||
llength: int
|
||||
continuation: bool
|
||||
attrs: Dict[Union[int, Tuple[int, ...]], Attribute]
|
||||
|
||||
|
||||
def dict_parser(
|
||||
filename: str, rad_dict: IO
|
||||
) -> Generator[Tuple[int, List[str]], None, None]:
|
||||
"""Tokenstream of RADIUS Dictionary files
|
||||
|
||||
Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED"
|
||||
and "FILE_CLOSED" tokens will be emitted.
|
||||
"""
|
||||
yield (-1, ["FILE_OPENED", filename])
|
||||
for line_num, line in enumerate(rad_dict.readlines()):
|
||||
tokens = line.split("#", 1)[0].strip().split()
|
||||
if tokens:
|
||||
first_tok = tokens[0] = tokens[0].upper()
|
||||
if first_tok == "$INCLUDE":
|
||||
try:
|
||||
inner_filename = tokens[1]
|
||||
except IndexError:
|
||||
raise ParseError(
|
||||
filename, "$INCLUDE is missing a filename", line_num,
|
||||
)
|
||||
if not isabs(tokens[1]):
|
||||
path = dirname(filename)
|
||||
inner_filename = normpath(join(path, inner_filename))
|
||||
yield from dict_loader(inner_filename)
|
||||
yield (line_num, tokens)
|
||||
yield (-1, ["FILE_CLOSED"])
|
||||
|
||||
|
||||
def dict_loader(filename: str) -> Generator[Tuple[int, List[str]], None, None]:
|
||||
"""Tokenstream of RADIUS Dictionary files
|
||||
|
||||
Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED"
|
||||
and "FILE_CLOSED" tokens will be emitted.
|
||||
"""
|
||||
with open(filename, "r") as rad_dict:
|
||||
yield from dict_parser(filename, rad_dict)
|
||||
|
||||
|
||||
def _parse_number(num: str) -> int:
|
||||
"""Parse a number from (Free)RADIUS dictionaries
|
||||
|
||||
Numbers can be either decimal, octal, or hexadecimal.
|
||||
"""
|
||||
if num.startswith("0x"):
|
||||
return int(num, 16)
|
||||
if num.startswith("0o"):
|
||||
return int(num, 8)
|
||||
return int(num)
|
||||
|
||||
|
||||
def _parse_attribute_code(attr_code: str) -> List[int]:
|
||||
"""Parse attribute codes from (Free)RADIUS dictionaries
|
||||
|
||||
Codes can be either decimal, octal, or hexadecimal.
|
||||
TLV typed can
|
||||
"""
|
||||
codes = []
|
||||
for code in attr_code.split("."):
|
||||
codes.append(_parse_number(code))
|
||||
return codes
|
||||
|
||||
|
||||
class Dictionary:
|
||||
"""(Free)RADIUS Dictionary.
|
||||
|
||||
#TODO: Better documentation
|
||||
This dictionary can "contain" multiple dictionaries.
|
||||
"""
|
||||
|
||||
# there must be some nicer way to unittest this...
|
||||
def __init__(self, dictionary: str, __dictio: Optional[IO] = None):
|
||||
self.vendor: Dict[int, Vendor] = {}
|
||||
self.vendor_lookup_id_by_name: Dict[str, int] = {}
|
||||
self.attrindex: Dict[Union[int, str], Attribute] = {}
|
||||
self.rfc_vendor = Vendor("RFC", 0, 1, 1, False, {})
|
||||
self.cur_vendor = self.rfc_vendor
|
||||
if __dictio is not None:
|
||||
loader = dict_parser(dictionary, __dictio)
|
||||
else:
|
||||
loader = dict_loader(dictionary)
|
||||
self.read_dictionary(loader)
|
||||
|
||||
def read_dictionary(
|
||||
self, reader: Generator[Tuple[int, List[str]], None, None]
|
||||
):
|
||||
"""Read and parse a (Free)RADIUS dictionary."""
|
||||
self.filestack = []
|
||||
for line_num, tokens in reader:
|
||||
key = tokens[0]
|
||||
if key == "ATTRIBUTE":
|
||||
self._parse_attribute(tokens, line_num)
|
||||
elif key == "VALUE":
|
||||
self._parse_value(tokens, line_num)
|
||||
elif key == "FILE_OPENED":
|
||||
LOG.info("Parsing file: %s", tokens[1])
|
||||
if tokens[1] in self.filestack:
|
||||
raise ParseError(
|
||||
self.filestack[-1], "Include recursion detected"
|
||||
)
|
||||
self.filestack.append(tokens[1])
|
||||
elif key == "FILE_CLOSED":
|
||||
filename = self.filestack.pop()
|
||||
LOG.info("Finished parsing file: %s", filename)
|
||||
elif key == "VENDOR":
|
||||
self._parse_vendor(tokens, line_num)
|
||||
elif key == "BEGIN-VENDOR":
|
||||
self._parse_begin_vendor(tokens, line_num)
|
||||
elif key == "END-VENDOR":
|
||||
self._parse_end_vendor(tokens, line_num)
|
||||
elif key == "BEGIN-TLV":
|
||||
raise NotImplementedError(
|
||||
"BEGIN-TLV is deprecated and not supported by pyrad3"
|
||||
)
|
||||
elif key == "END-TLV":
|
||||
raise NotImplementedError(
|
||||
"END-TLV is deprecated and not supported by pyrad3"
|
||||
)
|
||||
else:
|
||||
raise ParseError(
|
||||
self.filestack[-1], f"Invalid Token key {key}", line_num
|
||||
)
|
||||
|
||||
def _parse_vendor(self, tokens: Sequence[str], line_num: int):
|
||||
"""Parse the vendor definition"""
|
||||
filename = self.filestack[-1]
|
||||
if len(tokens) not in {3, 4}:
|
||||
raise ParseError(
|
||||
filename, "Incorrect number of tokens for vendor statement"
|
||||
)
|
||||
vendor_name = tokens[1]
|
||||
vendor_id = int(tokens[2], 0)
|
||||
continuation = False
|
||||
|
||||
# Parse optional vendor specification
|
||||
try:
|
||||
vendor_format = tokens[3].split("=")
|
||||
if vendor_format[0] != "format":
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"Unknown option {vendor_format[0]} for vendor definition",
|
||||
line_num,
|
||||
)
|
||||
try:
|
||||
vendor_format = vendor_format[1].split(",")
|
||||
t_len, l_len = (int(a) for a in vendor_format[:2])
|
||||
if t_len not in {1, 2, 4}:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f'Invalid type length definition "{t_len}" for vendor {vendor_name}',
|
||||
line_num,
|
||||
)
|
||||
if l_len not in {0, 1, 2}:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f'Invalid length definition "{l_len}" for vendor {vendor_name}',
|
||||
line_num,
|
||||
)
|
||||
try:
|
||||
if vendor_format[2] == "c":
|
||||
if not vendor_name.upper() == "WIMAX":
|
||||
# Not sure why, but FreeRADIUS has this limit,
|
||||
# so we just do the same cause they know better than me
|
||||
raise ParseError(
|
||||
filename,
|
||||
"continuation-bit is only supported for WiMAX",
|
||||
line_num,
|
||||
)
|
||||
continuation = True
|
||||
except IndexError:
|
||||
pass
|
||||
except ValueError:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"Syntax error in specification for vendor {vendor_name}",
|
||||
line_num,
|
||||
)
|
||||
except IndexError:
|
||||
# no format definition
|
||||
t_len, l_len = 1, 1
|
||||
|
||||
vendor = Vendor(vendor_name, vendor_id, t_len, l_len, continuation, {})
|
||||
self.vendor_lookup_id_by_name[vendor_name] = vendor_id
|
||||
self.vendor[vendor_id] = vendor
|
||||
|
||||
def _parse_begin_vendor(self, tokens: Sequence[str], line_num: int):
|
||||
"""Parse the BEGIN-VENDOR line of (Free)RADIUS dictionaries."""
|
||||
filename = self.filestack[-1]
|
||||
if self.cur_vendor != self.rfc_vendor:
|
||||
raise ParseError(
|
||||
filename,
|
||||
"vendor-begin sections are not allowed to be nested",
|
||||
line_num,
|
||||
)
|
||||
if len(tokens) != 2:
|
||||
raise ParseError(
|
||||
filename,
|
||||
"Incorrect number of tokens for begin-vendor statement",
|
||||
line_num,
|
||||
)
|
||||
try:
|
||||
vendor_id = self.vendor_lookup_id_by_name[tokens[1]]
|
||||
self.cur_vendor = self.vendor[vendor_id]
|
||||
except KeyError:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"Unknown vendor {tokens[1]} in begin-vendor statement",
|
||||
line_num,
|
||||
)
|
||||
|
||||
def _parse_end_vendor(self, tokens: Sequence[str], line_num: int):
|
||||
"""Parse the END-VENDOR line of (Free)RADIUS dictionaries."""
|
||||
filename = self.filestack[-1]
|
||||
if len(tokens) != 2:
|
||||
raise ParseError(
|
||||
filename,
|
||||
"Incorrect number of tokens for end-vendor statement",
|
||||
line_num,
|
||||
)
|
||||
if self.cur_vendor.name != tokens[1]:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"Closing non-opened vendor {tokens[1]} in end-vendor statement",
|
||||
line_num,
|
||||
)
|
||||
self.cur_vendor = self.rfc_vendor
|
||||
|
||||
def _parse_attribute_flags(
|
||||
self, tokens: Sequence[str], line_num: int
|
||||
) -> Tuple[bool, Encrypt]:
|
||||
"""Parse Attribute flags of (Free)RADIUS dictionaries."""
|
||||
filename = self.filestack[-1]
|
||||
has_tag = False
|
||||
encrypt = Encrypt.NoEncrpytion
|
||||
|
||||
try:
|
||||
flags = [flag.split("=") for flag in tokens[4].split(",")]
|
||||
except IndexError:
|
||||
return False, Encrypt.NoEncrpytion
|
||||
|
||||
for flag in flags:
|
||||
flag_len = len(flag)
|
||||
if flag == 1:
|
||||
value = None
|
||||
elif flag_len == 2:
|
||||
value = flag[1]
|
||||
else:
|
||||
raise ParseError(
|
||||
filename, f"Incorrect attribute flag {flag}", line_num
|
||||
)
|
||||
key = flag[0]
|
||||
|
||||
if key == "has_tag":
|
||||
has_tag = True
|
||||
elif key == "encrypt":
|
||||
try:
|
||||
encrypt = Encrypt(int(value)) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"Illegal attribute encryption {value}",
|
||||
line_num,
|
||||
)
|
||||
else:
|
||||
raise ParseError(
|
||||
filename, "Unknown attribute flag {key}", line_num
|
||||
)
|
||||
|
||||
return has_tag, encrypt
|
||||
|
||||
def _parse_attribute(self, tokens: Sequence[str], line_num: int):
|
||||
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
|
||||
filename = self.filestack[-1]
|
||||
if not len(tokens) in {4, 5}:
|
||||
raise ParseError(
|
||||
filename,
|
||||
"Incorrect number of tokens for attribute definition",
|
||||
line_num,
|
||||
)
|
||||
|
||||
has_tag, encrypt = self._parse_attribute_flags(tokens, line_num)
|
||||
name, code, datatype = tokens[1:4]
|
||||
|
||||
if datatype == "concat" and self.cur_vendor != self.rfc_vendor:
|
||||
raise ParseError(
|
||||
filename,
|
||||
'vendor attributes are not allowed to have the datatype "concat"',
|
||||
line_num,
|
||||
)
|
||||
|
||||
try:
|
||||
codes = _parse_attribute_code(code)
|
||||
except ValueError:
|
||||
raise ParseError(
|
||||
filename, f'invalid attribute code {code}""', line_num
|
||||
)
|
||||
|
||||
for code in codes:
|
||||
tlength = self.cur_vendor.tlength
|
||||
if 2 ** (8 * tlength) <= code:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"attribute code is too big, must be smaller than 2**{tlength}",
|
||||
line_num,
|
||||
)
|
||||
if code < 0:
|
||||
raise ParseError(
|
||||
filename,
|
||||
"negative attribute codes are not allowed",
|
||||
line_num,
|
||||
)
|
||||
|
||||
# TODO: Do we some explicit handling of tlvs?
|
||||
# if len(codes) > 1:
|
||||
# self._parse_attribute_tlv(codes, line_num)
|
||||
# else:
|
||||
# pass
|
||||
|
||||
base_datatype = datatype.split("[")[0].replace("-", "")
|
||||
try:
|
||||
attribute_type = Datatype[base_datatype]
|
||||
except KeyError:
|
||||
raise ParseError(filename, f"Illegal type: {datatype}", line_num)
|
||||
|
||||
attribute = Attribute(
|
||||
name,
|
||||
codes[-1],
|
||||
attribute_type,
|
||||
has_tag,
|
||||
encrypt,
|
||||
len(codes) > 1,
|
||||
{},
|
||||
)
|
||||
|
||||
attrcode = codes[0] if len(codes) == 1 else tuple(codes)
|
||||
self.cur_vendor.attrs[attrcode] = attribute
|
||||
|
||||
if self.cur_vendor != self.rfc_vendor:
|
||||
codes = tuple([26] + codes)
|
||||
attrcode = codes[0] if len(codes) == 1 else tuple(codes)
|
||||
self.attrindex[attrcode] = attribute
|
||||
self.attrindex[name] = attribute
|
||||
|
||||
def _parse_value(self, tokens: Sequence[str], line_num: int):
|
||||
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
|
||||
filename = self.filestack[-1]
|
||||
if len(tokens) != 4:
|
||||
raise ParseError(
|
||||
filename,
|
||||
"Incorrect number of tokens for VALUE definition",
|
||||
line_num,
|
||||
)
|
||||
|
||||
(attr_name, key, value) = tokens[1:]
|
||||
value = _parse_number(value)
|
||||
|
||||
attribute = self.attrindex[attr_name]
|
||||
try:
|
||||
datatype = str(attribute.datatype).split(".")[1]
|
||||
lmin, lmax = INTEGER_TYPES[datatype]
|
||||
if value < lmin or value > lmax:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"VALUE {key}({value}) is not in the limit of type {datatype}",
|
||||
line_num,
|
||||
)
|
||||
except KeyError:
|
||||
raise ParseError(
|
||||
filename,
|
||||
f"only attributes with integer typed datatypes can have"
|
||||
f"value definitions {attribute.datatype}",
|
||||
line_num,
|
||||
)
|
||||
attribute.values[value] = key
|
||||
attribute.values[key] = value
|
||||
46
src/pyrad3/host.py
Normal file
46
src/pyrad3/host.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2020 Istvan Ruzman
|
||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||
|
||||
"""Interface Class for RADIUS Clients and Servers"""
|
||||
|
||||
from pyrad3.dictionary import Dictionary
|
||||
from pyrad3 import packet
|
||||
|
||||
|
||||
class Host: # pylint: disable=too-many-arguments
|
||||
"""Interface Class for RADIUS Clients and Servers"""
|
||||
def __init__(
|
||||
self,
|
||||
secret: bytes,
|
||||
radius_dict: Dictionary,
|
||||
authport: int = 1812,
|
||||
acctport: int = 1813,
|
||||
coaport: int = 3799,
|
||||
timeout: float = 30,
|
||||
retries: int = 3,
|
||||
):
|
||||
self.secret = secret
|
||||
self.dictionary = radius_dict
|
||||
|
||||
self.authport = authport
|
||||
self.acctport = acctport
|
||||
self.coaport = coaport
|
||||
|
||||
self.timeout = timeout
|
||||
self.retries = retries
|
||||
|
||||
def create_packet(self, **kwargs):
|
||||
"""Create a generic RADIUS Packet"""
|
||||
return packet.Packet(self, **kwargs)
|
||||
|
||||
def create_auth_packet(self, **kwargs):
|
||||
"""Create an Authentictaion packet (request per default)"""
|
||||
return packet.AuthPacket(self, **kwargs)
|
||||
|
||||
def create_acct_packet(self, **kwargs):
|
||||
"""Create an Accounting packet (request per default)"""
|
||||
return packet.AcctPacket(self, **kwargs)
|
||||
|
||||
def create_coa_packet(self, **kwargs):
|
||||
"""Create an CoA packet (requset per default)"""
|
||||
return packet.CoAPacket(self, **kwargs)
|
||||
305
src/pyrad3/packet.py
Normal file
305
src/pyrad3/packet.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# Copyright 2020 Istvan Ruzman
|
||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||
|
||||
from collections import OrderedDict
|
||||
from enum import IntEnum
|
||||
from secrets import token_bytes
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
from pyrad3.host import Host
|
||||
from pyrad3.utils import (
|
||||
PacketError,
|
||||
Attribute,
|
||||
parse_header,
|
||||
parse_attributes,
|
||||
calculate_authenticator,
|
||||
validate_pap_password,
|
||||
validate_chap_password,
|
||||
)
|
||||
|
||||
HMAC = hmac.new
|
||||
|
||||
|
||||
# Packet codes
|
||||
class Code(IntEnum):
|
||||
AccessRequest = 1
|
||||
AccessAccept = 2
|
||||
AccessReject = 3
|
||||
AccountingRequest = 4
|
||||
AccountingResponse = 5
|
||||
AccessChallenge = 11
|
||||
StatusServer = 12
|
||||
StatusClient = 13
|
||||
DisconnectRequest = 40
|
||||
DisconnectACK = 41
|
||||
DisconnectNAK = 42
|
||||
CoARequest = 43
|
||||
CoAACK = 44
|
||||
CoANAK = 45
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Packet(OrderedDict):
|
||||
def __init__(
|
||||
self,
|
||||
host: Host,
|
||||
code: Code,
|
||||
radius_id: int,
|
||||
*,
|
||||
request: "Packet" = None,
|
||||
**attributes
|
||||
):
|
||||
super().__init__(**attributes)
|
||||
self.code = code
|
||||
self.id = radius_id
|
||||
self.host = host
|
||||
self.request = request
|
||||
self.ordered_attributes: Sequence[Attribute] = []
|
||||
self.raw_packet: Optional[bytearray] = None
|
||||
self.authenticator: Optional[bytes] = None
|
||||
|
||||
@staticmethod
|
||||
def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
|
||||
(code, radius_id, _length, authenticator) = parse_header(raw_packet)
|
||||
|
||||
ordered_attrs = parse_attributes(host.dictionary, raw_packet)
|
||||
|
||||
# Can we do better than Any with type hinting?
|
||||
attrs: Dict[str, Any] = {}
|
||||
for attr in ordered_attrs:
|
||||
try:
|
||||
attrs[attr.name].append(attr.value)
|
||||
except KeyError:
|
||||
attrs[attr.name] = [attr.value]
|
||||
|
||||
parsed_packet = Packet(host, code, radius_id, **attrs)
|
||||
parsed_packet.authenticator = authenticator
|
||||
parsed_packet.raw_packet = raw_packet
|
||||
parsed_packet.ordered_attributes = ordered_attrs
|
||||
return parsed_packet
|
||||
|
||||
def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
|
||||
self.verify_reply(raw_packet)
|
||||
reply = Packet.from_raw(self.host, raw_packet)
|
||||
reply.request = self
|
||||
try:
|
||||
if not reply.validate_message_authenticator():
|
||||
raise PacketError("Packet has a wrong message authenticator")
|
||||
except KeyError:
|
||||
if "EAP-Message" in reply:
|
||||
raise PacketError("Packet is missing a message authenticator")
|
||||
return reply
|
||||
|
||||
def send(self):
|
||||
"""Send the packet to the Client/Server.
|
||||
"""
|
||||
self.host._send_packet(self)
|
||||
|
||||
def verify_reply(self, raw_reply: bytes):
|
||||
"""Verify the reply to this packet.
|
||||
"""
|
||||
if self.id != raw_reply[1]:
|
||||
raise PacketError("Response has a wrong id")
|
||||
|
||||
# self.authenticator MUST be set, this packet got send so by definitation
|
||||
# self.authenticator will not be non, but bytes
|
||||
radius_hash = calculate_authenticator(
|
||||
self.host.secret,
|
||||
self.authenticator, # type: ignore
|
||||
raw_reply,
|
||||
)
|
||||
|
||||
if radius_hash != raw_reply[4:20]:
|
||||
raise PacketError("Reply Packet has a wrong authenticator")
|
||||
|
||||
def validate_message_authenticator(self):
|
||||
message_authenticator = self["Message-Authenticator"]
|
||||
if isinstance(list, message_authenticator):
|
||||
# There are multiple Message Authenticators, but a packet MUST NOT have
|
||||
# more than one
|
||||
return False
|
||||
ma_attribute = self.find_first_attribute("Message-Authenticator")
|
||||
generated = self._generate_message_authenticator(ma_attribute)
|
||||
return message_authenticator == generated
|
||||
|
||||
def _generate_message_authenticator(self, ma_attr: Attribute):
|
||||
assert self.authenticator is not None
|
||||
assert self.request is not None
|
||||
assert self.request.authenticator is not None
|
||||
assert self.raw_packet is not None
|
||||
|
||||
# The message authenticator must be treated as 16 * \00
|
||||
start_pos = ma_attr.pos + 2
|
||||
end_pos = start_pos + 16
|
||||
original_ma: bytes = ma_attr.value
|
||||
self.raw_packet[start_pos:end_pos] = 16 * b"\00"
|
||||
|
||||
hmac_builder = HMAC(self.host.secret, digestmod=hashlib.md5)
|
||||
hmac_builder.update(self.raw_packet)
|
||||
|
||||
if self.code in (Code.AccessRequest, Code.StatusServer):
|
||||
hmac_builder.update(self.authenticator)
|
||||
elif self.code in (
|
||||
Code.AccessAccept,
|
||||
Code.AccessChallenge,
|
||||
Code.AccessReject,
|
||||
):
|
||||
hmac_builder.update(self.request.authenticator)
|
||||
else:
|
||||
hmac_builder.update(16 * b"\00")
|
||||
|
||||
hmac_builder.update(self.raw_packet[20:])
|
||||
self.raw_packet[start_pos:end_pos] = original_ma
|
||||
return hmac_builder.digest()
|
||||
|
||||
def add_message_authenticator(self):
|
||||
self._encode_packet()
|
||||
self._generate_message_authenticator(self)
|
||||
try:
|
||||
# quick lookup before we iterate over the whole packet
|
||||
_ = self["Message-Authenticator"]
|
||||
attr = self.find_first_attribute("Message-Authenticator")
|
||||
except KeyError:
|
||||
self["Message-Authenticator"] = 16 * b"\00"
|
||||
attr = self.ordered_attributes[-1]
|
||||
generated = self._generate_message_authenticator(attr)
|
||||
self[attr.pos + 2 :] = generated
|
||||
|
||||
def refresh_message_authenticator(self):
|
||||
self.add_message_authenticator()
|
||||
|
||||
def find_first_attribute(self, attr_type_name: str) -> Attribute:
|
||||
for attr in self.ordered_attributes:
|
||||
if attr.type == attr_type_name:
|
||||
return attr.type
|
||||
raise KeyError
|
||||
|
||||
def _encode_packet(self):
|
||||
self.raw_packet = None
|
||||
|
||||
|
||||
class AuthPacket(Packet):
|
||||
def __init__(
|
||||
self,
|
||||
host: Host,
|
||||
radius_id: int,
|
||||
auth_type,
|
||||
*,
|
||||
code: Code = Code.AccessRequest,
|
||||
request: Optional[Packet] = None,
|
||||
**attributes
|
||||
):
|
||||
super().__init__(host, code, radius_id, request=request, **attributes)
|
||||
self.auth_type = auth_type
|
||||
if code == Code.AccessRequest:
|
||||
self.authenticator = token_bytes(16)
|
||||
|
||||
def create_accept(self, **attributes):
|
||||
return AuthPacket(
|
||||
self.host,
|
||||
self.id,
|
||||
self.auth_type,
|
||||
request=self,
|
||||
code=Code.AccessAccept,
|
||||
**attributes
|
||||
)
|
||||
|
||||
def create_reject(self, **attributes):
|
||||
return AuthPacket(
|
||||
self.host,
|
||||
self.id,
|
||||
self.auth_type,
|
||||
request=self,
|
||||
code=Code.AccessReject,
|
||||
**attributes
|
||||
)
|
||||
|
||||
def create_challange(self, **attributes):
|
||||
return AuthPacket(
|
||||
self.host,
|
||||
self.id,
|
||||
self.auth_type,
|
||||
request=self,
|
||||
code=Code.AccessChallenge,
|
||||
**attributes
|
||||
)
|
||||
|
||||
def validate_password(self, password: bytes) -> bool:
|
||||
try:
|
||||
return self.validate_pap(password)
|
||||
except KeyError:
|
||||
pass
|
||||
# Will throw KeyError if no chap password exists
|
||||
return self.validate_chap(password)
|
||||
|
||||
def validate_pap(self, password: bytes) -> bool:
|
||||
packet_password = self["User-Password"]
|
||||
return validate_pap_password(
|
||||
self.host.secret,
|
||||
self.authenticator, # type: ignore
|
||||
packet_password,
|
||||
password,
|
||||
)
|
||||
|
||||
def validate_chap(self, password: bytes) -> bool:
|
||||
packet_password = self["Chap-Password"]
|
||||
chap_id = packet_password[:1]
|
||||
chap_password = packet_password[1:]
|
||||
try:
|
||||
challenge = self["Chap-Challenge"]
|
||||
except KeyError:
|
||||
challenge = self.authenticator
|
||||
return validate_chap_password(
|
||||
chap_id, challenge, chap_password, password, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class AcctPacket(Packet):
|
||||
def __init__(
|
||||
self,
|
||||
host: Host,
|
||||
radius_id: int,
|
||||
*,
|
||||
code: Code = Code.AccountingRequest,
|
||||
request: Optional[Packet] = None,
|
||||
**attributes
|
||||
):
|
||||
super().__init__(host, code, radius_id, request=request, **attributes)
|
||||
|
||||
def create_response(self, **attributes):
|
||||
return AcctPacket(
|
||||
self.host,
|
||||
self.id,
|
||||
code=Code.AccountingResponse,
|
||||
request=self,
|
||||
**attributes
|
||||
)
|
||||
|
||||
|
||||
class CoAPacket(Packet):
|
||||
def __init__(
|
||||
self,
|
||||
host: Host,
|
||||
radius_id: int,
|
||||
*,
|
||||
code: Code = Code.CoARequest,
|
||||
request: Optional[Packet] = None,
|
||||
**attributes
|
||||
):
|
||||
super().__init__(host, code, radius_id, request=request, **attributes)
|
||||
|
||||
def create_ack(self, **attributes):
|
||||
return CoAPacket(
|
||||
self.host, self.id, code=Code.CoAACK, request=self, **attributes
|
||||
)
|
||||
|
||||
def create_nack(self, **attributes):
|
||||
return CoAPacket(
|
||||
self.host, self.id, code=Code.CoANAK, request=self, **attributes
|
||||
)
|
||||
73
src/pyrad3/proxy.py
Normal file
73
src/pyrad3/proxy.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# proxy.py
|
||||
#
|
||||
# Copyright 2005,2007 Wichert Akkerman <wichert@wiggy.net>
|
||||
#
|
||||
# A RADIUS proxy as defined in RFC 2138
|
||||
|
||||
import select
|
||||
import socket
|
||||
|
||||
from pyrad.server import Server, ServerPacketError
|
||||
from pyrad import packet
|
||||
|
||||
|
||||
class Proxy(Server):
|
||||
"""Base class for RADIUS proxies.
|
||||
This class extends tha RADIUS server class with the capability to
|
||||
handle communication with other RADIUS servers as well.
|
||||
|
||||
:ivar _proxyfd: network socket used to communicate with other servers
|
||||
:type _proxyfd: socket class instance
|
||||
"""
|
||||
|
||||
def _prepare_sockets(self):
|
||||
Server._prepare_sockets(self)
|
||||
self._proxyfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._fdmap[self._proxyfd.fileno()] = self._proxyfd
|
||||
self._poll.register(
|
||||
self._proxyfd.fileno(),
|
||||
(select.POLLIN | select.POLLPRI | select.POLLERR),
|
||||
)
|
||||
|
||||
def _handle_proxy_packet(self, pkt):
|
||||
"""Process a packet received on the reply socket.
|
||||
If this packet should be dropped instead of processed a
|
||||
:obj:`ServerPacketError` exception should be raised. The main loop
|
||||
will drop the packet and log the reason.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
if pkt.source[0] not in self.hosts:
|
||||
raise ServerPacketError("Received packet from unknown host")
|
||||
pkt.secret = self.hosts[pkt.source[0]].secret
|
||||
|
||||
if pkt.code not in [
|
||||
packet.AccessAccept,
|
||||
packet.AccessReject,
|
||||
packet.AccountingResponse,
|
||||
]:
|
||||
raise ServerPacketError("Received non-response on proxy socket")
|
||||
|
||||
def _process_input(self, fd):
|
||||
"""Process available data.
|
||||
If this packet should be dropped instead of processed a
|
||||
`ServerPacketError` exception should be raised. The main loop
|
||||
will drop the packet and log the reason.
|
||||
|
||||
This function calls either :obj:`HandleAuthPacket`,
|
||||
:obj:`HandleAcctPacket` or :obj:`_handle_proxy_packet` depending on
|
||||
which socket is being processed.
|
||||
|
||||
:param fd: socket to read packet from
|
||||
:type fd: socket class instance
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
if fd.fileno() == self._proxyfd.fileno():
|
||||
pkt = self._grab_packet(
|
||||
lambda data, s=self: s.CreatePacket(packet=data), fd
|
||||
)
|
||||
self._handle_proxy_packet(pkt)
|
||||
else:
|
||||
Server._process_input(self, fd)
|
||||
367
src/pyrad3/server.py
Normal file
367
src/pyrad3/server.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# server.py
|
||||
#
|
||||
# Copyright 2003-2004,2007,2016 Wichert Akkerman <wichert@wiggy.net>
|
||||
|
||||
import logging
|
||||
import select
|
||||
import socket
|
||||
from pyrad import host
|
||||
from pyrad import packet
|
||||
|
||||
|
||||
LOGGER = logging.getLogger("pyrad")
|
||||
|
||||
|
||||
class RemoteHost:
|
||||
"""Remote RADIUS capable host we can talk to."""
|
||||
|
||||
def __init__(
|
||||
self, address, secret, name, authport=1812, acctport=1813, coaport=3799
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
:param address: IP address
|
||||
:type address: string
|
||||
:param secret: RADIUS secret
|
||||
:type secret: string
|
||||
:param name: short name (used for logging only)
|
||||
:type name: string
|
||||
:param authport: port used for authentication packets
|
||||
:type authport: integer
|
||||
:param acctport: port used for accounting packets
|
||||
:type acctport: integer
|
||||
:param coaport: port used for CoA packets
|
||||
:type coaport: integer
|
||||
"""
|
||||
self.address = address
|
||||
self.secret = secret
|
||||
self.authport = authport
|
||||
self.acctport = acctport
|
||||
self.coaport = coaport
|
||||
self.name = name
|
||||
|
||||
|
||||
class ServerPacketError(Exception):
|
||||
"""Exception class for bogus packets.
|
||||
ServerPacketError exceptions are only used inside the Server class to
|
||||
abort processing of a packet.
|
||||
"""
|
||||
|
||||
|
||||
class Server(host.Host):
|
||||
"""Basic RADIUS server.
|
||||
This class implements the basics of a RADIUS server. It takes care
|
||||
of the details of receiving and decoding requests; processing of
|
||||
the requests should be done by overloading the appropriate methods
|
||||
in derived classes.
|
||||
|
||||
:ivar hosts: hosts who are allowed to talk to us
|
||||
:type hosts: dictionary of Host class instances
|
||||
:ivar _poll: poll object for network sockets
|
||||
:type _poll: select.poll class instance
|
||||
:ivar _fdmap: map of filedescriptors to network sockets
|
||||
:type _fdmap: dictionary
|
||||
:cvar MaxPacketSize: maximum size of a RADIUS packet
|
||||
:type MaxPacketSize: integer
|
||||
"""
|
||||
|
||||
MaxPacketSize = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
addresses=[],
|
||||
authport=1812,
|
||||
acctport=1813,
|
||||
coaport=3799,
|
||||
hosts=None,
|
||||
dict=None,
|
||||
auth_enabled=True,
|
||||
acct_enabled=True,
|
||||
coa_enabled=False,
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
:param addresses: IP addresses to listen on
|
||||
:type addresses: sequence of strings
|
||||
:param authport: port to listen on for authentication packets
|
||||
:type authport: integer
|
||||
:param acctport: port to listen on for accounting packets
|
||||
:type acctport: integer
|
||||
:param coaport: port to listen on for CoA packets
|
||||
:type coaport: integer
|
||||
:param hosts: hosts who we can talk to
|
||||
:type hosts: dictionary mapping IP to RemoteHost class instances
|
||||
:param dict: RADIUS dictionary to use
|
||||
:type dict: Dictionary class instance
|
||||
:param auth_enabled: enable auth server (default True)
|
||||
:type auth_enabled: bool
|
||||
:param acct_enabled: enable accounting server (default True)
|
||||
:type acct_enabled: bool
|
||||
:param coa_enabled: enable coa server (default False)
|
||||
:type coa_enabled: bool
|
||||
"""
|
||||
host.Host.__init__(self, authport, acctport, coaport, dict)
|
||||
if hosts is None:
|
||||
self.hosts = {}
|
||||
else:
|
||||
self.hosts = hosts
|
||||
|
||||
self.auth_enabled = auth_enabled
|
||||
self.authfds = []
|
||||
self.acct_enabled = acct_enabled
|
||||
self.acctfds = []
|
||||
self.coa_enabled = coa_enabled
|
||||
self.coafds = []
|
||||
|
||||
for addr in addresses:
|
||||
self.BindToAddress(addr)
|
||||
|
||||
def _get_addr_info(self, addr):
|
||||
"""Use getaddrinfo to lookup all addresses for each address.
|
||||
|
||||
Returns a list of tuples or an empty list:
|
||||
[(family, address)]
|
||||
|
||||
:param addr: IP address to lookup
|
||||
:type addr: string
|
||||
"""
|
||||
results = set()
|
||||
try:
|
||||
tmp = socket.getaddrinfo(addr, "www")
|
||||
except socket.gaierror:
|
||||
return []
|
||||
|
||||
for el in tmp:
|
||||
results.add((el[0], el[4][0]))
|
||||
|
||||
return results
|
||||
|
||||
def BindToAddress(self, addr):
|
||||
"""Add an address to listen to.
|
||||
An empty string indicated you want to listen on all addresses.
|
||||
|
||||
:param addr: IP address to listen on
|
||||
:type addr: string
|
||||
"""
|
||||
addrFamily = self._get_addr_info(addr)
|
||||
for (family, address) in addrFamily:
|
||||
if self.auth_enabled:
|
||||
authfd = socket.socket(family, socket.SOCK_DGRAM)
|
||||
authfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
authfd.bind((address, self.authport))
|
||||
self.authfds.append(authfd)
|
||||
|
||||
if self.acct_enabled:
|
||||
acctfd = socket.socket(family, socket.SOCK_DGRAM)
|
||||
acctfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
acctfd.bind((address, self.acctport))
|
||||
self.acctfds.append(acctfd)
|
||||
|
||||
if self.coa_enabled:
|
||||
coafd = socket.socket(family, socket.SOCK_DGRAM)
|
||||
coafd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
coafd.bind((address, self.coaport))
|
||||
self.coafds.append(coafd)
|
||||
|
||||
def HandleAuthPacket(self, pkt):
|
||||
"""Authentication packet handler.
|
||||
This is an empty function that is called when a valid
|
||||
authentication packet has been received. It can be overriden in
|
||||
derived classes to add custom behaviour.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
|
||||
def HandleAcctPacket(self, pkt):
|
||||
"""Accounting packet handler.
|
||||
This is an empty function that is called when a valid
|
||||
accounting packet has been received. It can be overriden in
|
||||
derived classes to add custom behaviour.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
|
||||
def HandleCoaPacket(self, pkt):
|
||||
"""CoA packet handler.
|
||||
This is an empty function that is called when a valid
|
||||
accounting packet has been received. It can be overriden in
|
||||
derived classes to add custom behaviour.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
|
||||
def HandleDisconnectPacket(self, pkt):
|
||||
"""CoA packet handler.
|
||||
This is an empty function that is called when a valid
|
||||
accounting packet has been received. It can be overriden in
|
||||
derived classes to add custom behaviour.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
|
||||
def _add_secret(self, pkt):
|
||||
"""Add secret to packets received and raise ServerPacketError
|
||||
for unknown hosts.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
if pkt.source[0] in self.hosts:
|
||||
pkt.secret = self.hosts[pkt.source[0]].secret
|
||||
elif "0.0.0.0" in self.hosts:
|
||||
pkt.secret = self.hosts["0.0.0.0"].secret
|
||||
else:
|
||||
raise ServerPacketError("Received packet from unknown host")
|
||||
|
||||
def _handle_auth_packet(self, pkt):
|
||||
"""Process a packet received on the authentication port.
|
||||
If this packet should be dropped instead of processed a
|
||||
ServerPacketError exception should be raised. The main loop will
|
||||
drop the packet and log the reason.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
self._add_secret(pkt)
|
||||
if pkt.code != packet.AccessRequest:
|
||||
raise ServerPacketError(
|
||||
"Received non-authentication packet on authentication port"
|
||||
)
|
||||
self.HandleAuthPacket(pkt)
|
||||
|
||||
def _handle_acct_packet(self, pkt):
|
||||
"""Process a packet received on the accounting port.
|
||||
If this packet should be dropped instead of processed a
|
||||
ServerPacketError exception should be raised. The main loop will
|
||||
drop the packet and log the reason.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
self._add_secret(pkt)
|
||||
if pkt.code not in [
|
||||
packet.AccountingRequest,
|
||||
packet.AccountingResponse,
|
||||
]:
|
||||
raise ServerPacketError(
|
||||
"Received non-accounting packet on accounting port"
|
||||
)
|
||||
self.HandleAcctPacket(pkt)
|
||||
|
||||
def _handle_coa_packet(self, pkt):
|
||||
"""Process a packet received on the coa port.
|
||||
If this packet should be dropped instead of processed a
|
||||
ServerPacketError exception should be raised. The main loop will
|
||||
drop the packet and log the reason.
|
||||
|
||||
:param pkt: packet to process
|
||||
:type pkt: Packet class instance
|
||||
"""
|
||||
self._add_secret(pkt)
|
||||
pkt.secret = self.hosts[pkt.source[0]].secret
|
||||
if pkt.code == packet.CoARequest:
|
||||
self.HandleCoaPacket(pkt)
|
||||
elif pkt.code == packet.DisconnectRequest:
|
||||
self.HandleDisconnectPacket(pkt)
|
||||
else:
|
||||
raise ServerPacketError("Received non-coa packet on coa port")
|
||||
|
||||
def _grab_packet(self, pktgen, fd):
|
||||
"""Read a packet from a network connection.
|
||||
This method assumes there is data waiting for to be read.
|
||||
|
||||
:param fd: socket to read packet from
|
||||
:type fd: socket class instance
|
||||
:return: RADIUS packet
|
||||
:rtype: Packet class instance
|
||||
"""
|
||||
(data, source) = fd.recvfrom(self.MaxPacketSize)
|
||||
pkt = pktgen(data)
|
||||
pkt.source = source
|
||||
pkt.fd = fd
|
||||
return pkt
|
||||
|
||||
def _prepare_sockets(self):
|
||||
"""Prepare all sockets to receive packets.
|
||||
"""
|
||||
for fd in self.authfds + self.acctfds + self.coafds:
|
||||
self._fdmap[fd.fileno()] = fd
|
||||
self._poll.register(
|
||||
fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR
|
||||
)
|
||||
if self.auth_enabled:
|
||||
self._realauthfds = list(map(lambda x: x.fileno(), self.authfds))
|
||||
if self.acct_enabled:
|
||||
self._realacctfds = list(map(lambda x: x.fileno(), self.acctfds))
|
||||
if self.coa_enabled:
|
||||
self._realcoafds = list(map(lambda x: x.fileno(), self.coafds))
|
||||
|
||||
def CreateReplyPacket(self, pkt, **attributes):
|
||||
"""Create a reply packet.
|
||||
Create a new packet which can be returned as a reply to a received
|
||||
packet.
|
||||
|
||||
:param pkt: original packet
|
||||
:type pkt: Packet instance
|
||||
"""
|
||||
reply = pkt.CreateReply(**attributes)
|
||||
reply.source = pkt.source
|
||||
return reply
|
||||
|
||||
def _process_input(self, fd):
|
||||
"""Process available data.
|
||||
If this packet should be dropped instead of processed a
|
||||
PacketError exception should be raised. The main loop will
|
||||
drop the packet and log the reason.
|
||||
|
||||
This function calls either HandleAuthPacket() or
|
||||
HandleAcctPacket() depending on which socket is being
|
||||
processed.
|
||||
|
||||
:param fd: socket to read packet from
|
||||
:type fd: socket class instance
|
||||
"""
|
||||
if self.auth_enabled and fd.fileno() in self._realauthfds:
|
||||
pkt = self._grab_packet(
|
||||
lambda data, s=self: s.CreateAuthPacket(packet=data), fd
|
||||
)
|
||||
self._handle_auth_packet(pkt)
|
||||
elif self.acct_enabled and fd.fileno() in self._realacctfds:
|
||||
pkt = self._grab_packet(
|
||||
lambda data, s=self: s.CreateAcctPacket(packet=data), fd
|
||||
)
|
||||
self._handle_acct_packet(pkt)
|
||||
elif self.coa_enabled:
|
||||
pkt = self._grab_packet(
|
||||
lambda data, s=self: s.CreateCoAPacket(packet=data), fd
|
||||
)
|
||||
self._handle_coa_packet(pkt)
|
||||
else:
|
||||
raise ServerPacketError("Received packet for unknown handler")
|
||||
|
||||
def Run(self):
|
||||
"""Main loop.
|
||||
This method is the main loop for a RADIUS server. It waits
|
||||
for packets to arrive via the network and calls other methods
|
||||
to process them.
|
||||
"""
|
||||
self._poll = select.poll()
|
||||
self._fdmap = {}
|
||||
self._prepare_sockets()
|
||||
|
||||
while True:
|
||||
for (fd, event) in self._poll.poll():
|
||||
if event == select.POLLIN:
|
||||
try:
|
||||
fdo = self._fdmap[fd]
|
||||
self._process_input(fdo)
|
||||
except ServerPacketError as err:
|
||||
LOGGER.info("Dropping packet: %s", err)
|
||||
except packet.PacketError as err:
|
||||
LOGGER.info("Received a broken packet: %s", err)
|
||||
else:
|
||||
LOGGER.error("Unexpected event in server main loop")
|
||||
429
src/pyrad3/server_async.py
Normal file
429
src/pyrad3/server_async.py
Normal file
@@ -0,0 +1,429 @@
|
||||
# server_async.py
|
||||
#
|
||||
# Copyright 2018-2019 Geaaru <geaaru@gmail.com>
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from pyrad.packet import (
|
||||
Packet,
|
||||
AccessAccept,
|
||||
AccessReject,
|
||||
AccountingRequest,
|
||||
AccountingResponse,
|
||||
DisconnectACK,
|
||||
DisconnectNAK,
|
||||
DisconnectRequest,
|
||||
CoARequest,
|
||||
CoAACK,
|
||||
CoANAK,
|
||||
AccessRequest,
|
||||
AuthPacket,
|
||||
AcctPacket,
|
||||
CoAPacket,
|
||||
PacketError,
|
||||
)
|
||||
|
||||
from pyrad.server import ServerPacketError
|
||||
|
||||
|
||||
class ServerType(Enum):
|
||||
Auth = "Authentication"
|
||||
Acct = "Accounting"
|
||||
Coa = "Coa"
|
||||
|
||||
|
||||
class DatagramProtocolServer(asyncio.Protocol):
|
||||
def __init__(
|
||||
self, ip, port, logger, server, server_type, hosts, request_callback
|
||||
):
|
||||
self.transport = None
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self.logger = logger
|
||||
self.server = server
|
||||
self.hosts = hosts
|
||||
self.server_type = server_type
|
||||
self.request_callback = request_callback
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
self.logger.info("[%s:%d] Transport created", self.ip, self.port)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if exc:
|
||||
self.logger.warn(
|
||||
"[%s:%d] Connection lost: %s", self.ip, self.port, str(exc)
|
||||
)
|
||||
else:
|
||||
self.logger.info("[%s:%d] Transport closed", self.ip, self.port)
|
||||
|
||||
def send_response(self, reply, addr):
|
||||
self.transport.sendto(reply.ReplyPacket(), addr)
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
self.logger.debug(
|
||||
"[%s:%d] Received %d bytes from %s",
|
||||
self.ip,
|
||||
self.port,
|
||||
len(data),
|
||||
addr,
|
||||
)
|
||||
|
||||
receive_date = datetime.utcnow()
|
||||
|
||||
if addr[0] in self.hosts:
|
||||
remote_host = self.hosts[addr[0]]
|
||||
elif "0.0.0.0" in self.hosts:
|
||||
remote_host = self.hosts["0.0.0.0"]
|
||||
else:
|
||||
self.logger.warn(
|
||||
"[%s:%d] Drop package from unknown source %s",
|
||||
self.ip,
|
||||
self.port,
|
||||
addr,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.debug(
|
||||
"[%s:%d] Received from %s packet: %s",
|
||||
self.ip,
|
||||
self.port,
|
||||
addr,
|
||||
data.hex(),
|
||||
)
|
||||
req = Packet(packet=data, dict=self.server.dict)
|
||||
except Exception as exc:
|
||||
self.logger.error(
|
||||
"[%s:%d] Error on decode packet: %s", self.ip, self.port, exc
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if req.code in (
|
||||
AccountingResponse,
|
||||
AccessAccept,
|
||||
AccessReject,
|
||||
CoANAK,
|
||||
CoAACK,
|
||||
DisconnectNAK,
|
||||
DisconnectACK,
|
||||
):
|
||||
raise ServerPacketError(f"Invalid response packet {req.code}")
|
||||
|
||||
elif self.server_type == ServerType.Auth:
|
||||
if req.code != AccessRequest:
|
||||
raise ServerPacketError(
|
||||
"Received non-auth packet on auth port"
|
||||
)
|
||||
req = AuthPacket(
|
||||
secret=remote_host.secret,
|
||||
dict=self.server.dict,
|
||||
packet=data,
|
||||
)
|
||||
if self.server.enable_pkt_verify:
|
||||
if req.VerifyAuthRequest():
|
||||
raise PacketError("Packet verification failed")
|
||||
|
||||
elif self.server_type == ServerType.Coa:
|
||||
if req.code != DisconnectRequest and req.code != CoARequest:
|
||||
raise ServerPacketError(
|
||||
"Received non-coa packet on coa port"
|
||||
)
|
||||
req = CoAPacket(
|
||||
secret=remote_host.secret,
|
||||
dict=self.server.dict,
|
||||
packet=data,
|
||||
)
|
||||
if self.server.enable_pkt_verify:
|
||||
if req.VerifyCoARequest():
|
||||
raise PacketError("Packet verification failed")
|
||||
|
||||
elif self.server_type == ServerType.Acct:
|
||||
|
||||
if req.code != AccountingRequest:
|
||||
raise ServerPacketError(
|
||||
"Received non-acct packet on acct port"
|
||||
)
|
||||
req = AcctPacket(
|
||||
secret=remote_host.secret,
|
||||
dict=self.server.dict,
|
||||
packet=data,
|
||||
)
|
||||
if self.server.enable_pkt_verify:
|
||||
if req.VerifyAcctRequest():
|
||||
raise PacketError("Packet verification failed")
|
||||
|
||||
# Call request callback
|
||||
self.request_callback(self, req, addr)
|
||||
except Exception as exc:
|
||||
if self.server.debug:
|
||||
self.logger.exception(
|
||||
"[%s:%d] Error for packet from %s", self.ip, self.port, addr
|
||||
)
|
||||
else:
|
||||
self.logger.error(
|
||||
"[%s:%d] Error for packet from %s: %s",
|
||||
self.ip,
|
||||
self.port,
|
||||
addr,
|
||||
exc,
|
||||
)
|
||||
|
||||
process_date = datetime.utcnow()
|
||||
self.logger.debug(
|
||||
"[%s:%d] Request from %s processed in %d ms",
|
||||
self.ip,
|
||||
self.port,
|
||||
addr,
|
||||
(process_date - receive_date).microseconds / 1000,
|
||||
)
|
||||
|
||||
def error_received(self, exc):
|
||||
self.logger.error("[%s:%d] Error received: %s", self.ip, self.port, exc)
|
||||
|
||||
async def close_transport(self):
|
||||
if self.transport:
|
||||
self.logger.debug("[%s:%d] Close transport...", self.ip, self.port)
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
|
||||
def __str__(self):
|
||||
return f"DatagramProtocolServer(ip={self.ip}, port={self.port})"
|
||||
|
||||
# Used as protocol_factory
|
||||
def __call__(self):
|
||||
return self
|
||||
|
||||
|
||||
class ServerAsync(metaclass=ABCMeta):
|
||||
def __init__(
|
||||
self,
|
||||
auth_port=1812,
|
||||
acct_port=1813,
|
||||
coa_port=3799,
|
||||
hosts=None,
|
||||
dictionary=None,
|
||||
loop=None,
|
||||
logger_name="pyrad",
|
||||
enable_pkt_verify=False,
|
||||
debug=False,
|
||||
):
|
||||
|
||||
if not loop:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self.loop = loop
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
|
||||
if hosts is None:
|
||||
self.hosts = {}
|
||||
else:
|
||||
self.hosts = hosts
|
||||
|
||||
self.auth_port = auth_port
|
||||
self.auth_protocols = []
|
||||
|
||||
self.acct_port = acct_port
|
||||
self.acct_protocols = []
|
||||
|
||||
self.coa_port = coa_port
|
||||
self.coa_protocols = []
|
||||
|
||||
self.dict = dictionary
|
||||
self.enable_pkt_verify = enable_pkt_verify
|
||||
|
||||
self.debug = debug
|
||||
|
||||
def __request_handler__(self, protocol, req, addr):
|
||||
|
||||
try:
|
||||
if protocol.server_type == ServerType.Acct:
|
||||
self.handle_acct_packet(protocol, req, addr)
|
||||
elif protocol.server_type == ServerType.Auth:
|
||||
self.handle_auth_packet(protocol, req, addr)
|
||||
elif (
|
||||
protocol.server_type == ServerType.Coa
|
||||
and req.code == CoARequest
|
||||
):
|
||||
self.handle_coa_packet(protocol, req, addr)
|
||||
elif (
|
||||
protocol.server_type == ServerType.Coa
|
||||
and req.code == DisconnectRequest
|
||||
):
|
||||
self.handle_disconnect_packet(protocol, req, addr)
|
||||
else:
|
||||
self.logger.error(
|
||||
"[%s:%s] Unexpected request found",
|
||||
protocol.ip,
|
||||
protocol.port,
|
||||
)
|
||||
except Exception as exc:
|
||||
if self.debug:
|
||||
self.logger.exception(
|
||||
"[%s:%s] Unexpected error", protocol.ip, protocol.port
|
||||
)
|
||||
|
||||
else:
|
||||
self.logger.error(
|
||||
"[%s:%s] Unexpected error: %s",
|
||||
protocol.ip,
|
||||
protocol.port,
|
||||
exc,
|
||||
)
|
||||
|
||||
def __is_present_proto__(self, ip, port):
|
||||
if port == self.auth_port:
|
||||
for proto in self.auth_protocols:
|
||||
if proto.ip == ip:
|
||||
return True
|
||||
elif port == self.acct_port:
|
||||
for proto in self.acct_protocols:
|
||||
if proto.ip == ip:
|
||||
return True
|
||||
elif port == self.coa_port:
|
||||
for proto in self.coa_protocols:
|
||||
if proto.ip == ip:
|
||||
return True
|
||||
return False
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
@staticmethod
|
||||
def CreateReplyPacket(pkt, **attributes):
|
||||
"""Create a reply packet.
|
||||
Create a new packet which can be returned as a reply to a received
|
||||
packet.
|
||||
|
||||
:param pkt: original packet
|
||||
:type pkt: Packet instance
|
||||
"""
|
||||
reply = pkt.CreateReply(**attributes)
|
||||
return reply
|
||||
|
||||
async def initialize_transports(
|
||||
self,
|
||||
enable_acct=False,
|
||||
enable_auth=False,
|
||||
enable_coa=False,
|
||||
addresses=None,
|
||||
):
|
||||
|
||||
task_list = []
|
||||
|
||||
if not enable_acct and not enable_auth and not enable_coa:
|
||||
raise Exception("No transports selected")
|
||||
if not addresses or len(addresses) == 0:
|
||||
addresses = ["127.0.0.1"]
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
for addr in addresses:
|
||||
|
||||
if enable_acct and not self.__is_present_proto__(
|
||||
addr, self.acct_port
|
||||
):
|
||||
protocol_acct = DatagramProtocolServer(
|
||||
addr,
|
||||
self.acct_port,
|
||||
self.logger,
|
||||
self,
|
||||
ServerType.Acct,
|
||||
self.hosts,
|
||||
self.__request_handler__,
|
||||
)
|
||||
|
||||
bind_addr = (addr, self.acct_port)
|
||||
acct_connect = self.loop.create_datagram_endpoint(
|
||||
protocol_acct, reuse_port=True, local_addr=bind_addr
|
||||
)
|
||||
self.acct_protocols.append(protocol_acct)
|
||||
task_list.append(acct_connect)
|
||||
|
||||
if enable_auth and not self.__is_present_proto__(
|
||||
addr, self.auth_port
|
||||
):
|
||||
protocol_auth = DatagramProtocolServer(
|
||||
addr,
|
||||
self.auth_port,
|
||||
self.logger,
|
||||
self,
|
||||
ServerType.Auth,
|
||||
self.hosts,
|
||||
self.__request_handler__,
|
||||
)
|
||||
bind_addr = (addr, self.auth_port)
|
||||
|
||||
auth_connect = self.loop.create_datagram_endpoint(
|
||||
protocol_auth, reuse_port=True, local_addr=bind_addr
|
||||
)
|
||||
self.auth_protocols.append(protocol_auth)
|
||||
task_list.append(auth_connect)
|
||||
|
||||
if enable_coa and not self.__is_present_proto__(
|
||||
addr, self.coa_port
|
||||
):
|
||||
protocol_coa = DatagramProtocolServer(
|
||||
addr,
|
||||
self.coa_port,
|
||||
self.logger,
|
||||
self,
|
||||
ServerType.Coa,
|
||||
self.hosts,
|
||||
self.__request_handler__,
|
||||
)
|
||||
bind_addr = (addr, self.coa_port)
|
||||
|
||||
coa_connect = self.loop.create_datagram_endpoint(
|
||||
protocol_coa, reuse_port=True, local_addr=bind_addr
|
||||
)
|
||||
self.coa_protocols.append(protocol_coa)
|
||||
task_list.append(coa_connect)
|
||||
|
||||
await asyncio.ensure_future(
|
||||
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
|
||||
)
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
async def deinitialize_transports(
|
||||
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
|
||||
):
|
||||
|
||||
if deinit_coa:
|
||||
for proto in self.coa_protocols:
|
||||
await proto.close_transport()
|
||||
del proto
|
||||
|
||||
self.coa_protocols = []
|
||||
|
||||
if deinit_auth:
|
||||
for proto in self.auth_protocols:
|
||||
await proto.close_transport()
|
||||
del proto
|
||||
|
||||
self.auth_protocols = []
|
||||
|
||||
if deinit_acct:
|
||||
for proto in self.acct_protocols:
|
||||
await proto.close_transport()
|
||||
del proto
|
||||
|
||||
self.acct_protocols = []
|
||||
|
||||
@abstractmethod
|
||||
def handle_auth_packet(self, protocol, pkt, addr):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_acct_packet(self, protocol, pkt, addr):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_coa_packet(self, protocol, pkt, addr):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_disconnect_packet(self, protocol, pkt, addr):
|
||||
pass
|
||||
243
src/pyrad3/tools.py
Normal file
243
src/pyrad3/tools.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# Copyright 2020 Istvan Ruzman
|
||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||
|
||||
"""Collections of functions to en- and decode RADIUS Attributes"""
|
||||
|
||||
from typing import Union
|
||||
from ipaddress import IPv4Address, IPv6Address, IPv6Network, ip_network, ip_address
|
||||
|
||||
import struct
|
||||
|
||||
|
||||
def encode_string(string: str) -> bytes:
|
||||
"""Encode a RADIUS value of type string"""
|
||||
if len(string) > 253:
|
||||
raise ValueError("Can only encode strings of <= 253 characters")
|
||||
if isinstance(string, str):
|
||||
return string.encode("utf-8")
|
||||
return string
|
||||
|
||||
|
||||
def encode_octets(string: bytes) -> bytes:
|
||||
"""Encode a RADIUS value of type octet"""
|
||||
if len(string) > 253:
|
||||
raise ValueError("Can only encode strings of <= 253 characters")
|
||||
return string
|
||||
|
||||
|
||||
def encode_address(addr: Union[str, IPv4Address]) -> bytes:
|
||||
"""Encode a RADIUS value of type ipaddr"""
|
||||
return IPv4Address(addr).packed
|
||||
|
||||
|
||||
def encode_ipv6_prefix(addr: Union[str, IPv6Network]) -> bytes:
|
||||
"""Encode a RADIUS value of type ipv6prefix"""
|
||||
address = IPv6Network(addr)
|
||||
return struct.pack("2B", *[0, address.prefixlen]) + address.network_address.packed
|
||||
|
||||
|
||||
def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes:
|
||||
"""Encode a RADIUS value of type ipv6addr"""
|
||||
return IPv6Address(addr).packed
|
||||
|
||||
|
||||
def encode_combo_ip(addr: Union[str, IPv4Address, IPv6Address]) -> bytes:
|
||||
"""Encode a RADIUS value of type combo-ip"""
|
||||
return ip_address(addr).packed
|
||||
|
||||
|
||||
def encode_ascend_binary(string: str) -> bytes:
|
||||
"""
|
||||
struct_format: List of type=value pairs sperated by spaces.
|
||||
|
||||
Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32'
|
||||
|
||||
Type:
|
||||
family ipv4(default) or ipv6
|
||||
action discard(default) or accept
|
||||
direction in(default) or out
|
||||
src source prefix (default ignore)
|
||||
dst destination prefix (default ignore)
|
||||
proto protocol number / next-header number (default ignore)
|
||||
sport source port (default ignore)
|
||||
dport destination port (default ignore)
|
||||
sportq source port qualifier (default 0)
|
||||
dportq destination port qualifier (default 0)
|
||||
|
||||
Source/Destination Port Qualifier:
|
||||
0 no compare
|
||||
1 less than
|
||||
2 equal to
|
||||
3 greater than
|
||||
4 not equal to
|
||||
"""
|
||||
|
||||
terms = {
|
||||
"family": b"\x01",
|
||||
"action": b"\x00",
|
||||
"direction": b"\x01",
|
||||
"src": b"\x00\x00\x00\x00",
|
||||
"dst": b"\x00\x00\x00\x00",
|
||||
"srcl": b"\x00",
|
||||
"dstl": b"\x00",
|
||||
"proto": b"\x00",
|
||||
"sport": b"\x00\x00",
|
||||
"dport": b"\x00\x00",
|
||||
"sportq": b"\x00",
|
||||
"dportq": b"\x00",
|
||||
}
|
||||
|
||||
for term in string.split(" "):
|
||||
key, value = term.split("=")
|
||||
if key == "family" and value == "ipv6":
|
||||
terms[key] = b"\x03"
|
||||
if terms["src"] == b"\x00\x00\x00\x00":
|
||||
terms["src"] = 16 * b"\x00"
|
||||
if terms["dst"] == b"\x00\x00\x00\x00":
|
||||
terms["dst"] = 16 * b"\x00"
|
||||
elif key == "action" and value == "accept":
|
||||
terms[key] = b"\x01"
|
||||
elif key == "direction" and value == "out":
|
||||
terms[key] = b"\x00"
|
||||
elif key in ("src", "dst"):
|
||||
address = ip_network(value)
|
||||
terms[key] = address.network_address.packed
|
||||
terms[key + "l"] = struct.pack("B", address.prefixlen)
|
||||
elif key in ("sport", "dport"):
|
||||
terms[key] = struct.pack("!H", int(value))
|
||||
elif key in ("sportq", "dportq", "proto"):
|
||||
terms[key] = struct.pack("B", int(value))
|
||||
|
||||
trailer = 8 * b"\x00"
|
||||
|
||||
result = b"".join(
|
||||
(
|
||||
terms["family"],
|
||||
terms["action"],
|
||||
terms["direction"],
|
||||
b"\x00",
|
||||
terms["src"],
|
||||
terms["dst"],
|
||||
terms["srcl"],
|
||||
terms["dstl"],
|
||||
terms["proto"],
|
||||
b"\x00",
|
||||
terms["sport"],
|
||||
terms["dport"],
|
||||
terms["sportq"],
|
||||
terms["dportq"],
|
||||
b"\x00\x00",
|
||||
trailer,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def encode_integer(num: Union[int, str], struct_format="!I") -> bytes:
|
||||
"""Encode a RADIUS value of some type integer"""
|
||||
return struct.pack(struct_format, int(num))
|
||||
|
||||
|
||||
def encode_date(num: Union[int, str]) -> bytes:
|
||||
"""Encode a RADIUS value of type date"""
|
||||
return struct.pack("!I", int(num))
|
||||
|
||||
|
||||
def decode_string(string: bytes) -> Union[str, bytes]:
|
||||
"""Decode a RADIUS value of type string"""
|
||||
try:
|
||||
return string.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return string
|
||||
|
||||
|
||||
def decode_octets(string: bytes) -> bytes:
|
||||
"""Decode a RADIUS value of type octet"""
|
||||
return string
|
||||
|
||||
|
||||
def decode_address(addr: bytes) -> IPv4Address:
|
||||
"""Decode a RADIUS value of type ipaddr"""
|
||||
return IPv4Address(addr)
|
||||
|
||||
|
||||
def decode_ipv6_prefix(addr: bytes) -> IPv6Network:
|
||||
"""Decode a RADIUS value of type ipv6prefix"""
|
||||
addr = addr + b"\x00" * (18 - len(addr))
|
||||
prefix = addr[:2]
|
||||
addr = addr[2:]
|
||||
return IPv6Network((prefix, addr))
|
||||
|
||||
|
||||
def decode_ipv6_address(addr: bytes) -> IPv6Address:
|
||||
"""Decode a RADIUS value of type ipv6addr"""
|
||||
addr = addr + b"\x00" * (16 - len(addr))
|
||||
return IPv6Address(addr)
|
||||
|
||||
|
||||
def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]:
|
||||
"""Decode a RADIUS value of type combo-ip"""
|
||||
return ip_address(addr).packed
|
||||
|
||||
|
||||
def decode_ascend_binary(string):
|
||||
"""Decode a RADIUS value of type abinary"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def decode_integer(num: bytes, struct_format="!I") -> int:
|
||||
"""Decode a RADIUS value of some integer type"""
|
||||
return (struct.unpack(struct_format, num))[0]
|
||||
|
||||
|
||||
def decode_date(num):
|
||||
"""Decode a RADIUS value of type date"""
|
||||
return (struct.unpack("!I", num))[0]
|
||||
|
||||
|
||||
ENCODE_MAP = {
|
||||
"string": encode_string,
|
||||
"octets": encode_octets,
|
||||
"integer": encode_integer,
|
||||
"ipaddr": encode_address,
|
||||
"ipv6prefix": encode_ipv6_prefix,
|
||||
"ipv6addr": encode_ipv6_address,
|
||||
"abinary": encode_ascend_binary,
|
||||
"signed": lambda value: encode_integer(value, "!i"),
|
||||
"short": lambda value: encode_integer(value, "!H"),
|
||||
"byte": lambda value: encode_integer(value, "!B"),
|
||||
"integer64": lambda value: encode_integer(value, '!Q'),
|
||||
"date": encode_date,
|
||||
}
|
||||
|
||||
|
||||
def encode_attr(datatype, value):
|
||||
"""Encode a RADIUS attribute"""
|
||||
try:
|
||||
return ENCODE_MAP[datatype](value)
|
||||
except KeyError:
|
||||
raise ValueError(f"Unknown attribute type {datatype}")
|
||||
|
||||
|
||||
DECODE_MAP = {
|
||||
"string": decode_string,
|
||||
"octets": decode_octets,
|
||||
"integer": decode_integer,
|
||||
"ipaddr": decode_address,
|
||||
"ipv6prefix": decode_ipv6_prefix,
|
||||
"ipv6addr": decode_ipv6_address,
|
||||
"abinary": decode_ascend_binary,
|
||||
"signed": lambda value: decode_integer(value, "!i"),
|
||||
"short": lambda value: decode_integer(value, "!H"),
|
||||
"byte": lambda value: decode_integer(value, "!B"),
|
||||
"integer64": lambda value: decode_integer(value, "!Q"),
|
||||
"date": decode_date,
|
||||
}
|
||||
|
||||
|
||||
def decode_attr(datatype, value):
|
||||
"""Decode a RADIUS attribute"""
|
||||
try:
|
||||
return DECODE_MAP[datatype](value)
|
||||
except KeyError:
|
||||
raise ValueError(f"Unknown attribute type {datatype}")
|
||||
234
src/pyrad3/utils.py
Normal file
234
src/pyrad3/utils.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# Copyright 2020 Istvan Ruzman
|
||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||
|
||||
"""Collection of functions to deal with RADIUS packet en- and decoding."""
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import List, Union
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
import struct
|
||||
|
||||
from pyrad3.dictionary import Dictionary
|
||||
|
||||
RANDOM_GENERATOR = secrets.SystemRandom()
|
||||
MD5 = hashlib.md5
|
||||
|
||||
|
||||
class PacketError(Exception):
|
||||
"""Exception for Invalid Packets"""
|
||||
|
||||
|
||||
Header = namedtuple("Header", ["code", "radius_id", "length", "authenticator"])
|
||||
Attribute = namedtuple("Attribute", ["name", "pos", "type", "length", "value"])
|
||||
|
||||
|
||||
def parse_header(raw_packet: bytes) -> Header:
|
||||
"""Parse the Header of a RADIUS Packet."""
|
||||
try:
|
||||
header = struct.unpack("!BBH16s", raw_packet)
|
||||
except struct.error:
|
||||
raise PacketError("Packet header is corrupt")
|
||||
|
||||
length = header[3]
|
||||
if len(raw_packet) != length:
|
||||
raise PacketError(
|
||||
f"RADIUS Packet ({len(raw_packet)}) has an invalid length ({length})"
|
||||
)
|
||||
if length > 4096:
|
||||
raise PacketError(f"Packet length is too big ({length})")
|
||||
return Header(*header)
|
||||
|
||||
|
||||
def parse_attributes(
|
||||
rad_dict: Dictionary, raw_packet: bytes
|
||||
) -> List[Attribute]:
|
||||
"""Parse the Attributes in a RADIUS Packet.
|
||||
|
||||
This function skips the Header. The Header must be parsed and verified
|
||||
separately.
|
||||
"""
|
||||
attributes = []
|
||||
packet = raw_packet[20:]
|
||||
|
||||
while packet:
|
||||
try:
|
||||
(key, length) = struct.unpack("!BB", packet[0:2])
|
||||
except struct.error:
|
||||
raise PacketError("Attribute header is corrupt")
|
||||
if length < 2:
|
||||
raise PacketError(f"Attribute length ({length}) is too small")
|
||||
|
||||
value = packet[2:length]
|
||||
offset = len(raw_packet) - len(packet) + length
|
||||
if key == 26:
|
||||
try:
|
||||
attributes.extend(
|
||||
parse_vendor_attributes(rad_dict, offset, value)
|
||||
)
|
||||
except (PacketError, IndexError):
|
||||
attributes.append(
|
||||
Attribute(
|
||||
name="Unknown-Vendor-Attribute",
|
||||
pos=offset,
|
||||
type="octets",
|
||||
length=int(packet[1]),
|
||||
value=packet[2:],
|
||||
)
|
||||
)
|
||||
else:
|
||||
key = parse_key(rad_dict, key)
|
||||
attributes.extend(parse_value(rad_dict, key, offset, value))
|
||||
packet = packet[length:]
|
||||
|
||||
return attributes
|
||||
|
||||
|
||||
def parse_vendor_attributes(
|
||||
rad_dict: Dictionary, offset: int, vendor_value: bytes
|
||||
) -> List[Attribute]:
|
||||
"""Parse A Vendor Attribute"""
|
||||
if len(vendor_value) < 4:
|
||||
raise PacketError
|
||||
vendor_id = int.from_bytes(vendor_value[:4], "big")
|
||||
vendor_dict = rad_dict.vendor_by_id[vendor_id]
|
||||
vendor_name = vendor_dict.name
|
||||
|
||||
attributes = []
|
||||
vendor_tlv = vendor_value[4:]
|
||||
while vendor_tlv:
|
||||
try:
|
||||
(key, length) = struct.unpack("!BB", vendor_tlv[0:2])
|
||||
except struct.error:
|
||||
attribute = [
|
||||
Attribute(
|
||||
name=f"Unknown-{vendor_name}-Attribute",
|
||||
pos=offset - len(vendor_value),
|
||||
type="octets",
|
||||
length=len(vendor_value) - 4,
|
||||
value=vendor_value,
|
||||
)
|
||||
]
|
||||
else:
|
||||
offset = offset - len(vendor_tlv) + length
|
||||
key = parse_key(vendor_dict, key)
|
||||
attribute = parse_value(vendor_dict, key, offset, vendor_tlv)
|
||||
attributes.extend(attribute)
|
||||
vendor_tlv = vendor_tlv[length:]
|
||||
return attributes
|
||||
|
||||
|
||||
def parse_key(rad_dict: Dictionary, key_id: int) -> Union[str, int]:
|
||||
"""Parse the key in the Dictionary Context"""
|
||||
try:
|
||||
return rad_dict.attrs[key_id].name
|
||||
except KeyError:
|
||||
return key_id
|
||||
|
||||
|
||||
def parse_value(
|
||||
rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
|
||||
) -> List[Attribute]:
|
||||
"""Parse the Value in the given Key/Dictionary Context"""
|
||||
return []
|
||||
|
||||
|
||||
def calculate_authenticator(
|
||||
secret: bytes, authenticator: bytes, raw_packet: bytes
|
||||
) -> bytes:
|
||||
"""Calculate the Authenticator for the RADIUS Packet"""
|
||||
return MD5(
|
||||
raw_packet[0:4] + authenticator + raw_packet[20:] + secret
|
||||
).digest()
|
||||
|
||||
|
||||
def validate_pap_password(
|
||||
secret: bytes,
|
||||
authenticator: bytes,
|
||||
obfuscated_password: bytes,
|
||||
plaintext_password: bytes,
|
||||
) -> bool:
|
||||
"""Check if the plaintext and the RADIUS passwords match
|
||||
This function does not "decrypt" the received password.
|
||||
"""
|
||||
obf_pass = password_encode(secret, authenticator, plaintext_password)
|
||||
return obfuscated_password == obf_pass
|
||||
|
||||
|
||||
def password_encode(
|
||||
secret: bytes, authenticator: bytes, password: bytes
|
||||
) -> bytes:
|
||||
"""Obfuscate the plaintext Password for RADIUS"""
|
||||
password += b"\x00" * (16 - (len(password) % 16))
|
||||
return obfuscation_algorithm(secret, authenticator, password)
|
||||
|
||||
|
||||
def password_decode(
|
||||
secret: bytes, authenticator: bytes, obfuscated_password: bytes
|
||||
) -> str:
|
||||
"""Reverse the RADIUS obfuscation on a given password
|
||||
|
||||
The password password is padded with \\x00 to a 16 byte boundary. The padding will
|
||||
be removed by this function.
|
||||
If the original password had some trailing \\x00 it will get lost. Therefore it is
|
||||
not recommended to use (trailing) \\x00 in passwords.
|
||||
"""
|
||||
deobfuscated = obfuscation_algorithm(
|
||||
secret, authenticator, obfuscated_password
|
||||
)
|
||||
return deobfuscated.rstrip(b"\x00").decode("utf-8")
|
||||
|
||||
|
||||
def obfuscation_algorithm(
|
||||
secret: bytes, authenticator: bytes, password: bytes
|
||||
) -> bytes:
|
||||
"""Obfuscate the plaintext password.
|
||||
|
||||
This function does not deal with the padding (which the
|
||||
RADIUS Protocol requires.)
|
||||
The User has to pad the password themself, or better use
|
||||
the `password_encode` or `password_decode` function.
|
||||
"""
|
||||
result = b""
|
||||
buf = password
|
||||
last = authenticator
|
||||
|
||||
while buf:
|
||||
cur_hash = MD5(secret + last).digest()
|
||||
for cbuf, chash in zip(buf, cur_hash):
|
||||
result += bytes([cbuf ^ chash])
|
||||
(last, buf) = (buf[:16], buf[16:])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_chap_password(
|
||||
chap_id: bytes,
|
||||
challenge: bytes,
|
||||
chap_password: bytes,
|
||||
plaintext_password: bytes,
|
||||
) -> bool:
|
||||
"""Validate the CHAP password against the given plaintext password"""
|
||||
return (
|
||||
chap_password == MD5(chap_id + plaintext_password + challenge).digest()
|
||||
)
|
||||
|
||||
|
||||
def salt_encrypt(secret: bytes, authenticator: bytes, value: bytes) -> bytes:
|
||||
"""Salt Encrypt the given value"""
|
||||
# The highest bit MUST be 1
|
||||
random_value = RANDOM_GENERATOR.randrange(32768, 65535)
|
||||
salt = struct.pack("!H", random_value)
|
||||
|
||||
salted_auth = authenticator + salt
|
||||
|
||||
return obfuscation_algorithm(secret, salted_auth, value)
|
||||
|
||||
|
||||
def salt_decrypt(
|
||||
secret: bytes, authenticator: bytes, salt: bytes, encrypted_value: bytes
|
||||
) -> bytes:
|
||||
"""Decrypt the given value"""
|
||||
salted_auth = authenticator + salt
|
||||
return obfuscation_algorithm(secret, salted_auth, encrypted_value)
|
||||
Reference in New Issue
Block a user