Make the different linting tools happy

This commit is contained in:
Istvan Ruzman
2020-08-12 10:15:14 +02:00
parent 6d8dd18601
commit d7ff0be1aa
8 changed files with 154 additions and 42 deletions

2
mypy.ini Normal file
View File

@@ -0,0 +1,2 @@
[mypy-pytest]
ignore_missing_imports = True

View File

@@ -43,4 +43,4 @@ __url__ = "http://pyrad.readthedocs.io/en/latest/?badge=latest"
__copyright__ = "Copyright 2020 Istvan Ruzman" __copyright__ = "Copyright 2020 Istvan Ruzman"
__version__ = "0.1.0" __version__ = "0.1.0"
__all__ = ["client", "dictionary", "packet", "server", "tools", "utils"] __all__ = ["client", "code", "dictionary", "packet", "tools", "utils"]

View File

@@ -125,9 +125,13 @@ class Client(host.Host):
pass pass
# timed out: try the next attempt after increasing the acct delay time # timed out: try the next attempt after increasing the acct delay time
if packet.code == P.Code.AccountingRequest: try:
# Pretend we've got an Acct Packet, we'll fail gracefully if it
# isn't an Acct Packet
packet = cast(P.AcctPacket, packet) packet = cast(P.AcctPacket, packet)
packet.increase_acct_delay_time(self.timeout) packet.increase_acct_delay_time(self.timeout)
raw_packet = packet.serialize() raw_packet = packet.serialize()
except AttributeError:
pass
raise Timeout raise Timeout

35
src/pyrad3/code.py Normal file
View File

