diff --git a/src/pyrad3/utils.py b/src/pyrad3/utils.py index 53d7af4..86961f2 100644 --- a/src/pyrad3/utils.py +++ b/src/pyrad3/utils.py @@ -30,8 +30,8 @@ PreParsedAttributes = List[Tuple[Tuple[int, ...], bytes, int]] SpecialTlvDescription = Tuple[Tuple[int, ...], bytes, int] -def parse_header(raw_packet: bytes) -> Header: - """Parse the Header of a RADIUS Packet.""" +def decode_header(raw_packet: bytes) -> Header: + """Decode the Header of a RADIUS Packet.""" try: header = struct.unpack("!BBH16s", raw_packet) except struct.error as exc: @@ -52,18 +52,18 @@ def parse_header(raw_packet: bytes) -> Header: return Header(*header) -def parse_attributes( +def decode_attributes( rad_dict: Dictionary, raw_packet: bytes ) -> List[Attribute]: - """Parse the Attributes in a RADIUS Packet. + """Decode the Attributes in a RADIUS Packet. - This function skips the Header. The Header must be parsed and verified + This function skips the Header. The Header must be decoded and verified separately. """ attributes = [] packet = raw_packet[20:] # Skip RADIUS Header - for key, value, offset in pre_parse_attributes(rad_dict, packet): + for key, value, offset in pre_decode_attributes(rad_dict, packet): attr_def = rad_dict.attrindex.get(key) length = len(value) dec_value: Any = value # to silence mypy @@ -97,7 +97,7 @@ def parse_attributes( return attributes -def pre_parse_attributes( # pylint: disable=too-many-branches +def pre_decode_attributes( # pylint: disable=too-many-branches rad_dict: Dictionary, packet_body: bytes ) -> PreParsedAttributes: """Find Attributes location and keystack""" @@ -183,14 +183,18 @@ def decode_extended( key: int, value: bytes, offset: int ) -> SpecialTlvDescription: """Decode an Attribute of type extended""" - raise NotImplementedError + key = (key, value[0]) + value = value[1:length-2] + return (key, value, 3) def decode_longextended( key: int, value: bytes, offset: int ) -> SpecialTlvDescription: """Decode an Attribute of type long-extended""" - raise NotImplementedError + key = (key, value[0]) + value = value[2:length-3] + return (key, value, 4) def decode_concat(key: int, value: bytes, offset: int) -> SpecialTlvDescription: diff --git a/tests/test_parse_header.py b/tests/test_parse_header.py index 867b868..5549546 100644 --- a/tests/test_parse_header.py +++ b/tests/test_parse_header.py @@ -131,46 +131,46 @@ def radius_dictionary(): ) def test_invalid_header(header): with pytest.raises(utils.PacketError): - utils.parse_header(header) + utils.decode_header(header) @pytest.mark.parametrize("attr_bytes, expected", TEST_ATTRIBUTES) -def test_parse_attribute_rfc(radius_dictionary, attr_bytes, expected): +def test_decode_attribute_rfc(radius_dictionary, attr_bytes, expected): raw_packet = bytes(20) + attr_bytes - attrs = utils.parse_attributes(radius_dictionary, raw_packet) + attrs = utils.decode_attributes(radius_dictionary, raw_packet) assert len(attrs) == 1 assert attrs[0].value == expected assert attrs[0].tag == 0 @pytest.mark.parametrize("attr_bytes, expected", TEST_ATTRIBUTES) -def test_parse_attribute_vsa(radius_dictionary, attr_bytes, expected): +def test_decode_attribute_vsa(radius_dictionary, attr_bytes, expected): vsa_length = (6 + len(attr_bytes)).to_bytes(1, "big") raw_packet = ( bytes(20) + b"\x1a" + vsa_length + b"\x00\x00\x04\xd2" + attr_bytes ) - attrs = utils.parse_attributes(radius_dictionary, raw_packet) + attrs = utils.decode_attributes(radius_dictionary, raw_packet) assert len(attrs) == 1 assert attrs[0].value == expected assert attrs[0].tag == 0 @pytest.mark.parametrize("attr_bytes, expected", TAGGED_ATTRIBUTES) -def test_parse_attribute_rfc_tagged(radius_dictionary, attr_bytes, expected): +def test_decode_attribute_rfc_tagged(radius_dictionary, attr_bytes, expected): raw_packet = bytes(20) + attr_bytes - attrs = utils.parse_attributes(radius_dictionary, raw_packet) + attrs = utils.decode_attributes(radius_dictionary, raw_packet) assert len(attrs) == 1 assert attrs[0].value == expected assert attrs[0].tag == 1 @pytest.mark.parametrize("attr_bytes, expected", TAGGED_ATTRIBUTES) -def test_parse_attribute_vsa_tagged(radius_dictionary, attr_bytes, expected): +def test_decode_attribute_vsa_tagged(radius_dictionary, attr_bytes, expected): vsa_length = (6 + len(attr_bytes)).to_bytes(1, "big") raw_packet = ( bytes(20) + b"\x1a" + vsa_length + b"\x00\x00\x04\xd2" + attr_bytes ) - attrs = utils.parse_attributes(radius_dictionary, raw_packet) + attrs = utils.decode_attributes(radius_dictionary, raw_packet) assert len(attrs) == 1 assert attrs[0].value == expected assert attrs[0].tag == 1