Make the different linting tools happy

This commit is contained in:
Istvan Ruzman
2020-08-12 10:15:14 +02:00
parent 6d8dd18601
commit d7ff0be1aa
8 changed files with 154 additions and 42 deletions

View File

@@ -43,4 +43,4 @@ __url__ = "http://pyrad.readthedocs.io/en/latest/?badge=latest"
__copyright__ = "Copyright 2020 Istvan Ruzman"
__version__ = "0.1.0"
__all__ = ["client", "dictionary", "packet", "server", "tools", "utils"]
__all__ = ["client", "code", "dictionary", "packet", "tools", "utils"]

View File

@@ -125,9 +125,13 @@ class Client(host.Host):
pass
# timed out: try the next attempt after increasing the acct delay time
if packet.code == P.Code.AccountingRequest:
try:
# Pretend we've got an Acct Packet, we'll fail gracefully if it
# isn't an Acct Packet
packet = cast(P.AcctPacket, packet)
packet.increase_acct_delay_time(self.timeout)
raw_packet = packet.serialize()
except AttributeError:
pass
raise Timeout

35
src/pyrad3/code.py Normal file
View File

@@ -0,0 +1,35 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Valid RADIUS codes (registered in IANA)
Currently not all RADIUS codes are contained, because
we don't support them (yet).
"""
from enum import IntEnum
class Code(IntEnum):
"""Valid RADIUS codes (registered in IANA)"""
AccessRequest = 1
AccessAccept = 2
AccessReject = 3
AccountingRequest = 4
AccountingResponse = 5
AccountingInterim = 6
PasswordRequest = 7
PasswordAck = 8
PasswordReject = 9
AccountingMessage = 10
AccessChallenge = 11
StatusServer = 12
StatusClient = 13
DisconnectRequest = 40
DisconnectACK = 41
DisconnectNAK = 42
CoARequest = 43
CoAACK = 44
CoANAK = 45

View File

@@ -7,7 +7,7 @@ from pyrad3.dictionary import Dictionary
from pyrad3 import packet
class Host: # pylint: disable=too-many-arguments
class Host: # pylint: disable=too-many-arguments,too-many-instance-attributes
"""Interface Class for RADIUS Clients and Servers"""
def __init__(
@@ -19,6 +19,7 @@ class Host: # pylint: disable=too-many-arguments
coaport: int = 3799,
timeout: float = 30,
retries: int = 3,
message_authenticator: bool = False,
):
self.secret = secret
self.dictionary = radius_dict
@@ -30,6 +31,8 @@ class Host: # pylint: disable=too-many-arguments
self.timeout = timeout
self.retries = retries
self.message_authenticator = message_authenticator
def create_packet(self, **kwargs):
"""Create a generic RADIUS Packet"""
return packet.Packet(self, **kwargs)

View File

@@ -1,15 +1,17 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Class for RADIUS Packet"""
from collections import OrderedDict
from enum import IntEnum
from secrets import token_bytes
from typing import Any, Dict, Optional, Sequence
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import hashlib
import hmac
from pyrad3.host import Host
from pyrad3.code import Code
from pyrad3.utils import (
PacketError,
Attribute,
@@ -23,29 +25,12 @@ from pyrad3.utils import (
HMAC = hmac.new
# Packet codes
class Code(IntEnum):
AccessRequest = 1
AccessAccept = 2
AccessReject = 3
AccountingRequest = 4
AccountingResponse = 5
AccessChallenge = 11
StatusServer = 12
StatusClient = 13
DisconnectRequest = 40
DisconnectACK = 41
DisconnectNAK = 42
CoARequest = 43
CoAACK = 44
CoANAK = 45
class AuthError(Exception):
pass
class Packet(OrderedDict):
"""Generic RADIUS Packet
Usually this should be used as a Super-Class for specific RADIUS Packets.
"""
def __init__(
self,
host: Host,
@@ -57,7 +42,7 @@ class Packet(OrderedDict):
):
super().__init__(**attributes)
self.code = code
self.id = radius_id
self.identifier = radius_id
self.host = host
self.request = request
self.ordered_attributes: Sequence[Attribute] = []
@@ -66,6 +51,7 @@ class Packet(OrderedDict):
@staticmethod
def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
"""Parse the given bytearray to a RADIUS Packet"""
(code, radius_id, _length, authenticator) = parse_header(raw_packet)
ordered_attrs = parse_attributes(host.dictionary, raw_packet)
@@ -85,6 +71,7 @@ class Packet(OrderedDict):
return parsed_packet
def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
"""Parse a bytearray """
self.verify_reply(raw_packet)
reply = Packet.from_raw(self.host, raw_packet)
reply.request = self
@@ -99,12 +86,11 @@ class Packet(OrderedDict):
def send(self):
"""Send the packet to the Client/Server.
"""
self.host._send_packet(self)
self.host._send_packet(self) # pylint: disable=protected-access
def verify_reply(self, raw_reply: bytes):
"""Verify the reply to this packet.
"""
if self.id != raw_reply[1]:
"""Verify the reply to this packet"""
if self.identifier != raw_reply[1]:
raise PacketError("Response has a wrong id")
# self.authenticator MUST be set, this packet got send so by definitation
@@ -119,6 +105,7 @@ class Packet(OrderedDict):
raise PacketError("Reply Packet has a wrong authenticator")
def validate_message_authenticator(self):
"""Validate the Message-Authenticator within the given RADIUS Packet"""
message_authenticator = self["Message-Authenticator"]
if isinstance(list, message_authenticator):
# There are multiple Message Authenticators, but a packet MUST NOT have
@@ -129,6 +116,7 @@ class Packet(OrderedDict):
return message_authenticator == generated
def _generate_message_authenticator(self, ma_attr: Attribute):
"""Calculate the Message-Authenticator for the given Packet"""
assert self.authenticator is not None
assert self.request is not None
assert self.request.authenticator is not None
@@ -159,6 +147,7 @@ class Packet(OrderedDict):
return hmac_builder.digest()
def add_message_authenticator(self):
"""Add a Message-Authenticator to the RADIUS packet"""
self._encode_packet()
self._generate_message_authenticator(self)
try:
@@ -173,9 +162,17 @@ class Packet(OrderedDict):
self[index:] = generated
def refresh_message_authenticator(self):
"""Refresh an existing message-Authenticator
This method is equivalent to add_message_authenticator, but
the name can provide more context information.
"""
self.add_message_authenticator()
def find_first_attribute(self, attr_type_name: str) -> Attribute:
def find_first_attribute(
self, attr_type_name: Union[str, int, Tuple[int, ...]]
) -> Attribute:
"""Find the first attribute with the given name or code"""
for attr in self.ordered_attributes:
if attr.type == attr_type_name:
return attr.type
@@ -184,8 +181,15 @@ class Packet(OrderedDict):
def _encode_packet(self):
self.raw_packet = None
def serialize(self) -> bytes:
"""Serialize the Packet to the RADIUS Line Format"""
# TODO: This is not an abstract method
raise NotImplementedError
class AuthPacket(Packet): # pylint: disable=abstract-method
"""Generic RADIUS Authentication Packet"""
class AuthPacket(Packet):
def __init__(
self,
host: Host,
@@ -202,9 +206,10 @@ class AuthPacket(Packet):
self.authenticator = token_bytes(16)
def create_accept(self, **attributes):
"""Create an Access-Accept for a given Access-Request"""
return AuthPacket(
self.host,
self.id,
self.identifier,
self.auth_type,
request=self,
code=Code.AccessAccept,
@@ -212,9 +217,10 @@ class AuthPacket(Packet):
)
def create_reject(self, **attributes):
"""Create an Access-Reject for a given Access-Request"""
return AuthPacket(
self.host,
self.id,
self.identifier,
self.auth_type,
request=self,
code=Code.AccessReject,
@@ -222,9 +228,10 @@ class AuthPacket(Packet):
)
def create_challange(self, **attributes):
"""Create an Access-Challange for a given Access-Request"""
return AuthPacket(
self.host,
self.id,
self.identifier,
self.auth_type,
request=self,
code=Code.AccessChallenge,
@@ -232,6 +239,7 @@ class AuthPacket(Packet):
)
def validate_password(self, password: bytes) -> bool:
"""Validate a password of an Access-Request"""
try:
return self.validate_pap(password)
except KeyError:
@@ -240,6 +248,7 @@ class AuthPacket(Packet):
return self.validate_chap(password)
def validate_pap(self, password: bytes) -> bool:
"""Validate a PAP-Password of an Access-Request"""
packet_password = self["User-Password"]
return validate_pap_password(
self.host.secret,
@@ -249,6 +258,7 @@ class AuthPacket(Packet):
)
def validate_chap(self, password: bytes) -> bool:
"""Validate the CHAP-Password of an Access-Request"""
packet_password = self["Chap-Password"]
chap_id = packet_password[:1]
chap_password = packet_password[1:]
@@ -261,7 +271,9 @@ class AuthPacket(Packet):
)
class AcctPacket(Packet):
class AcctPacket(Packet): # pylint: disable=abstract-method
"""Generic RADIUS Accounting Packet"""
def __init__(
self,
host: Host,
@@ -274,22 +286,36 @@ class AcctPacket(Packet):
super().__init__(host, code, radius_id, request=request, **attributes)
def create_response(self, **attributes):
"""Create an Accounting-Response to a given Accounting-Request"""
return AcctPacket(
self.host,
self.id,
self.identifier,
code=Code.AccountingResponse,
request=self,
**attributes
)
def increase_acct_delay_time(self, delay_time: float):
"""Increase the Accounting Delay Time
This method automatically adjust the Authenticator
and the Message-Authenticator (if present)
No check for the Accounting Codes are made, an
Accounting-Respones should not contain the
Acct-Delay-Time Attribute.
"""
try:
self["Acct-Delay-Time"] += int(delay_time)
if self.host.message_authenticator:
self.refresh_message_authenticator()
except KeyError:
pass
class CoAPacket(Packet):
class CoAPacket(Packet): # pylint: disable=abstract-method
"""Generic RADIUS CoA Packet"""
def __init__(
self,
host: Host,
@@ -302,11 +328,21 @@ class CoAPacket(Packet):
super().__init__(host, code, radius_id, request=request, **attributes)
def create_ack(self, **attributes):
"""Create a RADIUS Packet of type CoA-Ack"""
return CoAPacket(
self.host, self.id, code=Code.CoAACK, request=self, **attributes
self.host,
self.identifier,
code=Code.CoAACK,
request=self,
**attributes
)
def create_nack(self, **attributes):
"""Create a RADIUS Packet of type CoA-Nack"""
return CoAPacket(
self.host, self.id, code=Code.CoANAK, request=self, **attributes
self.host,
self.identifier,
code=Code.CoANAK,
request=self,
**attributes
)

View File

@@ -11,6 +11,7 @@ import secrets
import struct
from pyrad3.dictionary import Dictionary
from pyrad3.code import Code
RANDOM_GENERATOR = secrets.SystemRandom()
MD5 = hashlib.md5
@@ -38,6 +39,11 @@ def parse_header(raw_packet: bytes) -> Header:
)
if length > 4096:
raise PacketError(f"Packet length is too big ({length})")
try:
Code(header[1])
except ValueError:
PacketError(f"Unknown RADIUS Code {header[1]}")
return Header(*header)