Make the different linting tools happy
This commit is contained in:
@@ -43,4 +43,4 @@ __url__ = "http://pyrad.readthedocs.io/en/latest/?badge=latest"
|
|||||||
__copyright__ = "Copyright 2020 Istvan Ruzman"
|
__copyright__ = "Copyright 2020 Istvan Ruzman"
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
__all__ = ["client", "dictionary", "packet", "server", "tools", "utils"]
|
__all__ = ["client", "code", "dictionary", "packet", "tools", "utils"]
|
||||||
|
|||||||
@@ -125,9 +125,13 @@ class Client(host.Host):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# timed out: try the next attempt after increasing the acct delay time
|
# timed out: try the next attempt after increasing the acct delay time
|
||||||
if packet.code == P.Code.AccountingRequest:
|
try:
|
||||||
|
# Pretend we've got an Acct Packet, we'll fail gracefully if it
|
||||||
|
# isn't an Acct Packet
|
||||||
packet = cast(P.AcctPacket, packet)
|
packet = cast(P.AcctPacket, packet)
|
||||||
packet.increase_acct_delay_time(self.timeout)
|
packet.increase_acct_delay_time(self.timeout)
|
||||||
raw_packet = packet.serialize()
|
raw_packet = packet.serialize()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
raise Timeout
|
raise Timeout
|
||||||
|
|||||||
35
src/pyrad3/code.py
Normal file
35
src/pyrad3/code.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Copyright 2020 Istvan Ruzman
|
||||||
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
|
||||||
|
"""Valid RADIUS codes (registered in IANA)
|
||||||
|
|
||||||
|
Currently not all RADIUS codes are contained, because
|
||||||
|
we don't support them (yet).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
|
class Code(IntEnum):
|
||||||
|
"""Valid RADIUS codes (registered in IANA)"""
|
||||||
|
|
||||||
|
AccessRequest = 1
|
||||||
|
AccessAccept = 2
|
||||||
|
AccessReject = 3
|
||||||
|
AccountingRequest = 4
|
||||||
|
AccountingResponse = 5
|
||||||
|
AccountingInterim = 6
|
||||||
|
PasswordRequest = 7
|
||||||
|
PasswordAck = 8
|
||||||
|
PasswordReject = 9
|
||||||
|
AccountingMessage = 10
|
||||||
|
AccessChallenge = 11
|
||||||
|
StatusServer = 12
|
||||||
|
StatusClient = 13
|
||||||
|
DisconnectRequest = 40
|
||||||
|
DisconnectACK = 41
|
||||||
|
DisconnectNAK = 42
|
||||||
|
CoARequest = 43
|
||||||
|
CoAACK = 44
|
||||||
|
CoANAK = 45
|
||||||
@@ -7,7 +7,7 @@ from pyrad3.dictionary import Dictionary
|
|||||||
from pyrad3 import packet
|
from pyrad3 import packet
|
||||||
|
|
||||||
|
|
||||||
class Host: # pylint: disable=too-many-arguments
|
class Host: # pylint: disable=too-many-arguments,too-many-instance-attributes
|
||||||
"""Interface Class for RADIUS Clients and Servers"""
|
"""Interface Class for RADIUS Clients and Servers"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -19,6 +19,7 @@ class Host: # pylint: disable=too-many-arguments
|
|||||||
coaport: int = 3799,
|
coaport: int = 3799,
|
||||||
timeout: float = 30,
|
timeout: float = 30,
|
||||||
retries: int = 3,
|
retries: int = 3,
|
||||||
|
message_authenticator: bool = False,
|
||||||
):
|
):
|
||||||
self.secret = secret
|
self.secret = secret
|
||||||
self.dictionary = radius_dict
|
self.dictionary = radius_dict
|
||||||
@@ -30,6 +31,8 @@ class Host: # pylint: disable=too-many-arguments
|
|||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.retries = retries
|
self.retries = retries
|
||||||
|
|
||||||
|
self.message_authenticator = message_authenticator
|
||||||
|
|
||||||
def create_packet(self, **kwargs):
|
def create_packet(self, **kwargs):
|
||||||
"""Create a generic RADIUS Packet"""
|
"""Create a generic RADIUS Packet"""
|
||||||
return packet.Packet(self, **kwargs)
|
return packet.Packet(self, **kwargs)
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
# Copyright 2020 Istvan Ruzman
|
# Copyright 2020 Istvan Ruzman
|
||||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
"""Class for RADIUS Packet"""
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from enum import IntEnum
|
|
||||||
from secrets import token_bytes
|
from secrets import token_bytes
|
||||||
from typing import Any, Dict, Optional, Sequence
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
|
|
||||||
from pyrad3.host import Host
|
from pyrad3.host import Host
|
||||||
|
from pyrad3.code import Code
|
||||||
from pyrad3.utils import (
|
from pyrad3.utils import (
|
||||||
PacketError,
|
PacketError,
|
||||||
Attribute,
|
Attribute,
|
||||||
@@ -23,29 +25,12 @@ from pyrad3.utils import (
|
|||||||
HMAC = hmac.new
|
HMAC = hmac.new
|
||||||
|
|
||||||
|
|
||||||
# Packet codes
|
|
||||||
class Code(IntEnum):
|
|
||||||
AccessRequest = 1
|
|
||||||
AccessAccept = 2
|
|
||||||
AccessReject = 3
|
|
||||||
AccountingRequest = 4
|
|
||||||
AccountingResponse = 5
|
|
||||||
AccessChallenge = 11
|
|
||||||
StatusServer = 12
|
|
||||||
StatusClient = 13
|
|
||||||
DisconnectRequest = 40
|
|
||||||
DisconnectACK = 41
|
|
||||||
DisconnectNAK = 42
|
|
||||||
CoARequest = 43
|
|
||||||
CoAACK = 44
|
|
||||||
CoANAK = 45
|
|
||||||
|
|
||||||
|
|
||||||
class AuthError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Packet(OrderedDict):
|
class Packet(OrderedDict):
|
||||||
|
"""Generic RADIUS Packet
|
||||||
|
|
||||||
|
Usually this should be used as a Super-Class for specific RADIUS Packets.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: Host,
|
host: Host,
|
||||||
@@ -57,7 +42,7 @@ class Packet(OrderedDict):
|
|||||||
):
|
):
|
||||||
super().__init__(**attributes)
|
super().__init__(**attributes)
|
||||||
self.code = code
|
self.code = code
|
||||||
self.id = radius_id
|
self.identifier = radius_id
|
||||||
self.host = host
|
self.host = host
|
||||||
self.request = request
|
self.request = request
|
||||||
self.ordered_attributes: Sequence[Attribute] = []
|
self.ordered_attributes: Sequence[Attribute] = []
|
||||||
@@ -66,6 +51,7 @@ class Packet(OrderedDict):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
|
def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
|
||||||
|
"""Parse the given bytearray to a RADIUS Packet"""
|
||||||
(code, radius_id, _length, authenticator) = parse_header(raw_packet)
|
(code, radius_id, _length, authenticator) = parse_header(raw_packet)
|
||||||
|
|
||||||
ordered_attrs = parse_attributes(host.dictionary, raw_packet)
|
ordered_attrs = parse_attributes(host.dictionary, raw_packet)
|
||||||
@@ -85,6 +71,7 @@ class Packet(OrderedDict):
|
|||||||
return parsed_packet
|
return parsed_packet
|
||||||
|
|
||||||
def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
|
def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
|
||||||
|
"""Parse a bytearray """
|
||||||
self.verify_reply(raw_packet)
|
self.verify_reply(raw_packet)
|
||||||
reply = Packet.from_raw(self.host, raw_packet)
|
reply = Packet.from_raw(self.host, raw_packet)
|
||||||
reply.request = self
|
reply.request = self
|
||||||
@@ -99,12 +86,11 @@ class Packet(OrderedDict):
|
|||||||
def send(self):
|
def send(self):
|
||||||
"""Send the packet to the Client/Server.
|
"""Send the packet to the Client/Server.
|
||||||
"""
|
"""
|
||||||
self.host._send_packet(self)
|
self.host._send_packet(self) # pylint: disable=protected-access
|
||||||
|
|
||||||
def verify_reply(self, raw_reply: bytes):
|
def verify_reply(self, raw_reply: bytes):
|
||||||
"""Verify the reply to this packet.
|
"""Verify the reply to this packet"""
|
||||||
"""
|
if self.identifier != raw_reply[1]:
|
||||||
if self.id != raw_reply[1]:
|
|
||||||
raise PacketError("Response has a wrong id")
|
raise PacketError("Response has a wrong id")
|
||||||
|
|
||||||
# self.authenticator MUST be set, this packet got send so by definitation
|
# self.authenticator MUST be set, this packet got send so by definitation
|
||||||
@@ -119,6 +105,7 @@ class Packet(OrderedDict):
|
|||||||
raise PacketError("Reply Packet has a wrong authenticator")
|
raise PacketError("Reply Packet has a wrong authenticator")
|
||||||
|
|
||||||
def validate_message_authenticator(self):
|
def validate_message_authenticator(self):
|
||||||
|
"""Validate the Message-Authenticator within the given RADIUS Packet"""
|
||||||
message_authenticator = self["Message-Authenticator"]
|
message_authenticator = self["Message-Authenticator"]
|
||||||
if isinstance(list, message_authenticator):
|
if isinstance(list, message_authenticator):
|
||||||
# There are multiple Message Authenticators, but a packet MUST NOT have
|
# There are multiple Message Authenticators, but a packet MUST NOT have
|
||||||
@@ -129,6 +116,7 @@ class Packet(OrderedDict):
|
|||||||
return message_authenticator == generated
|
return message_authenticator == generated
|
||||||
|
|
||||||
def _generate_message_authenticator(self, ma_attr: Attribute):
|
def _generate_message_authenticator(self, ma_attr: Attribute):
|
||||||
|
"""Calculate the Message-Authenticator for the given Packet"""
|
||||||
assert self.authenticator is not None
|
assert self.authenticator is not None
|
||||||
assert self.request is not None
|
assert self.request is not None
|
||||||
assert self.request.authenticator is not None
|
assert self.request.authenticator is not None
|
||||||
@@ -159,6 +147,7 @@ class Packet(OrderedDict):
|
|||||||
return hmac_builder.digest()
|
return hmac_builder.digest()
|
||||||
|
|
||||||
def add_message_authenticator(self):
|
def add_message_authenticator(self):
|
||||||
|
"""Add a Message-Authenticator to the RADIUS packet"""
|
||||||
self._encode_packet()
|
self._encode_packet()
|
||||||
self._generate_message_authenticator(self)
|
self._generate_message_authenticator(self)
|
||||||
try:
|
try:
|
||||||
@@ -173,9 +162,17 @@ class Packet(OrderedDict):
|
|||||||
self[index:] = generated
|
self[index:] = generated
|
||||||
|
|
||||||
def refresh_message_authenticator(self):
|
def refresh_message_authenticator(self):
|
||||||
|
"""Refresh an existing message-Authenticator
|
||||||
|
|
||||||
|
This method is equivalent to add_message_authenticator, but
|
||||||
|
the name can provide more context information.
|
||||||
|
"""
|
||||||
self.add_message_authenticator()
|
self.add_message_authenticator()
|
||||||
|
|
||||||
def find_first_attribute(self, attr_type_name: str) -> Attribute:
|
def find_first_attribute(
|
||||||
|
self, attr_type_name: Union[str, int, Tuple[int, ...]]
|
||||||
|
) -> Attribute:
|
||||||
|
"""Find the first attribute with the given name or code"""
|
||||||
for attr in self.ordered_attributes:
|
for attr in self.ordered_attributes:
|
||||||
if attr.type == attr_type_name:
|
if attr.type == attr_type_name:
|
||||||
return attr.type
|
return attr.type
|
||||||
@@ -184,8 +181,15 @@ class Packet(OrderedDict):
|
|||||||
def _encode_packet(self):
|
def _encode_packet(self):
|
||||||
self.raw_packet = None
|
self.raw_packet = None
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
"""Serialize the Packet to the RADIUS Line Format"""
|
||||||
|
# TODO: This is not an abstract method
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class AuthPacket(Packet): # pylint: disable=abstract-method
|
||||||
|
"""Generic RADIUS Authentication Packet"""
|
||||||
|
|
||||||
class AuthPacket(Packet):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: Host,
|
host: Host,
|
||||||
@@ -202,9 +206,10 @@ class AuthPacket(Packet):
|
|||||||
self.authenticator = token_bytes(16)
|
self.authenticator = token_bytes(16)
|
||||||
|
|
||||||
def create_accept(self, **attributes):
|
def create_accept(self, **attributes):
|
||||||
|
"""Create an Access-Accept for a given Access-Request"""
|
||||||
return AuthPacket(
|
return AuthPacket(
|
||||||
self.host,
|
self.host,
|
||||||
self.id,
|
self.identifier,
|
||||||
self.auth_type,
|
self.auth_type,
|
||||||
request=self,
|
request=self,
|
||||||
code=Code.AccessAccept,
|
code=Code.AccessAccept,
|
||||||
@@ -212,9 +217,10 @@ class AuthPacket(Packet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_reject(self, **attributes):
|
def create_reject(self, **attributes):
|
||||||
|
"""Create an Access-Reject for a given Access-Request"""
|
||||||
return AuthPacket(
|
return AuthPacket(
|
||||||
self.host,
|
self.host,
|
||||||
self.id,
|
self.identifier,
|
||||||
self.auth_type,
|
self.auth_type,
|
||||||
request=self,
|
request=self,
|
||||||
code=Code.AccessReject,
|
code=Code.AccessReject,
|
||||||
@@ -222,9 +228,10 @@ class AuthPacket(Packet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_challange(self, **attributes):
|
def create_challange(self, **attributes):
|
||||||
|
"""Create an Access-Challange for a given Access-Request"""
|
||||||
return AuthPacket(
|
return AuthPacket(
|
||||||
self.host,
|
self.host,
|
||||||
self.id,
|
self.identifier,
|
||||||
self.auth_type,
|
self.auth_type,
|
||||||
request=self,
|
request=self,
|
||||||
code=Code.AccessChallenge,
|
code=Code.AccessChallenge,
|
||||||
@@ -232,6 +239,7 @@ class AuthPacket(Packet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def validate_password(self, password: bytes) -> bool:
|
def validate_password(self, password: bytes) -> bool:
|
||||||
|
"""Validate a password of an Access-Request"""
|
||||||
try:
|
try:
|
||||||
return self.validate_pap(password)
|
return self.validate_pap(password)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -240,6 +248,7 @@ class AuthPacket(Packet):
|
|||||||
return self.validate_chap(password)
|
return self.validate_chap(password)
|
||||||
|
|
||||||
def validate_pap(self, password: bytes) -> bool:
|
def validate_pap(self, password: bytes) -> bool:
|
||||||
|
"""Validate a PAP-Password of an Access-Request"""
|
||||||
packet_password = self["User-Password"]
|
packet_password = self["User-Password"]
|
||||||
return validate_pap_password(
|
return validate_pap_password(
|
||||||
self.host.secret,
|
self.host.secret,
|
||||||
@@ -249,6 +258,7 @@ class AuthPacket(Packet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def validate_chap(self, password: bytes) -> bool:
|
def validate_chap(self, password: bytes) -> bool:
|
||||||
|
"""Validate the CHAP-Password of an Access-Request"""
|
||||||
packet_password = self["Chap-Password"]
|
packet_password = self["Chap-Password"]
|
||||||
chap_id = packet_password[:1]
|
chap_id = packet_password[:1]
|
||||||
chap_password = packet_password[1:]
|
chap_password = packet_password[1:]
|
||||||
@@ -261,7 +271,9 @@ class AuthPacket(Packet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AcctPacket(Packet):
|
class AcctPacket(Packet): # pylint: disable=abstract-method
|
||||||
|
"""Generic RADIUS Accounting Packet"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: Host,
|
host: Host,
|
||||||
@@ -274,22 +286,36 @@ class AcctPacket(Packet):
|
|||||||
super().__init__(host, code, radius_id, request=request, **attributes)
|
super().__init__(host, code, radius_id, request=request, **attributes)
|
||||||
|
|
||||||
def create_response(self, **attributes):
|
def create_response(self, **attributes):
|
||||||
|
"""Create an Accounting-Response to a given Accounting-Request"""
|
||||||
return AcctPacket(
|
return AcctPacket(
|
||||||
self.host,
|
self.host,
|
||||||
self.id,
|
self.identifier,
|
||||||
code=Code.AccountingResponse,
|
code=Code.AccountingResponse,
|
||||||
request=self,
|
request=self,
|
||||||
**attributes
|
**attributes
|
||||||
)
|
)
|
||||||
|
|
||||||
def increase_acct_delay_time(self, delay_time: float):
|
def increase_acct_delay_time(self, delay_time: float):
|
||||||
|
"""Increase the Accounting Delay Time
|
||||||
|
|
||||||
|
This method automatically adjust the Authenticator
|
||||||
|
and the Message-Authenticator (if present)
|
||||||
|
|
||||||
|
No check for the Accounting Codes are made, an
|
||||||
|
Accounting-Respones should not contain the
|
||||||
|
Acct-Delay-Time Attribute.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
self["Acct-Delay-Time"] += int(delay_time)
|
self["Acct-Delay-Time"] += int(delay_time)
|
||||||
|
if self.host.message_authenticator:
|
||||||
|
self.refresh_message_authenticator()
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CoAPacket(Packet):
|
class CoAPacket(Packet): # pylint: disable=abstract-method
|
||||||
|
"""Generic RADIUS CoA Packet"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: Host,
|
host: Host,
|
||||||
@@ -302,11 +328,21 @@ class CoAPacket(Packet):
|
|||||||
super().__init__(host, code, radius_id, request=request, **attributes)
|
super().__init__(host, code, radius_id, request=request, **attributes)
|
||||||
|
|
||||||
def create_ack(self, **attributes):
|
def create_ack(self, **attributes):
|
||||||
|
"""Create a RADIUS Packet of type CoA-Ack"""
|
||||||
return CoAPacket(
|
return CoAPacket(
|
||||||
self.host, self.id, code=Code.CoAACK, request=self, **attributes
|
self.host,
|
||||||
|
self.identifier,
|
||||||
|
code=Code.CoAACK,
|
||||||
|
request=self,
|
||||||
|
**attributes
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_nack(self, **attributes):
|
def create_nack(self, **attributes):
|
||||||
|
"""Create a RADIUS Packet of type CoA-Nack"""
|
||||||
return CoAPacket(
|
return CoAPacket(
|
||||||
self.host, self.id, code=Code.CoANAK, request=self, **attributes
|
self.host,
|
||||||
|
self.identifier,
|
||||||
|
code=Code.CoANAK,
|
||||||
|
request=self,
|
||||||
|
**attributes
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import secrets
|
|||||||
import struct
|
import struct
|
||||||
|
|
||||||
from pyrad3.dictionary import Dictionary
|
from pyrad3.dictionary import Dictionary
|
||||||
|
from pyrad3.code import Code
|
||||||
|
|
||||||
RANDOM_GENERATOR = secrets.SystemRandom()
|
RANDOM_GENERATOR = secrets.SystemRandom()
|
||||||
MD5 = hashlib.md5
|
MD5 = hashlib.md5
|
||||||
@@ -38,6 +39,11 @@ def parse_header(raw_packet: bytes) -> Header:
|
|||||||
)
|
)
|
||||||
if length > 4096:
|
if length > 4096:
|
||||||
raise PacketError(f"Packet length is too big ({length})")
|
raise PacketError(f"Packet length is too big ({length})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
Code(header[1])
|
||||||
|
except ValueError:
|
||||||
|
PacketError(f"Unknown RADIUS Code {header[1]}")
|
||||||
return Header(*header)
|
return Header(*header)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
26
tests/test_parse_header.py
Normal file
26
tests/test_parse_header.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Copyright 2020 Istvan Ruzman
|
||||||
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from pyrad3 import utils
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize("header", [
|
||||||
|
# b""])
|
||||||
|
# def test_valid_header(header):
|
||||||
|
# utils.parse_header(header)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"header",
|
||||||
|
[
|
||||||
|
b"\1\0" + struct.pack("!H", 5000) + 4996 * b"\0",
|
||||||
|
b"\1\0" + struct.pack("!H", 100),
|
||||||
|
b"\0\0" + struct.pack("!H", 20) + 16 * b"\0",
|
||||||
|
b"",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_invalid_header(header):
|
||||||
|
with pytest.raises(utils.PacketError):
|
||||||
|
utils.parse_header(header)
|
||||||
Reference in New Issue
Block a user