diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..3ef210f --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy-pytest] +ignore_missing_imports = True diff --git a/src/pyrad3/__init__.py b/src/pyrad3/__init__.py index 3f8b558..52843fa 100644 --- a/src/pyrad3/__init__.py +++ b/src/pyrad3/__init__.py @@ -43,4 +43,4 @@ __url__ = "http://pyrad.readthedocs.io/en/latest/?badge=latest" __copyright__ = "Copyright 2020 Istvan Ruzman" __version__ = "0.1.0" -__all__ = ["client", "dictionary", "packet", "server", "tools", "utils"] +__all__ = ["client", "code", "dictionary", "packet", "tools", "utils"] diff --git a/src/pyrad3/client.py b/src/pyrad3/client.py index c0cebd3..1cfd3c2 100644 --- a/src/pyrad3/client.py +++ b/src/pyrad3/client.py @@ -125,9 +125,13 @@ class Client(host.Host): pass # timed out: try the next attempt after increasing the acct delay time - if packet.code == P.Code.AccountingRequest: + try: + # Pretend we've got an Acct Packet, we'll fail gracefully if it + # isn't an Acct Packet packet = cast(P.AcctPacket, packet) packet.increase_acct_delay_time(self.timeout) raw_packet = packet.serialize() + except AttributeError: + pass raise Timeout diff --git a/src/pyrad3/code.py b/src/pyrad3/code.py new file mode 100644 index 0000000..74432e3 --- /dev/null +++ b/src/pyrad3/code.py @@ -0,0 +1,35 @@ +# Copyright 2020 Istvan Ruzman +# SPDX-License-Identifier: MIT OR Apache-2.0 + + +"""Valid RADIUS codes (registered in IANA) + +Currently not all RADIUS codes are contained, because +we don't support them (yet). +""" + +from enum import IntEnum + + +class Code(IntEnum): + """Valid RADIUS codes (registered in IANA)""" + + AccessRequest = 1 + AccessAccept = 2 + AccessReject = 3 + AccountingRequest = 4 + AccountingResponse = 5 + AccountingInterim = 6 + PasswordRequest = 7 + PasswordAck = 8 + PasswordReject = 9 + AccountingMessage = 10 + AccessChallenge = 11 + StatusServer = 12 + StatusClient = 13 + DisconnectRequest = 40 + DisconnectACK = 41 + DisconnectNAK = 42 + CoARequest = 43 + CoAACK = 44 + CoANAK = 45 diff --git a/src/pyrad3/host.py b/src/pyrad3/host.py index c9f2081..2c2a747 100644 --- a/src/pyrad3/host.py +++ b/src/pyrad3/host.py @@ -7,7 +7,7 @@ from pyrad3.dictionary import Dictionary from pyrad3 import packet -class Host: # pylint: disable=too-many-arguments +class Host: # pylint: disable=too-many-arguments,too-many-instance-attributes """Interface Class for RADIUS Clients and Servers""" def __init__( @@ -19,6 +19,7 @@ class Host: # pylint: disable=too-many-arguments coaport: int = 3799, timeout: float = 30, retries: int = 3, + message_authenticator: bool = False, ): self.secret = secret self.dictionary = radius_dict @@ -30,6 +31,8 @@ class Host: # pylint: disable=too-many-arguments self.timeout = timeout self.retries = retries + self.message_authenticator = message_authenticator + def create_packet(self, **kwargs): """Create a generic RADIUS Packet""" return packet.Packet(self, **kwargs) diff --git a/src/pyrad3/packet.py b/src/pyrad3/packet.py index 98d57a6..a48d446 100644 --- a/src/pyrad3/packet.py +++ b/src/pyrad3/packet.py @@ -1,15 +1,17 @@ # Copyright 2020 Istvan Ruzman # SPDX-License-Identifier: MIT OR Apache-2.0 +"""Class for RADIUS Packet""" + from collections import OrderedDict -from enum import IntEnum from secrets import token_bytes -from typing import Any, Dict, Optional, Sequence +from typing import Any, Dict, Optional, Sequence, Tuple, Union import hashlib import hmac from pyrad3.host import Host +from pyrad3.code import Code from pyrad3.utils import ( PacketError, Attribute, @@ -23,29 +25,12 @@ from pyrad3.utils import ( HMAC = hmac.new -# Packet codes -class Code(IntEnum): - AccessRequest = 1 - AccessAccept = 2 - AccessReject = 3 - AccountingRequest = 4 - AccountingResponse = 5 - AccessChallenge = 11 - StatusServer = 12 - StatusClient = 13 - DisconnectRequest = 40 - DisconnectACK = 41 - DisconnectNAK = 42 - CoARequest = 43 - CoAACK = 44 - CoANAK = 45 - - -class AuthError(Exception): - pass - - class Packet(OrderedDict): + """Generic RADIUS Packet + + Usually this should be used as a Super-Class for specific RADIUS Packets. + """ + def __init__( self, host: Host, @@ -57,7 +42,7 @@ class Packet(OrderedDict): ): super().__init__(**attributes) self.code = code - self.id = radius_id + self.identifier = radius_id self.host = host self.request = request self.ordered_attributes: Sequence[Attribute] = [] @@ -66,6 +51,7 @@ class Packet(OrderedDict): @staticmethod def from_raw(host: Host, raw_packet: bytearray) -> "Packet": + """Parse the given bytearray to a RADIUS Packet""" (code, radius_id, _length, authenticator) = parse_header(raw_packet) ordered_attrs = parse_attributes(host.dictionary, raw_packet) @@ -85,6 +71,7 @@ class Packet(OrderedDict): return parsed_packet def from_raw_reply(self, raw_packet: bytearray) -> "Packet": + """Parse a bytearray """ self.verify_reply(raw_packet) reply = Packet.from_raw(self.host, raw_packet) reply.request = self @@ -99,12 +86,11 @@ class Packet(OrderedDict): def send(self): """Send the packet to the Client/Server. """ - self.host._send_packet(self) + self.host._send_packet(self) # pylint: disable=protected-access def verify_reply(self, raw_reply: bytes): - """Verify the reply to this packet. - """ - if self.id != raw_reply[1]: + """Verify the reply to this packet""" + if self.identifier != raw_reply[1]: raise PacketError("Response has a wrong id") # self.authenticator MUST be set, this packet got send so by definitation @@ -119,6 +105,7 @@ class Packet(OrderedDict): raise PacketError("Reply Packet has a wrong authenticator") def validate_message_authenticator(self): + """Validate the Message-Authenticator within the given RADIUS Packet""" message_authenticator = self["Message-Authenticator"] if isinstance(list, message_authenticator): # There are multiple Message Authenticators, but a packet MUST NOT have @@ -129,6 +116,7 @@ class Packet(OrderedDict): return message_authenticator == generated def _generate_message_authenticator(self, ma_attr: Attribute): + """Calculate the Message-Authenticator for the given Packet""" assert self.authenticator is not None assert self.request is not None assert self.request.authenticator is not None @@ -159,6 +147,7 @@ class Packet(OrderedDict): return hmac_builder.digest() def add_message_authenticator(self): + """Add a Message-Authenticator to the RADIUS packet""" self._encode_packet() self._generate_message_authenticator(self) try: @@ -173,9 +162,17 @@ class Packet(OrderedDict): self[index:] = generated def refresh_message_authenticator(self): + """Refresh an existing message-Authenticator + + This method is equivalent to add_message_authenticator, but + the name can provide more context information. + """ self.add_message_authenticator() - def find_first_attribute(self, attr_type_name: str) -> Attribute: + def find_first_attribute( + self, attr_type_name: Union[str, int, Tuple[int, ...]] + ) -> Attribute: + """Find the first attribute with the given name or code""" for attr in self.ordered_attributes: if attr.type == attr_type_name: return attr.type @@ -184,8 +181,15 @@ class Packet(OrderedDict): def _encode_packet(self): self.raw_packet = None + def serialize(self) -> bytes: + """Serialize the Packet to the RADIUS Line Format""" + # TODO: This is not an abstract method + raise NotImplementedError + + +class AuthPacket(Packet): # pylint: disable=abstract-method + """Generic RADIUS Authentication Packet""" -class AuthPacket(Packet): def __init__( self, host: Host, @@ -202,9 +206,10 @@ class AuthPacket(Packet): self.authenticator = token_bytes(16) def create_accept(self, **attributes): + """Create an Access-Accept for a given Access-Request""" return AuthPacket( self.host, - self.id, + self.identifier, self.auth_type, request=self, code=Code.AccessAccept, @@ -212,9 +217,10 @@ class AuthPacket(Packet): ) def create_reject(self, **attributes): + """Create an Access-Reject for a given Access-Request""" return AuthPacket( self.host, - self.id, + self.identifier, self.auth_type, request=self, code=Code.AccessReject, @@ -222,9 +228,10 @@ class AuthPacket(Packet): ) def create_challange(self, **attributes): + """Create an Access-Challange for a given Access-Request""" return AuthPacket( self.host, - self.id, + self.identifier, self.auth_type, request=self, code=Code.AccessChallenge, @@ -232,6 +239,7 @@ class AuthPacket(Packet): ) def validate_password(self, password: bytes) -> bool: + """Validate a password of an Access-Request""" try: return self.validate_pap(password) except KeyError: @@ -240,6 +248,7 @@ class AuthPacket(Packet): return self.validate_chap(password) def validate_pap(self, password: bytes) -> bool: + """Validate a PAP-Password of an Access-Request""" packet_password = self["User-Password"] return validate_pap_password( self.host.secret, @@ -249,6 +258,7 @@ class AuthPacket(Packet): ) def validate_chap(self, password: bytes) -> bool: + """Validate the CHAP-Password of an Access-Request""" packet_password = self["Chap-Password"] chap_id = packet_password[:1] chap_password = packet_password[1:] @@ -261,7 +271,9 @@ class AuthPacket(Packet): ) -class AcctPacket(Packet): +class AcctPacket(Packet): # pylint: disable=abstract-method + """Generic RADIUS Accounting Packet""" + def __init__( self, host: Host, @@ -274,22 +286,36 @@ class AcctPacket(Packet): super().__init__(host, code, radius_id, request=request, **attributes) def create_response(self, **attributes): + """Create an Accounting-Response to a given Accounting-Request""" return AcctPacket( self.host, - self.id, + self.identifier, code=Code.AccountingResponse, request=self, **attributes ) def increase_acct_delay_time(self, delay_time: float): + """Increase the Accounting Delay Time + + This method automatically adjust the Authenticator + and the Message-Authenticator (if present) + + No check for the Accounting Codes are made, an + Accounting-Respones should not contain the + Acct-Delay-Time Attribute. + """ try: self["Acct-Delay-Time"] += int(delay_time) + if self.host.message_authenticator: + self.refresh_message_authenticator() except KeyError: pass -class CoAPacket(Packet): +class CoAPacket(Packet): # pylint: disable=abstract-method + """Generic RADIUS CoA Packet""" + def __init__( self, host: Host, @@ -302,11 +328,21 @@ class CoAPacket(Packet): super().__init__(host, code, radius_id, request=request, **attributes) def create_ack(self, **attributes): + """Create a RADIUS Packet of type CoA-Ack""" return CoAPacket( - self.host, self.id, code=Code.CoAACK, request=self, **attributes + self.host, + self.identifier, + code=Code.CoAACK, + request=self, + **attributes ) def create_nack(self, **attributes): + """Create a RADIUS Packet of type CoA-Nack""" return CoAPacket( - self.host, self.id, code=Code.CoANAK, request=self, **attributes + self.host, + self.identifier, + code=Code.CoANAK, + request=self, + **attributes ) diff --git a/src/pyrad3/utils.py b/src/pyrad3/utils.py index b7a12cc..288ebd3 100644 --- a/src/pyrad3/utils.py +++ b/src/pyrad3/utils.py @@ -11,6 +11,7 @@ import secrets import struct from pyrad3.dictionary import Dictionary +from pyrad3.code import Code RANDOM_GENERATOR = secrets.SystemRandom() MD5 = hashlib.md5 @@ -38,6 +39,11 @@ def parse_header(raw_packet: bytes) -> Header: ) if length > 4096: raise PacketError(f"Packet length is too big ({length})") + + try: + Code(header[1]) + except ValueError: + PacketError(f"Unknown RADIUS Code {header[1]}") return Header(*header) diff --git a/tests/test_parse_header.py b/tests/test_parse_header.py new file mode 100644 index 0000000..ed69494 --- /dev/null +++ b/tests/test_parse_header.py @@ -0,0 +1,26 @@ +# Copyright 2020 Istvan Ruzman +# SPDX-License-Identifier: MIT OR Apache-2.0 + +import struct + +from pyrad3 import utils +import pytest + +# @pytest.mark.parametrize("header", [ +# b""]) +# def test_valid_header(header): +# utils.parse_header(header) + + +@pytest.mark.parametrize( + "header", + [ + b"\1\0" + struct.pack("!H", 5000) + 4996 * b"\0", + b"\1\0" + struct.pack("!H", 100), + b"\0\0" + struct.pack("!H", 20) + 16 * b"\0", + b"", + ], +) +def test_invalid_header(header): + with pytest.raises(utils.PacketError): + utils.parse_header(header)