Not so great value decoding code

This commit is contained in:
Istvan Ruzman
2020-09-03 13:32:34 +02:00
parent 9091f2eb2c
commit f508e1b614
9 changed files with 191 additions and 155 deletions

View File

@@ -38,6 +38,7 @@ python = "^3.7"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
black = { version = "^19.3.10b0", allow-prereleases = true } black = { version = "^19.3.10b0", allow-prereleases = true }
bandit = "^1.6"
pytest = "^5.4" pytest = "^5.4"
pytest-black = "^0.3.10" pytest-black = "^0.3.10"
pytest-cov = "^2.10" pytest-cov = "^2.10"
@@ -55,7 +56,7 @@ include = '\.py'
[tool.isort] [tool.isort]
combine_as_imports = true combine_as_imports = true
include_trailing_comma = true include_trailing_comma = true
line_length = 88 line_length = 80
multi_line_output = 3 multi_line_output = 3
use_parentheses = true use_parentheses = true
@@ -76,6 +77,7 @@ deps =
poetry poetry
commands = commands =
poetry install -v poetry install -v
poetry run bandit -c bandit.yaml -r src/pyrad3
poetry run pytest --black --isort --pylint --pylint-jobs=4 --mypy --flake8 --cov=pyrad3 poetry run pytest --black --isort --pylint --pylint-jobs=4 --mypy --flake8 --cov=pyrad3
""" """

View File

@@ -3,34 +3,10 @@
"""Python RADIUS client code. """Python RADIUS client code.
pyrad is an implementation of a RADIUS client as described in RFC2865. pyrad3 is framework for RADIUS as described in RFC2865.
It takes care of all the details like building RADIUS packets, sending It takes care of all the details like building RADIUS packets, sending
them and decoding responses. them and decoding responses.
Here is an example of doing a authentication request::
import pyrad.packet
from pyrad.client import Client
from pyrad.dictionary import Dictionary
srv = Client(server="radius.my.domain", secret="s3cr3t",
dict = Dictionary("dicts/dictionary", "dictionary.acc"))
req = srv.CreatePacket(code=pyrad.packet.AccessRequest,
User_Name = "wichert", NAS_Identifier="localhost")
req["User-Password"] = req.PwCrypt("password")
reply = srv.SendPacket(req)
if reply.code = =pyrad.packet.AccessAccept:
print "access accepted"
else:
print "access denied"
print "Attributes returned by server:"
for key, value in reply.items():
print f'{key}: {value}')
This package contains four modules: This package contains four modules:
- client: RADIUS client code - client: RADIUS client code

View File

