fix: circular import for type checking

This commit is contained in:
Istvan Ruzman
2022-02-21 16:14:44 +01:00
parent 1b3cfe8f1c
commit ca67c58ea5

View File

@@ -1,15 +1,16 @@
# Copyright 2020 Istvan Ruzman # Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0 # SPDX-License-Identifier: MIT OR Apache-2.0
from __future__ import annotations
"""Class for RADIUS Packet""" """Class for RADIUS Packet"""
import hashlib import hashlib
import hmac import hmac
import time
from collections import OrderedDict from collections import OrderedDict
from secrets import token_bytes from secrets import token_bytes
from typing import Any, Dict, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union
from pyrad3.host import Host
from pyrad3.types import Code from pyrad3.types import Code
from pyrad3.utils import ( from pyrad3.utils import (
Attribute, Attribute,
@@ -21,6 +22,9 @@ from pyrad3.utils import (
validate_pap_password, validate_pap_password,
) )
if TYPE_CHECKING:
from pyrad3.host import Host
HMAC = hmac.new HMAC = hmac.new
@@ -51,7 +55,7 @@ class Packet(OrderedDict):
@staticmethod @staticmethod
def from_raw(host: Host, raw_packet: bytearray) -> "Packet": def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
"""Decode the given bytearray to a RADIUS Packet""" """Decode the given bytearray to a RADIUS Packet"""
(code, radius_id, _length, authenticator) = decode_header(raw_packet) (code, radius_id, _, authenticator) = decode_header(raw_packet)
ordered_attrs = decode_attributes(host.dictionary, raw_packet) ordered_attrs = decode_attributes(host.dictionary, raw_packet)
@@ -70,7 +74,7 @@ class Packet(OrderedDict):
return decoded_packet return decoded_packet
def from_raw_reply(self, raw_packet: bytearray) -> "Packet": def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
"""Parse a bytearray """ """Parse a bytearray"""
self.verify_reply(raw_packet) self.verify_reply(raw_packet)
reply = Packet.from_raw(self.host, raw_packet) reply = Packet.from_raw(self.host, raw_packet)
reply.request = self reply.request = self
@@ -251,9 +255,10 @@ class AuthPacket(Packet): # pylint: disable=abstract-method
def validate_pap(self, password: bytes) -> bool: def validate_pap(self, password: bytes) -> bool:
"""Validate a PAP-Password of an Access-Request""" """Validate a PAP-Password of an Access-Request"""
packet_password = self["User-Password"] packet_password = self["User-Password"]
assert self.authenticator is not None
return validate_pap_password( return validate_pap_password(
self.host.secret, self.host.secret,
self.authenticator, # type: ignore self.authenticator,
packet_password, packet_password,
password, password,
) )
@@ -267,11 +272,12 @@ class AuthPacket(Packet): # pylint: disable=abstract-method
challenge = self["Chap-Challenge"] challenge = self["Chap-Challenge"]
except KeyError: except KeyError:
challenge = self.authenticator challenge = self.authenticator
assert challenge is not None
return validate_chap_password( return validate_chap_password(
chap_id, chap_id,
challenge, challenge,
chap_password, chap_password,
password, # type: ignore password,
) )
@@ -285,8 +291,11 @@ class AcctPacket(Packet): # pylint: disable=abstract-method
*, *,
code: Code = Code.AccountingRequest, code: Code = Code.AccountingRequest,
request: Optional[Packet] = None, request: Optional[Packet] = None,
with_event_time: bool = True,
**attributes, **attributes,
): ):
if with_event_time:
attributes["Event-Timestamp"] = int(time.time())
super().__init__(host, code, radius_id, request=request, **attributes) super().__init__(host, code, radius_id, request=request, **attributes)
def create_response(self, **attributes): def create_response(self, **attributes):