@@ -0,0 +1,35 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Valid RADIUS codes (registered in IANA)
Currently not all RADIUS codes are contained, because
we don't support them (yet).
"""
from enum import IntEnum
class Code(IntEnum):
"""Valid RADIUS codes (registered in IANA)"""
AccessRequest = 1
AccessAccept = 2
AccessReject = 3
AccountingRequest = 4
AccountingResponse = 5
AccountingInterim = 6
PasswordRequest = 7
PasswordAck = 8
PasswordReject = 9
AccountingMessage = 10
AccessChallenge = 11
StatusServer = 12
StatusClient = 13
DisconnectRequest = 40
DisconnectACK = 41
DisconnectNAK = 42
CoARequest = 43
CoAACK = 44
CoANAK = 45

View File

@@ -7,7 +7,7 @@ from pyrad3.dictionary import Dictionary
from pyrad3 import packet from pyrad3 import packet
class Host: # pylint: disable=too-many-arguments class Host: # pylint: disable=too-many-arguments,too-many-instance-attributes
"""Interface Class for RADIUS Clients and Servers""" """Interface Class for RADIUS Clients and Servers"""
def __init__( def __init__(
@@ -19,6 +19,7 @@ class Host: # pylint: disable=too-many-arguments
coaport: int = 3799, coaport: int = 3799,
timeout: float = 30, timeout: float = 30,
retries: int = 3, retries: int = 3,
message_authenticator: bool = False,
): ):
self.secret = secret self.secret = secret
self.dictionary = radius_dict self.dictionary = radius_dict
@@ -30,6 +31,8 @@ class Host: # pylint: disable=too-many-arguments
self.timeout = timeout self.timeout = timeout
self.retries = retries self.retries = retries
self.message_authenticator = message_authenticator
def create_packet(self, **kwargs): def create_packet(self, **kwargs):
"""Create a generic RADIUS Packet""" """Create a generic RADIUS Packet"""
return packet.Packet(self, **kwargs) return packet.Packet(self, **kwargs)

View File

@@ -1,15 +1,17 @@
# Copyright 2020 Istvan Ruzman # Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0 # SPDX-License-Identifier: MIT OR Apache-2.0
"""Class for RADIUS Packet"""
from collections import OrderedDict from collections import OrderedDict
from enum import IntEnum
from secrets import token_bytes from secrets import token_bytes
from typing import Any, Dict, Optional, Sequence from typing import Any, Dict, Optional, Sequence, Tuple, Union
import hashlib import hashlib
import hmac import hmac
from pyrad3.host import Host from pyrad3.host import Host
from pyrad3.code import Code
from pyrad3.utils import ( from pyrad3.utils import (
PacketError, PacketError,
Attribute, Attribute,
@@ -23,29 +25,12 @@ from pyrad3.utils import (
HMAC = hmac.new HMAC = hmac.new
# Packet codes
class Code(IntEnum):
AccessRequest = 1
AccessAccept = 2
AccessReject = 3
AccountingRequest = 4
AccountingResponse = 5
AccessChallenge = 11
StatusServer = 12
StatusClient = 13
DisconnectRequest = 40
DisconnectACK = 41
DisconnectNAK = 42
CoARequest = 43
CoAACK = 44
CoANAK = 45
class AuthError(Exception):
pass
class Packet(OrderedDict): class Packet(OrderedDict):
"""Generic RADIUS Packet
Usually this should be used as a Super-Class for specific RADIUS Packets.
"""
def __init__( def __init__(
self, self,
host: Host, host: Host,
@@ -57,7 +42,7 @@ class Packet(OrderedDict):
): ):
super().__init__(**attributes) super().__init__(**attributes)
self.code = code self.code = code
self.id = radius_id self.identifier = radius_id
self.host = host self.host = host
self.request = request self.request = request
self.ordered_attributes: Sequence[Attribute] = [] self.ordered_attributes: Sequence[Attribute] = []
@@ -66,6 +51,7 @@ class Packet(OrderedDict):
@staticmethod @staticmethod
def from_raw(host: Host, raw_packet: bytearray) -> "Packet": def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
"""Parse the given bytearray to a RADIUS Packet"""
(code, radius_id, _length, authenticator) = parse_header(raw_packet) (code, radius_id, _length, authenticator) = parse_header(raw_packet)
ordered_attrs = parse_attributes(host.dictionary, raw_packet) ordered_attrs = parse_attributes(host.dictionary, raw_packet)
@@ -85,6 +71,7 @@ class Packet(OrderedDict):
return parsed_packet return parsed_packet
def from_raw_reply(self, raw_packet: bytearray) -> "Packet": def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
"""Parse a bytearray """
self.verify_reply(raw_packet) self.verify_reply(raw_packet)
reply = Packet.from_raw(self.host, raw_packet) reply = Packet.from_raw(self.host, raw_packet)
reply.request = self reply.request = self
@@ -99,12 +86,11 @@ class Packet(OrderedDict):
def send(self): def send(self):
"""Send the packet to the Client/Server. """Send the packet to the Client/Server.
""" """
self.host._send_packet(self) self.host._send_packet(self) # pylint: disable=protected-access
def verify_reply(self, raw_reply: bytes): def verify_reply(self, raw_reply: bytes):
"""Verify the reply to this packet. """Verify the reply to this packet"""
""" if self.identifier != raw_reply[1]:
if self.id != raw_reply[1]:
raise PacketError("Response has a wrong id") raise PacketError("Response has a wrong id")
# self.authenticator MUST be set, this packet got send so by definitation # self.authenticator MUST be set, this packet got send so by definitation
@@ -119,6 +105,7 @@ class Packet(OrderedDict):
raise PacketError("Reply Packet has a wrong authenticator") raise PacketError("Reply Packet has a wrong authenticator")
def validate_message_authenticator(self): def validate_message_authenticator(self):
"""Validate the Message-Authenticator within the given RADIUS Packet"""
message_authenticator = self["Message-Authenticator"] message_authenticator = self["Message-Authenticator"]
if isinstance(list, message_authenticator): if isinstance(list, message_authenticator):
# There are multiple Message Authenticators, but a packet MUST NOT have # There are multiple Message Authenticators, but a packet MUST NOT have
@@ -129,6 +116,7 @@ class Packet(OrderedDict):
return message_authenticator == generated return message_authenticator == generated
def _generate_message_authenticator(self, ma_attr: Attribute): def _generate_message_authenticator(self, ma_attr: Attribute):
"""Calculate the Message-Authenticator for the given Packet"""
assert self.authenticator is not None assert self.authenticator is not None
assert self.request is not None assert self.request is not None
assert self.request.authenticator is not None assert self.request.authenticator is not None
@@ -159,6 +147,7 @@ class Packet(OrderedDict):
return hmac_builder.digest() return hmac_builder.digest()
def add_message_authenticator(self): def add_message_authenticator(self):
"""Add a Message-Authenticator to the RADIUS packet"""
self._encode_packet() self._encode_packet()
self._generate_message_authenticator(self) self._generate_message_authenticator(self)
try: try:
@@ -173,9 +162,17 @@ class Packet(OrderedDict):
self[index:] = generated self[index:] = generated
def refresh_message_authenticator(self): def refresh_message_authenticator(self):
"""Refresh an existing message-Authenticator
This method is equivalent to add_message_authenticator, but
the name can provide more context information.
"""
self.add_message_authenticator() self.add_message_authenticator()
def find_first_attribute(self, attr_type_name: str) -> Attribute: def find_first_attribute(
self, attr_type_name: Union[str, int, Tuple[int, ...]]
) -> Attribute:
"""Find the first attribute with the given name or code"""
for attr in self.ordered_attributes: for attr in self.ordered_attributes:
if attr.type == attr_type_name: if attr.type == attr_type_name:
return attr.type return attr.type
@@ -184,8 +181,15 @@ class Packet(OrderedDict):
def _encode_packet(self): def _encode_packet(self):
self.raw_packet = None self.raw_packet = None
def serialize(self) -> bytes:
"""Serialize the Packet to the RADIUS Line Format"""
# TODO: This is not an abstract method
raise NotImplementedError
class AuthPacket(Packet): # pylint: disable=abstract-method
"""Generic RADIUS Authentication Packet"""
class AuthPacket(Packet):
def __init__( def __init__(
self, self,
host: Host, host: Host,
@@ -202,9 +206,10 @@ class AuthPacket(Packet):
self.authenticator = token_bytes(16) self.authenticator = token_bytes(16)
def create_accept(self, **attributes): def create_accept(self, **attributes):
"""Create an Access-Accept for a given Access-Request"""
return AuthPacket( return AuthPacket(
self.host, self.host,
self.id, self.identifier,
self.auth_type, self.auth_type,
request=self, request=self,
code=Code.AccessAccept, code=Code.AccessAccept,
@@ -212,9 +217,10 @@ class AuthPacket(Packet):
) )
def create_reject(self, **attributes): def create_reject(self, **attributes):
"""Create an Access-Reject for a given Access-Request"""
return AuthPacket( return AuthPacket(
self.host, self.host,
self.id, self.identifier,
self.auth_type, self.auth_type,
request=self, request=self,
code=Code.AccessReject, code=Code.AccessReject,
@@ -222,9 +228,10 @@ class AuthPacket(Packet):
) )
def create_challange(self, **attributes): def create_challange(self, **attributes):
"""Create an Access-Challange for a given Access-Request"""
return AuthPacket( return AuthPacket(
self.host, self.host,
self.id, self.identifier,
self.auth_type, self.auth_type,
request=self, request=self,
code=Code.AccessChallenge, code=Code.AccessChallenge,
@@ -232,6 +239,7 @@ class AuthPacket(Packet):
) )
def validate_password(self, password: bytes) -> bool: def validate_password(self, password: bytes) -> bool:
"""Validate a password of an Access-Request"""
try: try:
return self.validate_pap(password) return self.validate_pap(password)
except KeyError: except KeyError:
@@ -240,6 +248,7 @@ class AuthPacket(Packet):
return self.validate_chap(password) return self.validate_chap(password)
def validate_pap(self, password: bytes) -> bool: def validate_pap(self, password: bytes) -> bool:
"""Validate a PAP-Password of an Access-Request"""
packet_password = self["User-Password"] packet_password = self["User-Password"]
return validate_pap_password( return validate_pap_password(
self.host.secret, self.host.secret,
@@ -249,6 +258,7 @@ class AuthPacket(Packet):
) )
def validate_chap(self, password: bytes) -> bool: def validate_chap(self, password: bytes) -> bool:
"""Validate the CHAP-Password of an Access-Request"""
packet_password = self["Chap-Password"] packet_password = self["Chap-Password"]
chap_id = packet_password[:1] chap_id = packet_password[:1]
chap_password = packet_password[1:] chap_password = packet_password[1:]
@@ -261,7 +271,9 @@ class AuthPacket(Packet):
) )
class AcctPacket(Packet): class AcctPacket(Packet): # pylint: disable=abstract-method
"""Generic RADIUS Accounting Packet"""
def __init__( def __init__(
self, self,
host: Host, host: Host,
@@ -274,22 +286,36 @@ class AcctPacket(Packet):
super().__init__(host, code, radius_id, request=request, **attributes) super().__init__(host, code, radius_id, request=request, **attributes)
def create_response(self, **attributes): def create_response(self, **attributes):
"""Create an Accounting-Response to a given Accounting-Request"""
return AcctPacket( return AcctPacket(
self.host, self.host,
self.id, self.identifier,
code=Code.AccountingResponse, code=Code.AccountingResponse,
request=self, request=self,
**attributes **attributes
) )
def increase_acct_delay_time(self, delay_time: float): def increase_acct_delay_time(self, delay_time: float):
"""Increase the Accounting Delay Time
This method automatically adjust the Authenticator
and the Message-Authenticator (if present)
No check for the Accounting Codes are made, an
Accounting-Respones should not contain the
Acct-Delay-Time Attribute.
"""
try: try:
self["Acct-Delay-Time"] += int(delay_time) self["Acct-Delay-Time"] += int(delay_time)
if self.host.message_authenticator:
self.refresh_message_authenticator()
except KeyError: except KeyError:
pass pass
class CoAPacket(Packet): class CoAPacket(Packet): # pylint: disable=abstract-method
"""Generic RADIUS CoA Packet"""
def __init__( def __init__(
self, self,
host: Host, host: Host,
@@ -302,11 +328,21 @@ class CoAPacket(Packet):
super().__init__(host, code, radius_id, request=request, **attributes) super().__init__(host, code, radius_id, request=request, **attributes)
def create_ack(self, **attributes): def create_ack(self, **attributes):
"""Create a RADIUS Packet of type CoA-Ack"""
return CoAPacket( return CoAPacket(
self.host, self.id, code=Code.CoAACK, request=self, **attributes self.host,
self.identifier,
code=Code.CoAACK,
request=self,
**attributes
) )
def create_nack(self, **attributes): def create_nack(self, **attributes):
"""Create a RADIUS Packet of type CoA-Nack"""
return CoAPacket( return CoAPacket(
self.host, self.id, code=Code.CoANAK, request=self, **attributes self.host,
self.identifier,
code=Code.CoANAK,
request=self,
**attributes
) )

View File

@@ -11,6 +11,7 @@ import secrets
import struct import struct
from pyrad3.dictionary import Dictionary from pyrad3.dictionary import Dictionary
from pyrad3.code import Code
RANDOM_GENERATOR = secrets.SystemRandom() RANDOM_GENERATOR = secrets.SystemRandom()
MD5 = hashlib.md5 MD5 = hashlib.md5
@@ -38,6 +39,11 @@ def parse_header(raw_packet: bytes) -> Header:
) )
if length > 4096: if length > 4096:
raise PacketError(f"Packet length is too big ({length})") raise PacketError(f"Packet length is too big ({length})")
try:
Code(header[1])
except ValueError:
PacketError(f"Unknown RADIUS Code {header[1]}")
return Header(*header) return Header(*header)

View File

@@ -0,0 +1,26 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
import struct
from pyrad3 import utils
import pytest
# @pytest.mark.parametrize("header", [
# b""])
# def test_valid_header(header):
# utils.parse_header(header)
@pytest.mark.parametrize(
"header",
[
b"\1\0" + struct.pack("!H", 5000) + 4996 * b"\0",
b"\1\0" + struct.pack("!H", 100),
b"\0\0" + struct.pack("!H", 20) + 16 * b"\0",
b"",
],
)
def test_invalid_header(header):
with pytest.raises(utils.PacketError):
utils.parse_header(header)