From 9d7d40806a35ddc10c202d002e369981bc1cee4a Mon Sep 17 00:00:00 2001 From: Istvan Ruzman Date: Fri, 7 Aug 2020 10:02:34 +0200 Subject: [PATCH] save progress --- src/pyrad3/client.py | 8 ++++--- src/pyrad3/dictionary.py | 28 +++++++++++------------ src/pyrad3/packet.py | 6 +++++ src/pyrad3/tools.py | 14 ++++++++---- src/pyrad3/utils.py | 4 ++-- tests/test_dictionary.py | 48 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 85 insertions(+), 23 deletions(-) diff --git a/src/pyrad3/client.py b/src/pyrad3/client.py index ae924f4..c0cebd3 100644 --- a/src/pyrad3/client.py +++ b/src/pyrad3/client.py @@ -3,7 +3,7 @@ """Implementation of a simple but extensible RADIUS Client""" -from typing import Optional, Union +from typing import cast, Optional, Union from ipaddress import IPv4Address, IPv6Address import select @@ -11,6 +11,7 @@ import socket import time import pyrad3.packet as P +from pyrad3.dictionary import Dictionary from pyrad3 import host SUPPORTED_SEND_TYPES = [ @@ -41,7 +42,7 @@ class Client(host.Host): self, server: Union[str, IPv4Address, IPv6Address], secret: bytes, - radius_dictionary: dict, + radius_dictionary: Dictionary, interface: Optional[str], **kwargs, ): @@ -124,7 +125,8 @@ class Client(host.Host): pass # timed out: try the next attempt after increasing the acct delay time - if packet.code == packet.AccountingRequest: + if packet.code == P.Code.AccountingRequest: + packet = cast(P.AcctPacket, packet) packet.increase_acct_delay_time(self.timeout) raw_packet = packet.serialize() diff --git a/src/pyrad3/dictionary.py b/src/pyrad3/dictionary.py index 54a960d..b3eb23b 100644 --- a/src/pyrad3/dictionary.py +++ b/src/pyrad3/dictionary.py @@ -9,14 +9,14 @@ Classes and Types to parse and represent a RADIUS dictionary. from enum import IntEnum, Enum, auto from dataclasses import dataclass from os.path import dirname, isabs, join, normpath -from typing import Dict, Generator, IO, List, Optional, Sequence, Tuple, Union +from typing import cast, Dict, Generator, IO, List, Optional, Sequence, Tuple, Union import logging LOG = logging.getLogger(__name__) -INTEGER_TYPES = { +INTEGER_TYPES: Dict[str, Tuple[int, int]] = { "byte": (0, 255), "short": (0, 2 ** 16 - 1), "signed": (-(2 ** 31), 2 ** 31 - 1), @@ -88,11 +88,11 @@ class Attribute: # pylint: disable=too-many-instance-attributes name: str code: int datatype: Datatype + values: Dict[Union[int, str], Union[int, str]] has_tag: bool = False encrypt: Encrypt = Encrypt(0) is_sub_attr: bool = False # vendor = Dictionary - values: Dict[Union[int, str], Union[int, str]] = None @dataclass @@ -180,7 +180,7 @@ class Dictionary: def __init__(self, dictionary: str, __dictio: Optional[IO] = None): self.vendor: Dict[int, Vendor] = {} self.vendor_lookup_id_by_name: Dict[str, int] = {} - self.attrindex: Dict[Union[int, str], Attribute] = {} + self.attrindex: Dict[Union[str, int, Tuple[int, ...]], Attribute] = {} self.rfc_vendor = Vendor("RFC", 0, 1, 1, False, {}) self.cur_vendor = self.rfc_vendor if __dictio is not None: @@ -193,7 +193,7 @@ class Dictionary: self, reader: Generator[Tuple[int, List[str]], None, None] ): """Read and parse a (Free)RADIUS dictionary.""" - self.filestack = [] + self.filestack: List[str] = [] for line_num, tokens in reader: key = tokens[0] if key == "ATTRIBUTE": @@ -387,9 +387,9 @@ class Dictionary: ) has_tag, encrypt = self._parse_attribute_flags(tokens, line_num) - name, code, datatype = tokens[1:4] + name, attr_code, datatype = tokens[1:4] - if datatype == "concat" and self.cur_vendor != self.rfc_vendor: + if datatype in {"concat", "extended", "evs", "long-extended" } and self.cur_vendor != self.rfc_vendor: raise ParseError( filename, 'vendor attributes are not allowed to have the datatype "concat"', @@ -397,10 +397,10 @@ class Dictionary: ) try: - codes = _parse_attribute_code(code) + codes = _parse_attribute_code(attr_code) except ValueError: raise ParseError( - filename, f'invalid attribute code {code}""', line_num + filename, f'invalid attribute code {attr_code}""', line_num ) for code in codes: @@ -434,17 +434,17 @@ class Dictionary: name, codes[-1], attribute_type, + {}, has_tag, encrypt, len(codes) > 1, - {}, ) - attrcode = codes[0] if len(codes) == 1 else tuple(codes) + attrcode: Union[int, Tuple[int, ...]] = codes[0] if len(codes) == 1 else tuple(codes) self.cur_vendor.attrs[attrcode] = attribute if self.cur_vendor != self.rfc_vendor: - codes = tuple([26] + codes) + codes = [26] + codes attrcode = codes[0] if len(codes) == 1 else tuple(codes) self.attrindex[attrcode] = attribute self.attrindex[name] = attribute @@ -459,8 +459,8 @@ class Dictionary: line_num, ) - (attr_name, key, value) = tokens[1:] - value = _parse_number(value) + (attr_name, key, vvalue) = tokens[1:] + value = _parse_number(vvalue) attribute = self.attrindex[attr_name] try: diff --git a/src/pyrad3/packet.py b/src/pyrad3/packet.py index 76d3764..24329d1 100644 --- a/src/pyrad3/packet.py +++ b/src/pyrad3/packet.py @@ -281,6 +281,12 @@ class AcctPacket(Packet): **attributes ) + def increase_acct_delay_time(self, delay_time: float): + try: + self['Acct-Delay-Time'] += int(delay_time) + except KeyError: + pass + class CoAPacket(Packet): def __init__( diff --git a/src/pyrad3/tools.py b/src/pyrad3/tools.py index 4cd40e4..71df498 100644 --- a/src/pyrad3/tools.py +++ b/src/pyrad3/tools.py @@ -4,7 +4,7 @@ """Collections of functions to en- and decode RADIUS Attributes""" from typing import Union -from ipaddress import IPv4Address, IPv6Address, IPv6Network, ip_network, ip_address +from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network, ip_address import struct @@ -30,10 +30,16 @@ def encode_address(addr: Union[str, IPv4Address]) -> bytes: return IPv4Address(addr).packed -def encode_ipv6_prefix(addr: Union[str, IPv6Network]) -> bytes: +def encode_network(network: Union[str, IPv4Network]) -> bytes: + """Encode a RADIUS value of type ipv4prefix""" + address = IPv4Network(network) + return struct.pack("2B", 0, address.prefixlen) + address.network_address.packed + + +def encode_ipv6_prefix(network: Union[str, IPv6Network]) -> bytes: """Encode a RADIUS value of type ipv6prefix""" - address = IPv6Network(addr) - return struct.pack("2B", *[0, address.prefixlen]) + address.network_address.packed + address = IPv6Network(network) + return struct.pack("2B", 0, address.prefixlen) + address.network_address.packed.rstrip(b'\0') def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes: diff --git a/src/pyrad3/utils.py b/src/pyrad3/utils.py index 7ef2395..876c14f 100644 --- a/src/pyrad3/utils.py +++ b/src/pyrad3/utils.py @@ -92,7 +92,7 @@ def parse_vendor_attributes( if len(vendor_value) < 4: raise PacketError vendor_id = int.from_bytes(vendor_value[:4], "big") - vendor_dict = rad_dict.vendor_by_id[vendor_id] + vendor_dict = rad_dict.vendor[vendor_id] vendor_name = vendor_dict.name attributes = [] @@ -122,7 +122,7 @@ def parse_vendor_attributes( def parse_key(rad_dict: Dictionary, key_id: int) -> Union[str, int]: """Parse the key in the Dictionary Context""" try: - return rad_dict.attrs[key_id].name + return rad_dict.attrindex[key_id].name except KeyError: return key_id diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index ac56ec1..5fe58f5 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -188,3 +188,51 @@ def test_value_number_out_of_limit(value_num, attr_type): ) with pytest.raises(ParseError): Dictionary("", dictionary) + + + +@pytest.mark.parametrize("datatype", [ + "string", "octets", "abinary", "byte", "short", + "integer", "signed", "integer64", "ipaddr", + "ipv4prefix", "ipv6addr", "ipv6prefix", "combo-ip", + "ifid", "ether", "concat", "tlv", "extended", + "long-extended", "evs", + ]) +def test_all_datatypes_rfc_space(datatype): + dictionary = StringIO( + f"ATTRIBUTE TEST-ATTRIBUTE 1 {datatype}\n" + ) + Dictionary("", dictionary) + + +@pytest.mark.parametrize("datatype", [ + "string", "octets", "abinary", "byte", "short", + "integer", "signed", "integer64", "ipaddr", + "ipv4prefix", "ipv6addr", "ipv6prefix", "combo-ip", + "ifid", "ether", "tlv", + ]) +def test_valid_datatypes_in_vendor_space(datatype): + dictionary = StringIO( + "VENDOR TEST 1234\n" + "BEGIN-VENDOR TEST\n" + f"ATTRIBUTE TEST-ATTRIBUTE 1 {datatype}\n" + "END-VENDOR TEST\n" + ) + Dictionary("", dictionary) + + +@pytest.mark.parametrize("datatype", [ + "concat", "extended", "long-extended", "evs", + ]) +def test_invalid_datatypes_in_vendor_space(datatype): + dictionary = StringIO( + "VENDOR TEST 1234\n" + "BEGIN-VENDOR TEST\n" + f"ATTRIBUTE TEST-ATTRIBUTE 1 {datatype}\n" + "END-VENDOR TEST\n" + ) + with pytest.raises(ParseError): + Dictionary("", dictionary) + + +