@@ -9,8 +9,8 @@ import time
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
from typing import Optional, Union, cast from typing import Optional, Union, cast
import pyrad3.host as H
import pyrad3.packet as P import pyrad3.packet as P
from pyrad3 import host
from pyrad3.dictionary import Dictionary from pyrad3.dictionary import Dictionary
SUPPORTED_SEND_TYPES = [ SUPPORTED_SEND_TYPES = [
@@ -34,7 +34,7 @@ class UnsupportedPacketType(Exception):
"""Exception for received packets""" """Exception for received packets"""
class Client(host.Host): class Client(H.Host):
"""A simple and extensible RADIUS Client.""" """A simple and extensible RADIUS Client."""
def __init__( def __init__(
@@ -55,6 +55,8 @@ class Client(host.Host):
"""Bind the Address to some socket""" """Bind the Address to some socket"""
self._socket_close() self._socket_close()
self._socket_open() self._socket_open()
# This should be always true, if there is no socket,
assert self._socket is not None
self._socket.bind(addr) self._socket.bind(addr)
def _socket_open(self): def _socket_open(self):
@@ -62,7 +64,7 @@ class Client(host.Host):
if self._socket is not None: if self._socket is not None:
return return
try: try:
family = socket.getaddrinfo(self.server, "www")[0][0] family = socket.getaddrinfo(str(self.server), "www")[0][0]
except socket.gaierror: except socket.gaierror:
family = socket.AF_INET family = socket.AF_INET
self._socket = socket.socket(family, socket.SOCK_DGRAM) self._socket = socket.socket(family, socket.SOCK_DGRAM)
@@ -74,15 +76,14 @@ class Client(host.Host):
self._socket.setsockopt( self._socket.setsockopt(
socket.SOL_SOCKET, socket.SOL_SOCKET,
socket.SO_BINDTODEVICE, socket.SO_BINDTODEVICE,
self.interface, self.interface.encode("utf-8"),
len(self.interface),
) )
self._poll = select.poll() self._poll = select.poll()
self._poll.register(self._socket, select.POLLIN) self._poll.register(self._socket, select.POLLIN)
def _socket_close(self): def _socket_close(self):
"""Close the Client socket""" """Close the Client socket"""
if self._socket is not None: if self._socket is not None and self._poll is not None:
self._poll.unregister(self._socket) self._poll.unregister(self._socket)
self._socket.close() self._socket.close()
self._socket = None self._socket = None

View File

@@ -122,10 +122,10 @@ def dict_parser(
if first_tok == "$INCLUDE": if first_tok == "$INCLUDE":
try: try:
inner_filename = tokens[1] inner_filename = tokens[1]
except IndexError: except IndexError as exc:
raise ParseError( raise ParseError(
filename, "$INCLUDE is missing a filename", line_num, filename, "$INCLUDE is missing a filename", line_num,
) ) from exc
if not isabs(tokens[1]): if not isabs(tokens[1]):
path = dirname(filename) path = dirname(filename)
inner_filename = normpath(join(path, inner_filename)) inner_filename = normpath(join(path, inner_filename))
@@ -269,12 +269,12 @@ class Dictionary:
continuation = True continuation = True
except IndexError: except IndexError:
pass pass
except ValueError: except ValueError as exc:
raise ParseError( raise ParseError(
filename, filename,
f"Syntax error in specification for vendor {vendor_name}", f"Syntax error in specification for vendor {vendor_name}",
line_num, line_num,
) ) from exc
except IndexError: except IndexError:
# no format definition # no format definition
t_len, l_len = 1, 1 t_len, l_len = 1, 1
@@ -301,12 +301,12 @@ class Dictionary:
try: try:
vendor_id = self.vendor_lookup_id_by_name[tokens[1]] vendor_id = self.vendor_lookup_id_by_name[tokens[1]]
self.cur_vendor = self.vendor[vendor_id] self.cur_vendor = self.vendor[vendor_id]
except KeyError: except KeyError as exc:
raise ParseError( raise ParseError(
filename, filename,
f"Unknown vendor {tokens[1]} in begin-vendor statement", f"Unknown vendor {tokens[1]} in begin-vendor statement",
line_num, line_num,
) ) from exc
def _parse_end_vendor(self, tokens: Sequence[str], line_num: int): def _parse_end_vendor(self, tokens: Sequence[str], line_num: int):
"""Parse the END-VENDOR line of (Free)RADIUS dictionaries.""" """Parse the END-VENDOR line of (Free)RADIUS dictionaries."""
@@ -357,12 +357,12 @@ class Dictionary:
if value == "0": if value == "0":
raise ValueError raise ValueError
encrypt = Encrypt(int(value)) # type: ignore encrypt = Encrypt(int(value)) # type: ignore
except (ValueError, TypeError): except (ValueError, TypeError) as exc:
raise ParseError( raise ParseError(
filename, filename,
f"Illegal attribute encryption {value}", f"Illegal attribute encryption {value}",
line_num, line_num,
) ) from exc
else: else:
raise ParseError( raise ParseError(
filename, "Unknown attribute flag {key}", line_num filename, "Unknown attribute flag {key}", line_num
@@ -377,10 +377,10 @@ class Dictionary:
for code in attr_code.split("."): for code in attr_code.split("."):
try: try:
code_num = _parse_number(code) code_num = _parse_number(code)
except ValueError: except ValueError as exc:
raise ParseError( raise ParseError(
filename, f'invalid attribute code {attr_code}""', line_num filename, f'invalid attribute code {attr_code}""', line_num
) ) from exc
if 2 ** (8 * tlength) <= code_num: if 2 ** (8 * tlength) <= code_num:
raise ParseError( raise ParseError(
filename, filename,
@@ -430,8 +430,10 @@ class Dictionary:
base_datatype = datatype.split("[")[0].replace("-", "") base_datatype = datatype.split("[")[0].replace("-", "")
try: try:
attribute_type = Datatype[base_datatype] attribute_type = Datatype[base_datatype]
except KeyError: except KeyError as exc:
raise ParseError(filename, f"Illegal type: {datatype}", line_num) raise ParseError(
filename, f"Illegal type: {datatype}", line_num
) from exc
attribute = Attribute( attribute = Attribute(
name, name,
@@ -457,7 +459,9 @@ class Dictionary:
) )
else: else:
LOG.info("Register Attribute %s", attribute.name) LOG.info("Register Attribute %s", attribute.name)
attrcode = codes[0] if len(codes) == 1 else tuple(codes) attrcode: Union[int, Tuple[int, ...]] = codes[0] if len(
codes
) == 1 else tuple(codes)
self.attrindex[attrcode] = attribute self.attrindex[attrcode] = attribute
self.attrindex[name] = attribute self.attrindex[name] = attribute
@@ -477,19 +481,19 @@ class Dictionary:
# quick and dirty way to make floats values an error # quick and dirty way to make floats values an error
raise ValueError raise ValueError
value = _parse_number(vvalue) value = _parse_number(vvalue)
except ValueError: except ValueError as exc:
raise ParseError( raise ParseError(
filename, f"Invalid number {vvalue} for VALUE {key}", line_num filename, f"Invalid number {vvalue} for VALUE {key}", line_num
) ) from exc
try: try:
attribute = self.attrindex[attr_name] attribute = self.attrindex[attr_name]
except KeyError: except KeyError as exc:
raise ParseError( raise ParseError(
filename, filename,
f"ATTRIBUTE {attr_name} has not been defined yet", f"ATTRIBUTE {attr_name} has not been defined yet",
line_num, line_num,
) ) from exc
try: try:
datatype = str(attribute.datatype).split(".")[1] datatype = str(attribute.datatype).split(".")[1]
lmin, lmax = INTEGER_TYPES[datatype] lmin, lmax = INTEGER_TYPES[datatype]
@@ -499,13 +503,13 @@ class Dictionary:
f"VALUE {key}({value}) is not in the limit of type {datatype}", f"VALUE {key}({value}) is not in the limit of type {datatype}",
line_num, line_num,
) )
except KeyError: except KeyError as exc:
raise ParseError( raise ParseError(
filename, filename,
f"only attributes with integer typed datatypes can have" f"only attributes with integer typed datatypes can have"
f"value definitions {attribute.datatype}", f"value definitions {attribute.datatype}",
line_num, line_num,
) ) from exc
attribute.values[value] = key attribute.values[value] = key
attribute.values[key] = value attribute.values[key] = value

View File

@@ -48,3 +48,6 @@ class Host: # pylint: disable=too-many-arguments,too-many-instance-attributes
def create_coa_packet(self, **kwargs): def create_coa_packet(self, **kwargs):
"""Create an CoA packet (requset per default)""" """Create an CoA packet (requset per default)"""
return packet.CoAPacket(self, **kwargs) return packet.CoAPacket(self, **kwargs)
def _send_packet(self, _packet: packet.Packet):
raise NotImplementedError

View File

@@ -77,9 +77,11 @@ class Packet(OrderedDict):
try: try:
if not reply.validate_message_authenticator(): if not reply.validate_message_authenticator():
raise PacketError("Packet has a wrong message authenticator") raise PacketError("Packet has a wrong message authenticator")
except KeyError: except KeyError as exc:
if "EAP-Message" in reply: if "EAP-Message" in reply:
raise PacketError("Packet is missing a message authenticator") raise PacketError(
"Packet is missing a message authenticator"
) from exc
return reply return reply
def send(self): def send(self):
@@ -147,8 +149,9 @@ class Packet(OrderedDict):
def add_message_authenticator(self): def add_message_authenticator(self):
"""Add a Message-Authenticator to the RADIUS packet""" """Add a Message-Authenticator to the RADIUS packet"""
self["Message-Authenticator"] = b"\x00" * 16
self._encode_packet() self._encode_packet()
self._generate_message_authenticator(self) self._generate_message_authenticator(self["Message-Authenticator"])
try: try:
# quick lookup before we iterate over the whole packet # quick lookup before we iterate over the whole packet
_ = self["Message-Authenticator"] _ = self["Message-Authenticator"]

View File

@@ -12,7 +12,13 @@ from ipaddress import (
ip_address, ip_address,
ip_network, ip_network,
) )
from typing import Optional, Union from typing import Any, Callable, Dict, Union
from pyrad3.dictionary import Datatype
def _from_bytes(value: bytes) -> int:
return int.from_bytes(value, byteorder="big")
def encode_string(string: str) -> bytes: def encode_string(string: str) -> bytes:
@@ -24,14 +30,10 @@ def encode_string(string: str) -> bytes:
return string return string
def encode_octets(string: bytes, explen: Optional[int]) -> bytes: def encode_octets(string: bytes) -> bytes:
"""Encode a RADIUS value of type octet""" """Encode a RADIUS value of type octet"""
if len(string) > 253: if len(string) > 253:
raise ValueError("Can only encode strings of <= 253 characters") raise ValueError("Can only encode strings of <= 253 characters")
if explen is not None and len(string) != explen:
raise ValueError(
f"Expected a value length of {explen} got {len(string)}"
)
return string return string
@@ -171,12 +173,8 @@ def decode_string(string: bytes) -> Union[str, bytes]:
return string return string
def decode_octets(string: bytes, explen: Optional[int] = None) -> bytes: def decode_octets(string: bytes) -> bytes:
"""Decode a RADIUS value of type octet""" """Decode a RADIUS value of type octets"""
if explen is not None and len(string) != explen:
raise ValueError(
f"Expected a value length of {explen} got {len(string)}"
)
return string return string
@@ -187,9 +185,9 @@ def decode_ipv4_address(addr: bytes) -> IPv4Address:
def decode_ipv4_prefix(addr: bytes) -> IPv4Network: def decode_ipv4_prefix(addr: bytes) -> IPv4Network:
"""Decode a RADIUS value of type ipv6prefix""" """Decode a RADIUS value of type ipv6prefix"""
prefix = addr[:1] prefix = _from_bytes(addr[:1])
addr = addr[1:] addr = addr[1:]
return IPv4Network((prefix, addr)) return IPv4Network((addr, prefix))
def decode_ipv6_address(addr: bytes) -> IPv6Address: def decode_ipv6_address(addr: bytes) -> IPv6Address:
@@ -201,9 +199,10 @@ def decode_ipv6_address(addr: bytes) -> IPv6Address:
def decode_ipv6_prefix(addr: bytes) -> IPv6Network: def decode_ipv6_prefix(addr: bytes) -> IPv6Network:
"""Decode a RADIUS value of type ipv6prefix""" """Decode a RADIUS value of type ipv6prefix"""
addr = addr + b"\x00" * (18 - len(addr)) addr = addr + b"\x00" * (18 - len(addr))
prefix = addr[:2] # ignoring the reserved field at addr[0]
addr = addr[2:] prefix = _from_bytes(addr[1:2])
return IPv6Network((prefix, addr)) addr = addr[2:] + b"\x00" * (16 - len(addr[2:]))
return IPv6Network((addr, prefix))
def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]: def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]:
@@ -226,55 +225,57 @@ def decode_date(num: bytes) -> int: # TODO: type
return (struct.unpack("!I", num))[0] return (struct.unpack("!I", num))[0]
ENCODE_MAP = { ENCODE_MAP: Dict[Datatype, Callable[[Any], bytes]] = {
"string": encode_string, Datatype.string: encode_string,
"octets": encode_octets, Datatype.octets: encode_octets,
"ipaddr": encode_ipv4_address, Datatype.ipaddr: encode_ipv4_address,
"ipv4prefix": encode_ipv4_prefix, Datatype.ipv4prefix: encode_ipv4_prefix,
"ipv6addr": encode_ipv6_address, Datatype.ipv6addr: encode_ipv6_address,
"ipv6prefix": encode_ipv6_prefix, Datatype.ipv6prefix: encode_ipv6_prefix,
"comboip": encode_combo_ip, Datatype.comboip: encode_combo_ip,
"ifid": lambda value: encode_octets(value, 8), # TODO: length check (8)
"abinary": encode_ascend_binary, Datatype.ifid: encode_octets,
"byte": lambda value: encode_integer(value, "!B"), Datatype.abinary: encode_ascend_binary,
"short": lambda value: encode_integer(value, "!H"), Datatype.byte: lambda value: encode_integer(value, "!B"),
"signed": lambda value: encode_integer(value, "!i"), Datatype.short: lambda value: encode_integer(value, "!H"),
"integer": encode_integer, Datatype.signed: lambda value: encode_integer(value, "!i"),
"integer64": lambda value: encode_integer(value, "!Q"), Datatype.integer: encode_integer,
"date": encode_date, Datatype.integer64: lambda value: encode_integer(value, "!Q"),
Datatype.date: encode_date,
} }
def encode_attr(datatype, value): def encode_attr(datatype: Datatype, value: bytes) -> bytes:
"""Encode a RADIUS attribute""" """Encode a RADIUS attribute"""
try: try:
return ENCODE_MAP[datatype](value) return ENCODE_MAP[datatype](value)
except KeyError: except KeyError as exc:
raise ValueError(f"Unknown attribute type {datatype}") raise ValueError(f"Unknown attribute type {datatype}") from exc
DECODE_MAP = { DECODE_MAP: Dict[Datatype, Callable[[bytes], Any]] = {
"string": decode_string, Datatype.string: decode_string,
"octets": decode_octets, Datatype.octets: decode_octets,
"ipaddr": decode_ipv4_address, Datatype.ipaddr: decode_ipv4_address,
"ipv4prefix": decode_ipv4_prefix, Datatype.ipv4prefix: decode_ipv4_prefix,
"ipv6addr": decode_ipv6_address, Datatype.ipv6addr: decode_ipv6_address,
"ipv6prefix": decode_ipv6_prefix, Datatype.ipv6prefix: decode_ipv6_prefix,
"comboip": decode_combo_ip, Datatype.comboip: decode_combo_ip,
"ifid": lambda value: decode_octets(value, 8), # TODO: length check (8)
"abinary": decode_ascend_binary, Datatype.ifid: decode_octets,
"byte": lambda value: decode_integer(value, "!B"), Datatype.abinary: decode_ascend_binary,
"short": lambda value: decode_integer(value, "!H"), Datatype.byte: lambda value: decode_integer(value, "!B"),
"signed": lambda value: decode_integer(value, "!i"), Datatype.short: lambda value: decode_integer(value, "!H"),
"integer": decode_integer, Datatype.signed: lambda value: decode_integer(value, "!i"),
"integer64": lambda value: decode_integer(value, "!Q"), Datatype.integer: decode_integer,
"date": decode_date, Datatype.integer64: lambda value: decode_integer(value, "!Q"),
Datatype.date: decode_date,
} }
def decode_attr(datatype, value): def decode_attr(datatype: Datatype, value: bytes):
"""Decode a RADIUS attribute""" """Decode a RADIUS attribute"""
try: try:
return DECODE_MAP[datatype](value) return DECODE_MAP[datatype](value)
except KeyError: except KeyError as exc:
raise ValueError(f"Unknown attribute type {datatype}") raise ValueError(f"Unknown attribute type {datatype}") from exc

View File

@@ -10,7 +10,8 @@ from collections import namedtuple
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from pyrad3.code import Code from pyrad3.code import Code
from pyrad3.dictionary import Dictionary from pyrad3.dictionary import Attribute as DictAttr, Datatype, Dictionary
from pyrad3.tools import decode_attr
RANDOM_GENERATOR = secrets.SystemRandom() RANDOM_GENERATOR = secrets.SystemRandom()
MD5 = hashlib.md5 MD5 = hashlib.md5
@@ -23,13 +24,15 @@ class PacketError(Exception):
Header = namedtuple("Header", ["code", "radius_id", "length", "authenticator"]) Header = namedtuple("Header", ["code", "radius_id", "length", "authenticator"])
Attribute = namedtuple("Attribute", ["name", "pos", "type", "length", "value"]) Attribute = namedtuple("Attribute", ["name", "pos", "type", "length", "value"])
PreParsedAttribute = List[Tuple[int, int, DictAttr]]
def parse_header(raw_packet: bytes) -> Header: def parse_header(raw_packet: bytes) -> Header:
"""Parse the Header of a RADIUS Packet.""" """Parse the Header of a RADIUS Packet."""
try: try:
header = struct.unpack("!BBH16s", raw_packet) header = struct.unpack("!BBH16s", raw_packet)
except struct.error: except struct.error as exc:
raise PacketError("Packet header is corrupt") raise PacketError("Packet header is corrupt") from exc
length = header[3] length = header[3]
if len(raw_packet) != length: if len(raw_packet) != length:
@@ -41,8 +44,8 @@ def parse_header(raw_packet: bytes) -> Header:
try: try:
Code(header[1]) Code(header[1])
except ValueError: except ValueError as exc:
PacketError(f"Unknown RADIUS Code {header[1]}") raise PacketError(f"Unknown RADIUS Code {header[1]}") from exc
return Header(*header) return Header(*header)
@@ -60,8 +63,8 @@ def parse_attributes(
while packet: while packet:
try: try:
(key, length) = struct.unpack("!BB", packet[0:2]) (key, length) = struct.unpack("!BB", packet[0:2])
except struct.error: except struct.error as exc:
raise PacketError("Attribute header is corrupt") raise PacketError("Attribute header is corrupt") from exc
if length < 2: if length < 2:
raise PacketError(f"Attribute length ({length}) is too small") raise PacketError(f"Attribute length ({length}) is too small")
@@ -77,14 +80,26 @@ def parse_attributes(
Attribute( Attribute(
name="Unknown-Attribute", name="Unknown-Attribute",
pos=offset, pos=offset,
type="octets", type=Datatype.octets,
length=int(packet[1]), length=int(packet[1]),
value=packet[2:], value=packet[2:],
) )
) )
else: else:
key = parse_key(rad_dict, key) try:
attributes.extend(parse_value(rad_dict, key, offset, value)) attr_def = rad_dict.attrindex[key]
except KeyError as exc:
# TODO
raise exc
attributes.append(
Attribute(
name=attr_def.name,
pos=offset,
type=attr_def.datatype,
length=length - 2,
value=decode_attr(attr_def.datatype, packet[2:length],),
)
)
packet = packet[length:] packet = packet[length:]
return attributes return attributes
@@ -97,33 +112,78 @@ def parse_vendor_attributes(
if len(vendor_value) < 4: if len(vendor_value) < 4:
raise PacketError raise PacketError
vendor_id = int.from_bytes(vendor_value[:4], "big") vendor_id = int.from_bytes(vendor_value[:4], "big")
vendor_dict = rad_dict.vendor[vendor_id]
vendor_prefix = [26, vendor_id] vendor_prefix = [26, vendor_id]
attributes = [] attributes = []
vendor_tlv = vendor_value[4:] vendor_tlv = vendor_value[4:]
offset += 4
while vendor_tlv: while vendor_tlv:
try: try:
(key, length) = struct.unpack("!BB", vendor_tlv[0:2]) (key, attr_length) = struct.unpack("!BB", vendor_tlv[0:2])
except struct.error: print(attr_length)
attribute = [ keystack = vendor_prefix + [key]
attr_def = rad_dict.attrindex[tuple(keystack)]
except (struct.error, KeyError) as exc:
attributes.append(
Attribute( Attribute(
name="Unknown-Attribute", name="Unknown-Attribute",
pos=offset - len(vendor_value), pos=offset - len(vendor_value),
type="octets", type=Datatype.octets,
length=len(vendor_value) - 4, length=len(vendor_value) - 4,
value=vendor_value, value=vendor_value,
) )
] )
if exc is struct.error:
break
else: else:
offset = offset - len(vendor_tlv) + length if attr_def.datatype == Datatype.tlv:
key = parse_key(rad_dict, tuple(vendor_prefix + [key])) attr_pos = _parse_tlv(rad_dict, vendor_tlv, 2, keystack)
attribute = parse_value(vendor_dict, key, offset, vendor_tlv) else:
attributes.extend(attribute) attr_pos = [(2, attr_length, attr_def)]
vendor_tlv = vendor_tlv[length:]
parsed_attributes: List[Attribute] = []
for local_offset, length, attr_def in attr_pos:
length -= 2
parsed_attributes.append(
Attribute(
name=attr_def.name,
pos=offset + local_offset,
type=attr_def.datatype,
length=length,
value=decode_attr(
attr_def.datatype,
vendor_tlv[local_offset : local_offset + length],
),
)
)
offset = offset + attr_length
attributes.extend(parsed_attributes)
vendor_tlv = vendor_tlv[attr_length:]
return attributes return attributes
def _parse_tlv(
rad_dict: Dictionary, block: bytes, local_offset: int, key_stack: List[int]
) -> PreParsedAttribute:
# get a list of flattened radius attributes
ret = []
while block:
(key, length) = struct.unpack("!BB", block[0:2])
attr_def = rad_dict.attrindex[tuple(key_stack)]
if attr_def.datatype == Datatype.tlv:
key_stack.append(key)
block = block[:length]
ret.extend(
_parse_tlv(rad_dict, block[2:], local_offset + 2, key_stack)
)
key_stack.pop()
else:
ret.append((local_offset + 2, length, attr_def))
local_offset += length
return ret
def parse_key( def parse_key(
rad_dict: Dictionary, key_id: Union[int, Tuple[int, ...]] rad_dict: Dictionary, key_id: Union[int, Tuple[int, ...]]
) -> Union[str, int, Tuple[int, ...]]: ) -> Union[str, int, Tuple[int, ...]]:
@@ -134,18 +194,6 @@ def parse_key(
return key_id return key_id
def parse_value(*_):
"""Parse the Value in the given Key/Dictionary Context"""
raise NotImplementedError
# def parse_value(
# rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
# ) -> List[Attribute]:
# """Parse the Value in the given Key/Dictionary Context"""
# return []
def calculate_authenticator( def calculate_authenticator(
secret: bytes, authenticator: bytes, raw_packet: bytes secret: bytes, authenticator: bytes, raw_packet: bytes
) -> bytes: ) -> bytes:

View File

@@ -64,7 +64,7 @@ def num_tlv(num_type, num, length, expected=None):
num_tlv(b"\x08\x06", 0xFFFF, 4), # rfc signed num_tlv(b"\x08\x06", 0xFFFF, 4), # rfc signed
num_tlv(b"\x08\x06", 0x10000, 4), # rfc signed num_tlv(b"\x08\x06", 0x10000, 4), # rfc signed
num_tlv(b"\x08\x06", 0xFFFFFFFF, 4, -1), # rfc signed num_tlv(b"\x08\x06", 0xFFFFFFFF, 4, -1), # rfc signed
num_tlv(b"\x08\x06", 0x80000000, 4, -268435458), # rfc signed num_tlv(b"\x08\x06", 0x80000000, 4, -2147483648), # rfc signed
num_tlv(b"\x08\x06", 0x7FFFFFFF, 4, 2147483647), # rfc signed num_tlv(b"\x08\x06", 0x7FFFFFFF, 4, 2147483647), # rfc signed
num_tlv(b"\x09\x0A", 0, 8), # rfc integer64 num_tlv(b"\x09\x0A", 0, 8), # rfc integer64
num_tlv(b"\x09\x0A", 0xFF, 8), # rfc integer64 num_tlv(b"\x09\x0A", 0xFF, 8), # rfc integer64
@@ -74,22 +74,18 @@ def num_tlv(num_type, num, length, expected=None):
num_tlv(b"\x09\x0A", 0xFFFFFFFF, 8), # rfc integer64 num_tlv(b"\x09\x0A", 0xFFFFFFFF, 8), # rfc integer64
num_tlv(b"\x09\x0A", 0x100000000, 8), # rfc integer64 num_tlv(b"\x09\x0A", 0x100000000, 8), # rfc integer64
num_tlv(b"\x09\x0A", 0xFFFFFFFFFFFFFFFF, 8), # rfc integer64 num_tlv(b"\x09\x0A", 0xFFFFFFFFFFFFFFFF, 8), # rfc integer64
(b"\x0a\x06\xc0\xa8\x01\x08", IPv4Address("192.168.1.1")), (b"\x0a\x06\xc0\xa8\x01\x08", IPv4Address("192.168.1.8")),
(b"\x0b\x07\x10\xc4\xa8\x00\x00", IPv4Network("192.168.0.0/16")), (b"\x0b\x07\x10\xc0\xa8\x00\x00", IPv4Network("192.168.0.0/16")),
( (
b"\x0c\x12\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", b"\x0c\x12\x20\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
IPv6Address("2003::1"), IPv6Address("2003::1"),
), ),
( (
b"\x0c\x13@\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", b"\x0d\x14\x00\x40\x20\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
IPv6Network("2003::0/64"),
),
(b"\x0c\x04@\x03", IPv6Network("2003::0/64")),
(b"\x0a\x06\xc0\xa8\x01\x08", IPv4Address("192.168.1.1")),
(
b"\x0a\x13@\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
IPv6Network("2003::0/64"), IPv6Network("2003::0/64"),
), ),
(b"\x0c\x04\x20\x03", IPv6Address("2003::0")),
(b"\x0d\x06\x00\x40\x20\x03", IPv6Network("2003::0/64")),
], ],
) )
def test_parse_attribute_rfc_and_vsa(radius_dictionary, attr_bytes, expected): def test_parse_attribute_rfc_and_vsa(radius_dictionary, attr_bytes, expected):
@@ -98,8 +94,10 @@ def test_parse_attribute_rfc_and_vsa(radius_dictionary, attr_bytes, expected):
assert len(attrs) == 1 assert len(attrs) == 1
assert attrs[0].value == expected assert attrs[0].value == expected
vsa_length = (4 + len(attr_bytes)).to_bytes(1, "big") vsa_length = (6 + len(attr_bytes)).to_bytes(1, "big")
raw_packet = bytes(20) + b"\x1a" + vsa_length + "\x04\xd2" + attr_bytes raw_packet = (
bytes(20) + b"\x1a" + vsa_length + b"\x00\x00\x04\xd2" + attr_bytes
)
attrs = utils.parse_attributes(radius_dictionary, raw_packet) attrs = utils.parse_attributes(radius_dictionary, raw_packet)
assert len(attrs) == 1 assert len(attrs) == 1
assert attrs[0].value == expected assert attrs[0].value == expected