From fd16436c3c8373ff6f912f2330d1ee8b419c31ec Mon Sep 17 00:00:00 2001 From: Istvan Ruzman Date: Thu, 6 Aug 2020 18:04:24 +0200 Subject: [PATCH] safe progress --- .flake8 | 1 - default.nix | 7 +- example/acct.py | 54 +- example/auth.py | 45 +- example/auth_async.py | 91 +- example/client-coa.py | 30 +- example/coa.py | 26 +- example/server.py | 48 +- example/server_async.py | 74 +- example/status.py | 27 +- pyproject.toml | 15 +- pyrad3/client.py | 221 ----- pyrad3/curved.py | 81 -- pyrad3/dictfile.py | 116 --- pyrad3/dictionary.py | 393 --------- pyrad3/host.py | 100 --- pyrad3/new_host.py | 29 - pyrad3/new_packet.py | 143 --- pyrad3/packet.py | 869 ------------------- pyrad3/tests/__init__.py | 5 - pyrad3/tests/data/chap | 6 - pyrad3/tests/data/dictfiletest | 5 - pyrad3/tests/data/full | 34 - pyrad3/tests/data/simple | 16 - pyrad3/tests/data/tunnelPassword | 3 - pyrad3/tests/mock.py | 141 --- pyrad3/tests/testBidict.py | 56 -- pyrad3/tests/testClient.py | 184 ---- pyrad3/tests/testDictionary.py | 332 ------- pyrad3/tests/testHost.py | 87 -- pyrad3/tests/testPacket.py | 535 ------------ pyrad3/tests/testProxy.py | 99 --- pyrad3/tests/testServer.py | 329 ------- pyrad3/tests/testTools.py | 119 --- pyrad3/tools.py | 226 ----- pyrad3/utils.py | 57 -- shell.nix | 19 +- {pyrad3 => src/pyrad3}/__init__.py | 8 +- {pyrad3 => src/pyrad3}/bidict.py | 2 +- pyrad3/new_client.py => src/pyrad3/client.py | 69 +- {pyrad3 => src/pyrad3}/client_async.py | 204 +++-- src/pyrad3/dictionary.py | 483 +++++++++++ src/pyrad3/host.py | 46 + src/pyrad3/packet.py | 305 +++++++ {pyrad3 => src/pyrad3}/proxy.py | 17 +- {pyrad3 => src/pyrad3}/server.py | 67 +- {pyrad3 => src/pyrad3}/server_async.py | 262 ++++-- src/pyrad3/tools.py | 243 ++++++ src/pyrad3/utils.py | 234 +++++ tests/dictionaries/mutual_recursive | 1 + tests/dictionaries/other_mutual_recursive | 1 + tests/dictionaries/self_recursive | 1 + tests/test_dictionary.py | 190 ++++ 53 files changed, 2167 insertions(+), 4589 deletions(-) delete mode 100644 pyrad3/client.py delete mode 100644 pyrad3/curved.py delete mode 100644 pyrad3/dictfile.py delete mode 100644 pyrad3/dictionary.py delete mode 100644 pyrad3/host.py delete mode 100644 pyrad3/new_host.py delete mode 100644 pyrad3/new_packet.py delete mode 100644 pyrad3/packet.py delete mode 100644 pyrad3/tests/__init__.py delete mode 100644 pyrad3/tests/data/chap delete mode 100644 pyrad3/tests/data/dictfiletest delete mode 100644 pyrad3/tests/data/full delete mode 100644 pyrad3/tests/data/simple delete mode 100644 pyrad3/tests/data/tunnelPassword delete mode 100644 pyrad3/tests/mock.py delete mode 100644 pyrad3/tests/testBidict.py delete mode 100644 pyrad3/tests/testClient.py delete mode 100644 pyrad3/tests/testDictionary.py delete mode 100644 pyrad3/tests/testHost.py delete mode 100644 pyrad3/tests/testPacket.py delete mode 100644 pyrad3/tests/testProxy.py delete mode 100644 pyrad3/tests/testServer.py delete mode 100644 pyrad3/tests/testTools.py delete mode 100644 pyrad3/tools.py delete mode 100644 pyrad3/utils.py rename {pyrad3 => src/pyrad3}/__init__.py (84%) rename {pyrad3 => src/pyrad3}/bidict.py (98%) rename pyrad3/new_client.py => src/pyrad3/client.py (55%) rename {pyrad3 => src/pyrad3}/client_async.py (68%) create mode 100644 src/pyrad3/dictionary.py create mode 100644 src/pyrad3/host.py create mode 100644 src/pyrad3/packet.py rename {pyrad3 => src/pyrad3}/proxy.py (84%) rename {pyrad3 => src/pyrad3}/server.py (86%) rename {pyrad3 => src/pyrad3}/server_async.py (53%) create mode 100644 src/pyrad3/tools.py create mode 100644 src/pyrad3/utils.py create mode 100644 tests/dictionaries/mutual_recursive create mode 100644 tests/dictionaries/other_mutual_recursive create mode 100644 tests/dictionaries/self_recursive create mode 100644 tests/test_dictionary.py diff --git a/.flake8 b/.flake8 index 4c2af06..7da1f96 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,2 @@ [flake8] -max-complexity = 10 max-line-length = 100 diff --git a/default.nix b/default.nix index 29bd1d5..ad0f35a 100644 --- a/default.nix +++ b/default.nix @@ -2,10 +2,13 @@ python.pkgs.buildPythonPackage rec { pname = "pyrad"; - version = "3.0-alpha"; + version = "1.0-alpha"; - buildInputs = with python.pkgs; [ netaddr six ]; + buildInputs = with python.pkgs; [ ]; checkInputs = with python.pkgs; [ + black + pytest + # pytest-cov ]; } diff --git a/example/acct.py b/example/acct.py index 5a22622..d68601f 100755 --- a/example/acct.py +++ b/example/acct.py @@ -15,41 +15,47 @@ def send_accounting_packet(srv, req): try: srv.SendPacket(req) except pyrad.client.Timeout: - print('RADIUS server does not reply') + print("RADIUS server does not reply") sys.exit(1) except socket.error as error: - print('Network error: ' + error[1]) + print("Network error: " + error[1]) sys.exit(1) def main(path_to_dictionary): - srv = Client(server='127.0.0.1', - secret=b'Kah3choteereethiejeimaeziecumi', - dict=Dictionary(path_to_dictionary)) + srv = Client( + server="127.0.0.1", + secret=b"Kah3choteereethiejeimaeziecumi", + dict=Dictionary(path_to_dictionary), + ) - req = srv.CreateAcctPacket(**{ - 'User-Name': 'wichert', - 'NAS-IP-Address': '192.168.1.10', - 'NAS-Port': 0, - 'NAS-Identifier': 'trillian', - 'Called-Station-Id': '00-04-5F-00-0F-D1', - 'Calling-Station-Id': '00-01-24-80-B3-9C', - 'Framed-IP-Address': '10.0.0.100', - }) + req = srv.CreateAcctPacket( + **{ + "User-Name": "wichert", + "NAS-IP-Address": "192.168.1.10", + "NAS-Port": 0, + "NAS-Identifier": "trillian", + "Called-Station-Id": "00-04-5F-00-0F-D1", + "Calling-Station-Id": "00-01-24-80-B3-9C", + "Framed-IP-Address": "10.0.0.100", + } + ) - print('Sending accounting start packet') - req['Acct-Status-Type'] = 'Start' + print("Sending accounting start packet") + req["Acct-Status-Type"] = "Start" send_accounting_packet(srv, req) - print('Sending accounting stop packet') - req['Acct-Status-Type'] = 'Stop' - req['Acct-Input-Octets'] = random.randrange(2**10, 2**30) - req['Acct-Output-Octets'] = random.randrange(2**10, 2**30) - req['Acct-Session-Time'] = random.randrange(120, 3600) - req['Acct-Terminate-Cause'] = random.choice(['User-Request', 'Idle-Timeout']) + print("Sending accounting stop packet") + req["Acct-Status-Type"] = "Stop" + req["Acct-Input-Octets"] = random.randrange(2 ** 10, 2 ** 30) + req["Acct-Output-Octets"] = random.randrange(2 ** 10, 2 ** 30) + req["Acct-Session-Time"] = random.randrange(120, 3600) + req["Acct-Terminate-Cause"] = random.choice( + ["User-Request", "Idle-Timeout"] + ) send_accounting_packet(srv, req) -if __name__ == '__main__': - dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary') +if __name__ == "__main__": + dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary") main(dictionary) diff --git a/example/auth.py b/example/auth.py index 6e2c89c..92a3939 100755 --- a/example/auth.py +++ b/example/auth.py @@ -11,43 +11,46 @@ from pyrad.dictionary import Dictionary def main(path_to_dictionary): - srv = Client(server='127.0.0.1', - secret=b'Kah3choteereethiejeimaeziecumi', - dict=Dictionary(path_to_dictionary)) + srv = Client( + server="127.0.0.1", + secret=b"Kah3choteereethiejeimaeziecumi", + dict=Dictionary(path_to_dictionary), + ) req = srv.CreateAuthPacket( code=pyrad.packet.AccessRequest, **{ - 'User-Name': 'wichert', - 'NAS-IP-Address': '192.168.1.10', - 'NAS-Port': 0, - 'Service-Type': 'Login-User', - 'NAS-Identifier': 'trillian', - 'Called-Station-Id': '00-04-5F-00-0F-D1', - 'Calling-Station-Id': '00-01-24-80-B3-9C', - 'Framed-IP-Address': '10.0.0.100', - }) + "User-Name": "wichert", + "NAS-IP-Address": "192.168.1.10", + "NAS-Port": 0, + "Service-Type": "Login-User", + "NAS-Identifier": "trillian", + "Called-Station-Id": "00-04-5F-00-0F-D1", + "Calling-Station-Id": "00-01-24-80-B3-9C", + "Framed-IP-Address": "10.0.0.100", + }, + ) try: - print('Sending authentication request') + print("Sending authentication request") reply = srv.SendPacket(req) except pyrad.client.Timeout: - print('RADIUS server does not reply') + print("RADIUS server does not reply") sys.exit(1) except socket.error as error: - print('Network error: ' + error[1]) + print("Network error: " + error[1]) sys.exit(1) if reply.code == pyrad.packet.AccessAccept: - print('Access accepted') + print("Access accepted") else: - print('Access denied') + print("Access denied") - print('Attributes returned by server:') + print("Attributes returned by server:") for key, value in reply.items(): - print(f'{key} {value}') + print(f"{key} {value}") -if __name__ == '__main__': - dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary') +if __name__ == "__main__": + dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary") main(dictionary) diff --git a/example/auth_async.py b/example/auth_async.py index b6c04e3..7141fed 100755 --- a/example/auth_async.py +++ b/example/auth_async.py @@ -10,49 +10,58 @@ from pyrad.client_async import ClientAsync from pyrad.dictionary import Dictionary from pyrad.packet import AccessAccept -logging.basicConfig(level='DEBUG', - format='%(asctime)s [%(levelname)-8s] %(message)s') +logging.basicConfig( + level="DEBUG", format="%(asctime)s [%(levelname)-8s] %(message)s" +) def create_request(client, user): - return client.CreateAuthPacket(**{ - 'User-Name': user, - 'NAS-IP-Address': '192.168.1.10', - 'NAS-Port': 0, - 'Service-Type': 'Login-User', - 'NAS-Identifier': 'trillian', - 'Called-Station-Id': '00-04-5F-00-0F-D1', - 'Calling-Station-Id': '00-01-24-80-B3-9C', - 'Framed-IP-Address': '10.0.0.100', - }) + return client.CreateAuthPacket( + **{ + "User-Name": user, + "NAS-IP-Address": "192.168.1.10", + "NAS-Port": 0, + "Service-Type": "Login-User", + "NAS-Identifier": "trillian", + "Called-Station-Id": "00-04-5F-00-0F-D1", + "Calling-Station-Id": "00-01-24-80-B3-9C", + "Framed-IP-Address": "10.0.0.100", + } + ) def print_reply(reply): if reply.code == AccessAccept: - print('Access accepted') + print("Access accepted") else: - print('Access denied') + print("Access denied") - print('Attributes returned by server:') + print("Attributes returned by server:") for key, value in reply.items(): - print(f'{key}: {value}') + print(f"{key}: {value}") def initialize_transport(loop, client): loop.run_until_complete( asyncio.ensure_future( - client.initialize_transports(enable_auth=True, - local_addr='127.0.0.1', - local_auth_port=8000, - enable_acct=True, - enable_coa=True))) + client.initialize_transports( + enable_auth=True, + local_addr="127.0.0.1", + local_auth_port=8000, + enable_acct=True, + enable_coa=True, + ) + ) + ) def main(path_to_dictionary): - client = ClientAsync(server='localhost', - secret=b'Kah3choteereethiejeimaeziecumi', - timeout=4, - dict=Dictionary(path_to_dictionary)) + client = ClientAsync( + server="localhost", + secret=b"Kah3choteereethiejeimaeziecumi", + timeout=4, + dict=Dictionary(path_to_dictionary), + ) loop = asyncio.get_event_loop() @@ -62,41 +71,41 @@ def main(path_to_dictionary): requests = [] for i in range(255): - req = create_request(client, f'user{i}') + req = create_request(client, f"user{i}") future = client.SendPacket(req) requests.append(future) # Send auth requests asynchronously to the server - loop.run_until_complete(asyncio.ensure_future( - asyncio.gather( - *requests, - return_exceptions=True + loop.run_until_complete( + asyncio.ensure_future( + asyncio.gather(*requests, return_exceptions=True) ) - - )) + ) for future in requests: if future.exception(): - print('EXCEPTION ', future.exception()) + print("EXCEPTION ", future.exception()) else: reply = future.result() print_reply(reply) # Close transports - loop.run_until_complete(asyncio.ensure_future( - client.deinitialize_transports())) - print('END') + loop.run_until_complete( + asyncio.ensure_future(client.deinitialize_transports()) + ) + print("END") except Exception as exc: - print('Error: ', exc) + print("Error: ", exc) traceback.print_exc() # Close transports - loop.run_until_complete(asyncio.ensure_future( - client.deinitialize_transports())) + loop.run_until_complete( + asyncio.ensure_future(client.deinitialize_transports()) + ) loop.close() -if __name__ == '__main__': - dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary') +if __name__ == "__main__": + dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary") main(dictionary) diff --git a/example/client-coa.py b/example/client-coa.py index 2a41927..a41a035 100755 --- a/example/client-coa.py +++ b/example/client-coa.py @@ -14,21 +14,21 @@ from pyrad.server import Server, RemoteHost def print_attributes(packet): - print('Attributes') + print("Attributes") for key, value in packet.items(): - print(f'{key}: {value}') + print(f"{key}: {value}") class FakeCoA(Server): def HandleCoaPacket(self, packet): - '''Accounting packet handler. + """Accounting packet handler. Function that is called when a valid accounting packet has been received. :param packet: packet to process :type packet: Packet class instance - ''' - print('Received a coa request %d' % packet.code) + """ + print("Received a coa request %d" % packet.code) print_attributes(packet) reply = self.CreateReplyPacket(packet) @@ -38,7 +38,7 @@ class FakeCoA(Server): self.SendReplyPacket(packet.fd, reply) def HandleDisconnectPacket(self, packet): - print('Received a disconnect request %d' % packet.code) + print("Received a disconnect request %d" % packet.code) print_attributes(packet) reply = self.CreateReplyPacket(packet) @@ -52,27 +52,27 @@ def main(path_to_dictionary, coa_port): # create server/coa only and read dictionary # bind and listen only on 127.0.0.1:argv[1] coa = FakeCoA( - addresses=['127.0.0.1'], + addresses=["127.0.0.1"], dict=Dictionary(path_to_dictionary), coaport=coa_port, auth_enabled=False, acct_enabled=False, - coa_enabled=True) + coa_enabled=True, + ) # add peers (address, secret, name) - coa.hosts['127.0.0.1'] = RemoteHost( - '127.0.0.1', - b'Kah3choteereethiejeimaeziecumi', - 'localhost') + coa.hosts["127.0.0.1"] = RemoteHost( + "127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost" + ) # start coa.Run() -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) != 2: - print('usage: client-coa.py {portnumber}') + print("usage: client-coa.py {portnumber}") sys.exit(1) - dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary') + dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary") main(dictionary, int(sys.argv[1])) diff --git a/example/coa.py b/example/coa.py index 5082ed4..f0ef313 100755 --- a/example/coa.py +++ b/example/coa.py @@ -11,27 +11,29 @@ from pyrad.dictionary import Dictionary def main(path_to_dictionary, coa_type, nas_identifier): # create coa client - client = Client(server='127.0.0.1', - secret=b'Kah3choteereethiejeimaeziecumi', - dict=Dictionary(path_to_dictionary)) + client = Client( + server="127.0.0.1", + secret=b"Kah3choteereethiejeimaeziecumi", + dict=Dictionary(path_to_dictionary), + ) # set coa timeout client.timeout = 30 # create coa request packet attributes = { - 'Acct-Session-Id': '1337', - 'NAS-Identifier': nas_identifier, + "Acct-Session-Id": "1337", + "NAS-Identifier": nas_identifier, } - if coa_type == 'coa': + if coa_type == "coa": # create coa request request = client.CreateCoAPacket(**attributes) - elif coa_type == 'dis': + elif coa_type == "dis": # create disconnect request request = client.CreateCoAPacket( - code=pyrad.packet.DisconnectRequest, - **attributes) + code=pyrad.packet.DisconnectRequest, **attributes + ) else: sys.exit(1) @@ -41,11 +43,11 @@ def main(path_to_dictionary, coa_type, nas_identifier): print(result.code) -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) != 3: - print('usage: coa.py {coa|dis} daemon-1234') + print("usage: coa.py {coa|dis} daemon-1234") sys.exit(1) - dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary') + dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary") main(dictionary, sys.argv[1], sys.argv[2]) diff --git a/example/server.py b/example/server.py index 342b93f..3de61f9 100755 --- a/example/server.py +++ b/example/server.py @@ -8,46 +8,52 @@ import pyrad.packet from pyrad import server from pyrad.dictionary import Dictionary -logging.basicConfig(filename='pyrad.log', level='DEBUG', - format='%(asctime)s [%(levelname)-8s] %(message)s') +logging.basicConfig( + filename="pyrad.log", + level="DEBUG", + format="%(asctime)s [%(levelname)-8s] %(message)s", +) def print_attributes(packet): - print('Attributes') + print("Attributes") for key, value in packet.items(): - print(f'{key}: {value}') + print(f"{key}: {value}") class FakeServer(server.Server): def HandleAuthPacket(self, packet): - print('Received an authentication request') + print("Received an authentication request") print_attributes(packet) - reply = self.CreateReplyPacket(packet, **{ - 'Service-Type': 'Framed-User', - 'Framed-IP-Address': '192.168.0.1', - 'Framed-IPv6-Prefix': 'fc66::/64' - }) + reply = self.CreateReplyPacket( + packet, + **{ + "Service-Type": "Framed-User", + "Framed-IP-Address": "192.168.0.1", + "Framed-IPv6-Prefix": "fc66::/64", + }, + ) reply.code = pyrad.packet.AccessAccept self.SendReplyPacket(packet.fd, reply) def HandleAcctPacket(self, packet): - print('Received an accounting request') + print("Received an accounting request") print_attributes(packet) reply = self.CreateReplyPacket(packet) self.SendReplyPacket(packet.fd, reply) def HandleCoaPacket(self, packet): - print('Received an coa request') + print("Received an coa request") print_attributes(packet) reply = self.CreateReplyPacket(packet) self.SendReplyPacket(packet.fd, reply) def HandleDisconnectPacket(self, packet): - print('Received an disconnect request') + print("Received an disconnect request") print_attributes(packet) reply = self.CreateReplyPacket(packet) @@ -58,20 +64,18 @@ class FakeServer(server.Server): def main(path_to_dictionary): # create server and read dictionary - srv = FakeServer(dict=Dictionary(path_to_dictionary), - coa_enabled=True) + srv = FakeServer(dict=Dictionary(path_to_dictionary), coa_enabled=True) # add clients (address, secret, name) - srv.hosts['127.0.0.1'] = server.RemoteHost( - '127.0.0.1', - b'Kah3choteereethiejeimaeziecumi', - 'localhost') - srv.BindToAddress('0.0.0.0') + srv.hosts["127.0.0.1"] = server.RemoteHost( + "127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost" + ) + srv.BindToAddress("0.0.0.0") # start server srv.Run() -if __name__ == '__main__': - dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary') +if __name__ == "__main__": + dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary") main(dictionary) diff --git a/example/server_async.py b/example/server_async.py index 3ed39f0..2f918a3 100755 --- a/example/server_async.py +++ b/example/server_async.py @@ -12,57 +12,67 @@ from pyrad.server import RemoteHost try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except: pass -logging.basicConfig(level='DEBUG', - format='%(asctime)s [%(levelname)-8s] %(message)s') +logging.basicConfig( + level="DEBUG", format="%(asctime)s [%(levelname)-8s] %(message)s" +) def print_attributes(packet): - print('Attributes returned by server:') + print("Attributes returned by server:") for key, value in packet.items(): - print(f'{key}: {value}') + print(f"{key}: {value}") class FakeServer(ServerAsync): def __init__(self, loop, dictionary): - ServerAsync.__init__(self, loop=loop, dictionary=dictionary, - enable_pkt_verify=True, debug=True) + ServerAsync.__init__( + self, + loop=loop, + dictionary=dictionary, + enable_pkt_verify=True, + debug=True, + ) def handle_auth_packet(self, protocol, packet, addr): - print('Received an authentication request with id ', packet.id) - print('Authenticator ', packet.authenticator.hex()) - print('Secret ', packet.secret) + print("Received an authentication request with id ", packet.id) + print("Authenticator ", packet.authenticator.hex()) + print("Secret ", packet.secret) print_attributes(packet) - reply = self.CreateReplyPacket(packet, **{ - 'Service-Type': 'Framed-User', - 'Framed-IP-Address': '192.168.0.1', - 'Framed-IPv6-Prefix': 'fc66::/64' - }) + reply = self.CreateReplyPacket( + packet, + **{ + "Service-Type": "Framed-User", + "Framed-IP-Address": "192.168.0.1", + "Framed-IPv6-Prefix": "fc66::/64", + }, + ) reply.code = AccessAccept protocol.send_response(reply, addr) def handle_acct_packet(self, protocol, packet, addr): - print('Received an accounting request') + print("Received an accounting request") print_attributes(packet) reply = self.CreateReplyPacket(packet) protocol.send_response(reply, addr) def handle_coa_packet(self, protocol, packet, addr): - print('Received an coa request') + print("Received an coa request") print_attributes(packet) reply = self.CreateReplyPacket(packet) protocol.send_response(reply, addr) def handle_disconnect_packet(self, protocol, packet, addr): - print('Received an disconnect request') + print("Received an disconnect request") print_attributes(packet) reply = self.CreateReplyPacket(packet) @@ -77,17 +87,19 @@ def main(path_to_dictionary): server = FakeServer(loop=loop, dictionary=Dictionary(path_to_dictionary)) # add clients (address, secret, name) - server.hosts['127.0.0.1'] = RemoteHost('127.0.0.1', - b'Kah3choteereethiejeimaeziecumi', - 'localhost') + server.hosts["127.0.0.1"] = RemoteHost( + "127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost" + ) try: # Initialize transports loop.run_until_complete( asyncio.ensure_future( - server.initialize_transports(enable_auth=True, - enable_acct=True, - enable_coa=True))) + server.initialize_transports( + enable_auth=True, enable_acct=True, enable_coa=True + ) + ) + ) try: # start server @@ -96,20 +108,22 @@ def main(path_to_dictionary): pass # Close transports - loop.run_until_complete(asyncio.ensure_future( - server.deinitialize_transports())) + loop.run_until_complete( + asyncio.ensure_future(server.deinitialize_transports()) + ) except Exception as exc: - print('Error: ', exc) + print("Error: ", exc) traceback.print_exc() # Close transports - loop.run_until_complete(asyncio.ensure_future( - server.deinitialize_transports())) + loop.run_until_complete( + asyncio.ensure_future(server.deinitialize_transports()) + ) loop.close() -if __name__ == '__main__': - dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary') +if __name__ == "__main__": + dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary") main(dictionary) diff --git a/example/status.py b/example/status.py index c1dfc2d..6322a02 100755 --- a/example/status.py +++ b/example/status.py @@ -11,32 +11,33 @@ from pyrad.dictionary import Dictionary def main(path_to_dictionary): - srv = Client(server='localhost', - authport=18121, - secret=b'test', - dict=Dictionary(path_to_dictionary)) + srv = Client( + server="localhost", + authport=18121, + secret=b"test", + dict=Dictionary(path_to_dictionary), + ) req = srv.CreateAuthPacket( - code=pyrad.packet.StatusServer, - FreeRADIUS_Statistics_Type='All', + code=pyrad.packet.StatusServer, FreeRADIUS_Statistics_Type="All", ) req.add_message_authenticator() try: - print('Sending FreeRADIUS status request') + print("Sending FreeRADIUS status request") reply = srv.SendPacket(req) except pyrad.client.Timeout: - print('RADIUS server does not reply') + print("RADIUS server does not reply") sys.exit(1) except socket.error as error: - print('Network error: ' + error[1]) + print("Network error: " + error[1]) sys.exit(1) - print('Attributes returned by server:') + print("Attributes returned by server:") for key, value in reply.items(): - print(f'{key}: {value}') + print(f"{key}: {value}") -if __name__ == '__main__': - dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary') +if __name__ == "__main__": + dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary") main(dictionary) diff --git a/pyproject.toml b/pyproject.toml index 72c9eb7..e1a3fea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,13 +21,11 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: System :: Systems Administration :: Authentication/Directory" - ] packages = [ - { include = "pyrad3"}, + { include = "src/pyrad3"}, ] include = [ - "CHANGELOG.md", "LICENSE-APACHE", "LICENSE-MIT", "README.md", @@ -40,9 +38,16 @@ repository = "https://github.com/pyradius/pyrad3" [tool.poetry.dev-dependencies] pytest = "^5.4" -pytest-cov = "^2.5" +pytest-black = "^0.30" +pytest-cov = "^2.10" +pytest-mypy = "^0.6" +pytest-pylint = "^0.17" [tool.black] -line-length = 100 +line-length = 80 include = '\.py' + + +[tool.pylint.messages_control] +disable = "bad-continuation" diff --git a/pyrad3/client.py b/pyrad3/client.py deleted file mode 100644 index 951cafb..0000000 --- a/pyrad3/client.py +++ /dev/null @@ -1,221 +0,0 @@ -# client.py -# -# Copyright 2002-2007 Wichert Akkerman - -__docformat__ = "epytext en" - -import hashlib -import select -import socket -import time -import struct -from pyrad import host -from pyrad import packet - -EAP_CODE_REQUEST = 1 -EAP_CODE_RESPONSE = 2 -EAP_TYPE_IDENTITY = 1 - - -class Timeout(Exception): - """Simple exception class which is raised when a timeout occurs - while waiting for a RADIUS server to respond.""" - - -class Client(host.Host): - """Basic RADIUS client. - This class implements a basic RADIUS client. It can send requests - to a RADIUS server, taking care of timeouts and retries, and - validate its replies. - - :ivar retries: number of times to retry sending a RADIUS request - :type retries: integer - :ivar timeout: number of seconds to wait for an answer - :type timeout: float - """ - def __init__(self, server, authport=1812, acctport=1813, - coaport=3799, secret=b'', dict=None, retries=3, - timeout=5): - - """Constructor. - - :param server: hostname or IP address of RADIUS server - :type server: string - :param authport: port to use for authentication packets - :type authport: integer - :param acctport: port to use for accounting packets - :type acctport: integer - :param coaport: port to use for CoA packets - :type coaport: integer - :param secret: RADIUS secret - :type secret: string - :param dict: RADIUS dictionary - :type dict: pyrad.dictionary.Dictionary - """ - host.Host.__init__(self, authport, acctport, coaport, dict) - - self.server = server - self.secret = secret - self._socket = None - self.retries = retries - self.timeout = timeout - self._poll = select.poll() - - def bind(self, addr): - """Bind socket to an address. - Binding the socket used for communicating to an address can be - usefull when working on a machine with multiple addresses. - - :param addr: network address (hostname or IP) and port to bind to - :type addr: host,port tuple - """ - self._close_socket() - self._socket_open() - self._socket.bind(addr) - - def _socket_open(self): - try: - family = socket.getaddrinfo(self.server, 'www')[0][0] - except: - family = socket.AF_INET - if not self._socket: - self._socket = socket.socket(family, - socket.SOCK_DGRAM) - self._socket.setsockopt(socket.SOL_SOCKET, - socket.SO_REUSEADDR, 1) - self._poll.register(self._socket, select.POLLIN) - - def _close_socket(self): - if self._socket: - self._poll.unregister(self._socket) - self._socket.close() - self._socket = None - - def CreateAuthPacket(self, **args): - """Create a new RADIUS packet. - This utility function creates a new RADIUS packet which can - be used to communicate with the RADIUS server this client - talks to. This is initializing the new packet with the - dictionary and secret used for the client. - - :return: a new empty packet instance - :rtype: pyrad.packet.AuthPacket - """ - return host.Host.CreateAuthPacket(self, secret=self.secret, **args) - - def CreateAcctPacket(self, **args): - """Create a new RADIUS packet. - This utility function creates a new RADIUS packet which can - be used to communicate with the RADIUS server this client - talks to. This is initializing the new packet with the - dictionary and secret used for the client. - - :return: a new empty packet instance - :rtype: pyrad.packet.Packet - """ - return host.Host.CreateAcctPacket(self, secret=self.secret, **args) - - def CreateCoAPacket(self, **args): - """Create a new RADIUS packet. - This utility function creates a new RADIUS packet which can - be used to communicate with the RADIUS server this client - talks to. This is initializing the new packet with the - dictionary and secret used for the client. - - :return: a new empty packet instance - :rtype: pyrad.packet.Packet - """ - return host.Host.CreateCoAPacket(self, secret=self.secret, **args) - - def _send_packet(self, pkt, port): - """Send a packet to a RADIUS server. - - :param pkt: the packet to send - :type pkt: pyrad.packet.Packet - :param port: UDP port to send packet to - :type port: integer - :return: the reply packet received - :rtype: pyrad.packet.Packet - :raise Timeout: RADIUS server does not reply - """ - self._socket_open() - - for attempt in range(self.retries): - if attempt and pkt.code == packet.AccountingRequest: - if "Acct-Delay-Time" in pkt: - pkt["Acct-Delay-Time"] = \ - pkt["Acct-Delay-Time"][0] + self.timeout - else: - pkt["Acct-Delay-Time"] = self.timeout - - now = time.time() - waitto = now + self.timeout - - self._socket.sendto(pkt.RequestPacket(), (self.server, port)) - - while now < waitto: - ready = self._poll.poll((waitto - now) * 1000) - - if ready: - rawreply = self._socket.recv(4096) - else: - now = time.time() - continue - - try: - return pkt.VerifyReply(rawreply) - except packet.PacketError: - # TODO: report or error out maybe? - pass - - now = time.time() - - raise Timeout - - def SendPacket(self, pkt): - """Send a packet to a RADIUS server. - - :param pkt: the packet to send - :type pkt: pyrad.packet.Packet - :return: the reply packet received - :rtype: pyrad.packet.Packet - :raise Timeout: RADIUS server does not reply - """ - if isinstance(pkt, packet.AuthPacket): - if pkt.auth_type == 'eap-md5': - # Creating EAP-Identity - password = pkt[2][0] if 2 in pkt else pkt[1][0] - pkt[79] = [struct.pack('!BBHB%ds' % len(password), - EAP_CODE_RESPONSE, - packet.CurrentID, - len(password) + 5, - EAP_TYPE_IDENTITY, - password)] - reply = self._send_packet(pkt, self.authport) - if (reply - and reply.code == packet.AccessChallenge - and pkt.auth_type == 'eap-md5' - ): - # Got an Access-Challenge - _eap_code, eap_id, _eap_size, _eap_type, eap_md5 = struct.unpack( - '!BBHB%ds' % (len(reply[79][0]) - 5), reply[79][0] - ) - # Sending back an EAP-Type-MD5-Challenge - # Thank god for http://www.secdev.org/python/eapy.py - client_pw = pkt[2][0] if 2 in pkt else pkt[1][0] - md5_challenge = hashlib.md5( - struct.pack('!B', eap_id) + client_pw + eap_md5[1:] - ).digest() - pkt[79] = [ - struct.pack('!BBHBB', 2, eap_id, len(md5_challenge) + 6, - 4, len(md5_challenge)) + md5_challenge - ] - # Copy over Challenge-State - pkt[24] = reply[24] - reply = self._send_packet(pkt, self.authport) - return reply - - if isinstance(pkt, packet.CoAPacket): - return self._send_packet(pkt, self.coaport) - - return self._send_packet(pkt, self.acctport) diff --git a/pyrad3/curved.py b/pyrad3/curved.py deleted file mode 100644 index 2a32c7a..0000000 --- a/pyrad3/curved.py +++ /dev/null @@ -1,81 +0,0 @@ -# curved.py -# -# Copyright 2002 Wichert Akkerman - -"""Twisted integration code -""" - -__docformat__ = 'epytext en' - -from twisted.internet import protocol -from twisted.internet import reactor -from twisted.python import log -import sys -from pyrad import dictionary -from pyrad import host -from pyrad import packet - - -class PacketError(Exception): - """Exception class for bogus packets - - PacketError exceptions are only used inside the Server class to - abort processing of a packet. - """ - - -class RADIUS(host.Host, protocol.DatagramProtocol): - def __init__(self, hosts={}, dict=dictionary.Dictionary()): - host.Host.__init__(self, dict=dict) - self.hosts = hosts - - def processPacket(self, pkt): - pass - - def createPacket(self, **kwargs): - raise NotImplementedError('Attempted to use a pure base class') - - def datagramReceived(self, datagram, source): - host, port = source - try: - pkt = self.CreatePacket(packet=datagram) - except packet.PacketError as err: - log.msg('Dropping invalid packet: ' + str(err)) - return - - if host not in self.hosts: - log.msg('Dropping packet from unknown host ' + host) - return - - pkt.source = (host, port) - try: - self.processPacket(pkt) - except PacketError as err: - log.msg('Dropping packet from %s: %s' % (host, str(err))) - - -class RADIUSAccess(RADIUS): - def createPacket(self, **kwargs): - self.CreateAuthPacket(**kwargs) - - def processPacket(self, pkt): - if pkt.code != packet.AccessRequest: - raise PacketError( - 'non-AccessRequest packet on authentication socket') - - -class RADIUSAccounting(RADIUS): - def createPacket(self, **kwargs): - self.CreateAcctPacket(**kwargs) - - def processPacket(self, pkt): - if pkt.code != packet.AccountingRequest: - raise PacketError( - 'non-AccountingRequest packet on authentication socket') - - -if __name__ == '__main__': - log.startLogging(sys.stdout, 0) - reactor.listenUDP(1812, RADIUSAccess()) - reactor.listenUDP(1813, RADIUSAccounting()) - reactor.run() diff --git a/pyrad3/dictfile.py b/pyrad3/dictfile.py deleted file mode 100644 index cc90d57..0000000 --- a/pyrad3/dictfile.py +++ /dev/null @@ -1,116 +0,0 @@ -# dictfile.py -# -# Copyright 2009 Kristoffer Gronlund - -""" Dictionary File - -Implements an iterable file format that handles the -RADIUS $INCLUDE directives behind the scene. -""" - -import os - - -class _Node(): - """Dictionary file node - - A single dictionary file. - """ - __slots__ = ('name', 'lines', 'current', 'length', 'dir') - - def __init__(self, fd, name, parentdir): - self.lines = fd.readlines() - self.length = len(self.lines) - self.current = 0 - self.name = os.path.basename(name) - path = os.path.dirname(name) - if os.path.isabs(path): - self.dir = path - else: - self.dir = os.path.join(parentdir, path) - - def next(self): - if self.current >= self.length: - return None - self.current += 1 - return self.lines[self.current - 1] - - -class DictFile(): - """Dictionary file class - - An iterable file type that handles $INCLUDE - directives internally. - """ - __slots__ = ['stack'] - - def __init__(self, fil): - """ - @param fil: a dictionary file to parse - @type fil: string or file - """ - self.stack = [] - self.__read_node(fil) - - def __read_node(self, fil): - node = None - parentdir = self.__cur_dir() - if isinstance(fil, str): - fname = None - if os.path.isabs(fil): - fname = fil - else: - fname = os.path.join(parentdir, fil) - fd = open(fname, "rt") - node = _Node(fd, fil, parentdir) - fd.close() - else: - node = _Node(fil, '', parentdir) - self.stack.append(node) - - def __cur_dir(self): - if self.stack: - return self.stack[-1].dir - return os.path.realpath(os.curdir) - - @staticmethod - def __get_include(line): - line = line.split("#", 1)[0].strip() - tokens = line.split() - if tokens and tokens[0].upper() == '$INCLUDE': - return " ".join(tokens[1:]) - return None - - def line(self): - """Returns line number of current file - """ - try: - return self.stack[-1].current - except (AttributeError, IndexError): - return -1 - - def file(self): - """Returns name of current file - """ - try: - return self.stack[-1].name - except (AttributeError, IndexError): - return '' - - def __iter__(self): - return self - - def __next__(self): - while self.stack: - line = self.stack[-1].next() - if line is None: - self.stack.pop() - else: - inc = DictFile.__get_include(line) - if inc: - self.__read_node(inc) - else: - return line - raise StopIteration - - next = __next__ # BBB for python <3 diff --git a/pyrad3/dictionary.py b/pyrad3/dictionary.py deleted file mode 100644 index e09dd86..0000000 --- a/pyrad3/dictionary.py +++ /dev/null @@ -1,393 +0,0 @@ -# dictionary.py -# -# Copyright 2002,2005,2007,2016 Wichert Akkerman -""" -RADIUS uses dictionaries to define the attributes that can -be used in packets. The Dictionary class stores the attribute -definitions from one or more dictionary files. - -Dictionary files are textfiles with one command per line. -Comments are specified by starting with a # character, and empty -lines are ignored. - -The commands supported are:: - - ATTRIBUTE [] - specify an attribute and its type - - VALUE - specify a value attribute - - VENDOR - specify a vendor ID - - BEGIN-VENDOR - begin definition of vendor attributes - - END-VENDOR - end definition of vendor attributes - - -The datatypes currently supported are: - -+---------------+----------------------------------------------+ -| type | description | -+===============+==============================================+ -| string | ASCII string | -+---------------+----------------------------------------------+ -| ipaddr | IPv4 address | -+---------------+----------------------------------------------+ -| date | 32 bits UNIX | -+---------------+----------------------------------------------+ -| octets | arbitrary binary data | -+---------------+----------------------------------------------+ -| abinary | ascend binary data | -+---------------+----------------------------------------------+ -| ipv6addr | 16 octets in network byte order | -+---------------+----------------------------------------------+ -| ipv6prefix | 18 octets in network byte order | -+---------------+----------------------------------------------+ -| integer | 32 bits unsigned number | -+---------------+----------------------------------------------+ -| signed | 32 bits signed number | -+---------------+----------------------------------------------+ -| short | 16 bits unsigned number | -+---------------+----------------------------------------------+ -| byte | 8 bits unsigned number | -+---------------+----------------------------------------------+ -| tlv | Nested tag-length-value | -+---------------+----------------------------------------------+ -| integer64 | 64 bits unsigned number | -+---------------+----------------------------------------------+ - -These datatypes are parsed but not supported: - -+---------------+----------------------------------------------+ -| type | description | -+===============+==============================================+ -| ifid | 8 octets in network byte order | -+---------------+----------------------------------------------+ -| ether | 6 octets of hh:hh:hh:hh:hh:hh | -| | where 'h' is hex digits, upper or lowercase. | -+---------------+----------------------------------------------+ -""" -from copy import copy - -from pyrad import bidict -from pyrad import tools -from pyrad import dictfile - -__docformat__ = 'epytext en' - - -DATATYPES = frozenset(['string', 'ipaddr', 'integer', 'date', 'octets', - 'abinary', 'ipv6addr', 'ipv6prefix', 'short', 'byte', - 'signed', 'ifid', 'ether', 'tlv', 'integer64']) - - -class ParseError(Exception): - """Dictionary parser exceptions. - - :ivar msg: Error message - :type msg: string - :ivar linenumber: Line number on which the error occurred - :type linenumber: integer - """ - - def __init__(self, msg=None, **data): - super().__init__() - self.msg = msg - self.file = data.get('file', '') - self.line = data.get('line', -1) - - def __str__(self): - line = f'({self.line})' if self.line > -1 else '' - return f'{self.file}{line}: ParseError: {self.msg}' - - -class Attribute(): - def __init__(self, name, code, datatype, is_sub_attribute=False, vendor='', values=None, - encrypt=0, has_tag=False): - if datatype not in DATATYPES: - raise ValueError('Invalid data type') - self.name = name - self.code = code - self.type = datatype - self.vendor = vendor - self.encrypt = encrypt - self.has_tag = has_tag - self.values = bidict.BiDict() - self.sub_attributes = {} - self.parent = None - self.is_sub_attribute = is_sub_attribute - if values: - for (key, value) in values.items(): - self.values.add(key, value) - - -class Dictionary(): - """RADIUS dictionary class. - This class stores all information about vendors, attributes and their - values as defined in RADIUS dictionary files. - - :ivar vendors: bidict mapping vendor name to vendor code - :type vendors: bidict - :ivar attrindex: bidict mapping - :type attrindex: bidict - :ivar attributes: bidict mapping attribute name to attribute class - :type attributes: bidict - """ - - def __init__(self, dict=None, *dicts): - """ - :param dict: path of dictionary file or file-like object to read - :type dict: string or file - :param dicts: list of dictionaries - :type dicts: sequence of strings or files - """ - self.vendors = bidict.BiDict() - self.vendors.add('', 0) - self.attrindex = bidict.BiDict() - self.attributes = {} - self.defer_parse = [] - - if dict: - self.read_dictionary(dict) - - for i in dicts: - self.read_dictionary(i) - - def __len__(self): - return len(self.attributes) - - def __getitem__(self, key): - return self.attributes[key] - - def __contains__(self, key): - return key in self.attributes - - has_key = __contains__ - - def __parse_attribute(self, state, tokens): - if not len(tokens) in [4, 5]: - raise ParseError( - 'Incorrect number of tokens for attribute definition', - name=state['file'], - line=state['line']) - - vendor = state['vendor'] - has_tag = False - encrypt = 0 - if len(tokens) >= 5: - def keyval(o): - kv = o.split('=') - if len(kv) == 2: - return (kv[0], kv[1]) - else: - return (kv[0], None) - options = [keyval(o) for o in tokens[4].split(',')] - - for (key, val) in options: - if key == 'has_tag': - has_tag = True - elif key == 'encrypt': - if val not in ['1', '2', '3']: - raise ParseError( - f'Illegal attribute encryption: {val}', - file=state['file'], - line=state['line']) - encrypt = int(val) - - if (not has_tag) and encrypt == 0: - vendor = tokens[4] - if not self.vendors.has_forward(vendor): - if vendor == "concat": - # ignore attributes with concat (freeradius compat.) - return None - else: - raise ParseError('Unknown vendor ' + vendor, - file=state['file'], - line=state['line']) - - (attribute, code, datatype) = tokens[1:4] - - codes = code.split('.') - - # Codes can be sent as hex, or octal or decimal string representations. - tmp = [] - for c in codes: - if c.startswith('0x'): - tmp.append(int(c, 16)) - elif c.startswith('0o'): - tmp.append(int(c, 8)) - else: - tmp.append(int(c, 10)) - codes = tmp - - is_sub_attribute = (len(codes) > 1) - if len(codes) == 2: - code = int(codes[1]) - parent_code = int(codes[0]) - elif len(codes) == 1: - code = int(codes[0]) - parent_code = None - else: - raise ParseError('nested tlvs are not supported') - - datatype = datatype.split("[")[0] - - if datatype not in DATATYPES: - raise ParseError('Illegal type: ' + datatype, - file=state['file'], - line=state['line']) - if vendor: - if is_sub_attribute: - key = (self.vendors.get_forward(vendor), parent_code, code) - else: - key = (self.vendors.get_forward(vendor), code) - else: - if is_sub_attribute: - key = (parent_code, code) - else: - key = code - - self.attrindex.add(attribute, key) - self.attributes[attribute] = Attribute(attribute, code, datatype, is_sub_attribute, - vendor, encrypt=encrypt, has_tag=has_tag) - if datatype == 'tlv': - # save attribute in tlvs - state['tlvs'][code] = self.attributes[attribute] - if is_sub_attribute: - # save sub attribute in parent tlv and update their parent field - state['tlvs'][parent_code].sub_attributes[code] = attribute - self.attributes[attribute].parent = state['tlvs'][parent_code] - - def __parse_value(self, state, tokens, defer): - if len(tokens) != 4: - raise ParseError('Incorrect number of tokens for value definition', - file=state['file'], - line=state['line']) - - (attr, key, value) = tokens[1:] - - try: - adef = self.attributes[attr] - except KeyError: - if defer: - self.defer_parse.append((copy(state), copy(tokens))) - return - raise ParseError('Value defined for unknown attribute ' + attr, - file=state['file'], - line=state['line']) - - if adef.type in ['integer', 'signed', 'short', 'byte', 'integer64']: - value = int(value, 0) - value = tools.EncodeAttr(adef.type, value) - self.attributes[attr].values.add(key, value) - - def __parse_vendor(self, state, tokens): - if len(tokens) not in [3, 4]: - raise ParseError( - 'Incorrect number of tokens for vendor definition', - file=state['file'], - line=state['line']) - - # Parse format specification, but do - # nothing about it for now - if len(tokens) == 4: - fmt = tokens[3].split('=') - if fmt[0] != 'format': - raise ParseError( - f"Unknown option '{fmt[0]}' for vendor definition", - file=state['file'], - line=state['line']) - try: - (t, l) = tuple(int(a) for a in fmt[1].split(',')) - if t not in [1, 2, 4] or l not in [0, 1, 2]: - raise ParseError( - f'Unknown vendor format specification {fmt[1]}', - file=state['file'], - line=state['line']) - except ValueError: - raise ParseError( - 'Syntax error in vendor specification', - file=state['file'], - line=state['line']) - - (vendorname, vendor) = tokens[1:3] - self.vendors.add(vendorname, int(vendor, 0)) - - def __parse_begin_vendor(self, state, tokens): - if len(tokens) != 2: - raise ParseError( - 'Incorrect number of tokens for begin-vendor statement', - file=state['file'], - line=state['line']) - - vendor = tokens[1] - - if not self.vendors.has_forward(vendor): - raise ParseError( - f'Unknown vendor {vendor} in begin-vendor statement', - file=state['file'], - line=state['line']) - - state['vendor'] = vendor - - def __parse_end_vendor(self, state, tokens): - if len(tokens) != 2: - raise ParseError( - 'Incorrect number of tokens for end-vendor statement', - file=state['file'], - line=state['line']) - - vendor = tokens[1] - - if state['vendor'] != vendor: - raise ParseError( - 'Ending non-open vendor' + vendor, - file=state['file'], - line=state['line']) - state['vendor'] = '' - - def read_dictionary(self, file): - """Parse a dictionary file. - Reads a RADIUS dictionary file and merges its contents into the - class instance. - - :param file: Name of dictionary file to parse or a file-like object - :type file: string or file-like object - """ - - fil = dictfile.DictFile(file) - - state = {} - state['vendor'] = '' - state['tlvs'] = {} - self.defer_parse = [] - for line in fil: - state['file'] = fil.file() - state['line'] = fil.line() - line = line.split('#', 1)[0].strip() - - tokens = line.split() - if not tokens: - continue - - key = tokens[0].upper() - if key == 'ATTRIBUTE': - self.__parse_attribute(state, tokens) - elif key == 'VALUE': - self.__parse_value(state, tokens, True) - elif key == 'VENDOR': - self.__parse_vendor(state, tokens) - elif key == 'BEGIN-VENDOR': - self.__parse_begin_vendor(state, tokens) - elif key == 'END-VENDOR': - self.__parse_end_vendor(state, tokens) - - for state, tokens in self.defer_parse: - key = tokens[0].upper() - if key == 'VALUE': - self.__parse_value(state, tokens, False) - self.defer_parse = [] diff --git a/pyrad3/host.py b/pyrad3/host.py deleted file mode 100644 index 9e4d49a..0000000 --- a/pyrad3/host.py +++ /dev/null @@ -1,100 +0,0 @@ -# host.py -# -# Copyright 2003,2007 Wichert Akkerman -from pyrad import packet - - -class Host(object): - """Generic RADIUS capable host. - - :ivar dict: RADIUS dictionary - :type dict: pyrad.dictionary.Dictionary - :ivar authport: port to listen on for authentication packets - :type authport: integer - :ivar acctport: port to listen on for accounting packets - :type acctport: integer - """ - def __init__(self, authport=1812, acctport=1813, coaport=3799, dict=None): - """Constructor - - :param authport: port to listen on for authentication packets - :type authport: integer - :param acctport: port to listen on for accounting packets - :type acctport: integer - :param coaport: port to listen on for CoA packets - :type coaport: integer - :param dict: RADIUS dictionary - :type dict: pyrad.dictionary.Dictionary - """ - self.dict = dict - self.authport = authport - self.acctport = acctport - self.coaport = coaport - - def CreatePacket(self, **args): - """Create a new RADIUS packet. - This utility function creates a new RADIUS authentication - packet which can be used to communicate with the RADIUS server - this client talks to. This is initializing the new packet with - the dictionary and secret used for the client. - - :return: a new empty packet instance - :rtype: pyrad.packet.Packet - """ - return packet.Packet(dict=self.dict, **args) - - def CreateAuthPacket(self, **args): - """Create a new authentication RADIUS packet. - This utility function creates a new RADIUS authentication - packet which can be used to communicate with the RADIUS server - this client talks to. This is initializing the new packet with - the dictionary and secret used for the client. - - :return: a new empty packet instance - :rtype: pyrad.packet.AuthPacket - """ - return packet.AuthPacket(dict=self.dict, **args) - - def CreateAcctPacket(self, **args): - """Create a new accounting RADIUS packet. - This utility function creates a new accouting RADIUS packet - which can be used to communicate with the RADIUS server this - client talks to. This is initializing the new packet with the - dictionary and secret used for the client. - - :return: a new empty packet instance - :rtype: pyrad.packet.AcctPacket - """ - return packet.AcctPacket(dict=self.dict, **args) - - def CreateCoAPacket(self, **args): - """Create a new CoA RADIUS packet. - This utility function creates a new CoA RADIUS packet - which can be used to communicate with the RADIUS server this - client talks to. This is initializing the new packet with the - dictionary and secret used for the client. - - :return: a new empty packet instance - :rtype: pyrad.packet.CoAPacket - """ - return packet.CoAPacket(dict=self.dict, **args) - - def SendPacket(self, fd, pkt): - """Send a packet. - - :param fd: socket to send packet with - :type fd: socket class instance - :param pkt: packet to send - :type pkt: Packet class instance - """ - fd.sendto(pkt.Packet(), pkt.source) - - def SendReplyPacket(self, fd, pkt): - """Send a packet. - - :param fd: socket to send packet with - :type fd: socket class instance - :param pkt: packet to send - :type pkt: Packet class instance - """ - fd.sendto(pkt.ReplyPacket(), pkt.source) diff --git a/pyrad3/new_host.py b/pyrad3/new_host.py deleted file mode 100644 index 3ee8563..0000000 --- a/pyrad3/new_host.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2020 Istvan Ruzman -# SPDX-License-Identifier: MIT OR Apache-2.0 -from pyrad3 import new_packet as packet - -class Host: - def __init__(self, secret, radius_dict, - authport=1812, acctport=1813, coaport=3799, - timeout=30, retries=3): - self.secret = secret - self.dictionary = radius_dict - - self.authport = authport - self.acctport = acctport - self.coaport = coaport - - self.timeout = timeout - self.retries = retries - - def create_packet(self, **kwargs): - return packet.Packet(self, **kwargs) - - def create_auth_packet(self, **kwargs): - return packet.AuthPacket(self, **kwargs) - - def create_acct_packet(self, **kwargs): - return packet.AcctPacket(self, **kwargs) - - def create_coa_packet(self, **kwargs): - return packet.CoAPacket(self, **kwargs) diff --git a/pyrad3/new_packet.py b/pyrad3/new_packet.py deleted file mode 100644 index 63e547f..0000000 --- a/pyrad3/new_packet.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2020 Istvan Ruzman -# SPDX-License-Identifier: MIT OR Apache-2.0 - -from collections import OrderedDict - -import hashlib -import secrets -import struct - -# Packet codes -AccessRequest = 1 -AccessAccept = 2 -AccessReject = 3 -AccountingRequest = 4 -AccountingResponse = 5 -AccessChallenge = 11 -StatusServer = 12 -StatusClient = 13 -DisconnectRequest = 40 -DisconnectACK = 41 -DisconnectNAK = 42 -CoARequest = 43 -CoAACK = 44 -CoANAK = 45 - - -class PacketError(Exception): - pass - -class AuthError(Exception): - pass - -RANDOM_GENERATOR = secrets.SystemRandom() - - -class Packet(OrderedDict): - def __init__(self, host, code, radius_id, *, request=None, **attributes): - super().__init__(**attributes) - self.code = code - self.id = radius_id - self.host = host - self.request = request - - def send(self): - self.host._send_packet(self) - - def verify_reply(self, raw_reply): - if self.id != raw_reply[1]: - raise PacketError("Response has a wrong id") - - radius_hash = self.calculate_radius_hash(raw_reply) - - if radius_hash != raw_reply[4:20]: - raise PacketError("Reply Packet has a wrong authenticator") - - return self.parse_raw_packet(raw_reply) - - def calculate_radius_hash(self, data, authenticator=None): - if authenticator is None: - authenticator = self.authenticator - return hashlib.md5(data[0:4] + authenticator + - data[20:] + self.secret).digest() - - def parse_raw_reply(self, raw_packet): - (code, radius_id, length, authenticator) = self._parse_header(raw_packet) - attrs = self.parse_attributes(raw_packet) - return Packet(self.host, code, radius_id, **attrs, request=self) - - def _parse_header(self, raw_packet): - try: - header = struct.unpack('!BBH16s', raw_packet) - except struct.error: - raise PacketError('Packet header is corrupt') - - length = header[3] - if len(raw_packet) != length: - raise PacketError('Packet has invalid length') - if length > 4096: - raise PacketError(f'Packet length is too big ({length})') - return header - - -class AuthPacket(Packet): - def __init__(self, host, radius_id, auth_type, *, - code=AccessRequest, request=None, - **attributes): - super().__init__(host, code, radius_id, - request=request, **attributes) - self.auth_type = auth_type - if code == AccessRequest: - self.authenticator = secrets.token_bytes(16) - - def create_accept(self, **attributes): - return AuthPacket(self.host, self.id, self.auth_type, - request=self, - code=AccessAccept - **attributes) - - def create_reject(self, **attributes): - return AuthPacket(self.host, self.id, self.auth_type, - request=self, - code=AccessReject, - **attributes) - - def create_challange(self, **attributes): - return AuthPacket(self.host, self.id, self.auth_type, - request=self, - code=AccessChallenge, - **attributes) - - -class AcctPacket(Packet): - def __init__(self, host, radius_id, *, - code=AccountingRequest, request=None, - **attributes): - super().__init__(host, code, radius_id, - request=request, **attributes) - - def create_response(self, **attributes): - return AcctPacket(self.host, self.id, - code=AccountingResponse, - request=self, - **attributes) - - -class CoAPacket(Packet): - def __init__(self, host, radius_id, *, - code=CoARequest, request=None, - **attributes): - super().__init__(host, code, radius_id, - request=request, **attributes) - - def create_ack(self, **attributes): - return CoAPacket(self.host, self.id, - code=CoAACK, - request=self, - **attributes) - - def create_nack(self, **attributes): - return CoAPacket(self.host, self.id, - code=CoANACK, - request=self, - **attributes) diff --git a/pyrad3/packet.py b/pyrad3/packet.py deleted file mode 100644 index 26a7623..0000000 --- a/pyrad3/packet.py +++ /dev/null @@ -1,869 +0,0 @@ -# packet.py -# -# Copyright 2002-2005,2007 Wichert Akkerman -# -# A RADIUS packet as defined in RFC 2138 - -from collections import OrderedDict -import hashlib -import hmac -import secrets -import struct -from pyrad import tools - -md5_constructor = hashlib.md5 - -# Packet codes -AccessRequest = 1 -AccessAccept = 2 -AccessReject = 3 -AccountingRequest = 4 -AccountingResponse = 5 -AccessChallenge = 11 -StatusServer = 12 -StatusClient = 13 -DisconnectRequest = 40 -DisconnectACK = 41 -DisconnectNAK = 42 -CoARequest = 43 -CoAACK = 44 -CoANAK = 45 - -# Use cryptographic-safe random generator as provided by the OS. -random_generator = secrets.SystemRandom() - -# Current ID -CurrentID = random_generator.randrange(1, 255) - - -class PacketError(Exception): - pass - - -class Packet(OrderedDict): - """Packet acts like a standard python map to provide simple access - to the RADIUS attributes. Since RADIUS allows for repeated - attributes the value will always be a sequence. pyrad makes sure - to preserve the ordering when encoding and decoding packets. - - There are two ways to use the map intereface: if attribute - names are used pyrad take care of en-/decoding data. If - the attribute type number (or a vendor ID/attribute type - tuple for vendor attributes) is used you work with the - raw data. - - Normally you will not use this class directly, but one of the - :obj:`AuthPacket` or :obj:`AcctPacket` classes. - """ - - def __init__(self, code=0, id=None, secret=b'', authenticator=None, - **attributes): - """Constructor - - :param dict: RADIUS dictionary - :type dict: pyrad.dictionary.Dictionary class - :param secret: secret needed to communicate with a RADIUS server - :type secret: string - :param id: packet identification number - :type id: integer (8 bits) - :param code: packet type code - :type code: integer (8bits) - :param packet: raw packet to decode - :type packet: string - """ - super().__init__() - self.code = code - if id is not None: - self.id = id - else: - self.id = CreateID() - if not isinstance(secret, bytes): - raise TypeError('secret must be a binary string') - self.secret = secret - if authenticator is not None and \ - not isinstance(authenticator, bytes): - raise TypeError('authenticator must be a binary string') - self.authenticator = authenticator - self.message_authenticator = None - - if 'dict' in attributes: - self.dict = attributes['dict'] - - if 'packet' in attributes: - self.raw_packet = attributes['packet'] - self.DecodePacket(attributes['packet']) - - if 'message_authenticator' in attributes: - self.message_authenticator = attributes['message_authenticator'] - - for (key, value) in attributes.items(): - if key in [ - 'dict', 'fd', 'packet', - 'message_authenticator', - ]: - continue - key = key.replace('_', '-') - self.AddAttribute(key, value) - - def add_message_authenticator(self): - self.message_authenticator = True - # Maintain a zero octets content for md5 and hmac calculation. - self['Message-Authenticator'] = 16 * b'\00' - - if self.id is None: - self.id = self.CreateID() - - if self.authenticator is None and self.code == AccessRequest: - self.authenticator = self.CreateAuthenticator() - self._refresh_message_authenticator() - - def get_message_authenticator(self): - self._refresh_message_authenticator() - return self.message_authenticator - - def _refresh_message_authenticator(self): - hmac_constructor = hmac.new(self.secret, digestmod=hashlib.md5) - - # Maintain a zero octets content for md5 and hmac calculation. - self['Message-Authenticator'] = 16 * b'\00' - attr = self._pkt_encode_attributes() - - header = struct.pack('!BBH', self.code, self.id, - (20 + len(attr))) - - hmac_constructor.update(header[0:4]) - if self.code in (AccountingRequest, DisconnectRequest, - CoARequest, AccountingResponse): - hmac_constructor.update(16 * b'\00') - else: - # NOTE: self.authenticator on reply packet is initialized - # with request authenticator by design. - # For AccessAccept, AccessReject and AccessChallenge - # it is needed use original Authenticator. - # For AccessAccept, AccessReject and AccessChallenge - # it is needed use original Authenticator. - if self.authenticator is None: - raise Exception('No authenticator found') - hmac_constructor.update(self.authenticator) - - hmac_constructor.update(attr) - self['Message-Authenticator'] = hmac_constructor.digest() - - def verify_message_authenticator(self, - original_authenticator=None, - original_code=None): - """Verify packet Message-Authenticator. - - :return: False if verification failed else True - :rtype: boolean - """ - if self.message_authenticator is None: - raise Exception('No Message-Authenticator AVP present') - - prev_ma = self['Message-Authenticator'] - - self['Message-Authenticator'] = 16 * b'\00' - attr = self._pkt_encode_attributes() - - header = struct.pack('!BBH', self.code, self.id, - (20 + len(attr))) - - hmac_constructor = hmac.new(self.secret, digestmod=hashlib.md5) - hmac_constructor.update(header) - if self.code in (AccountingRequest, DisconnectRequest, - CoARequest, AccountingResponse): - if original_code is None or original_code != StatusServer: - # TODO: Handle Status-Server response correctly. - hmac_constructor.update(16 * b'\00') - elif self.code in (AccessAccept, AccessChallenge, - AccessReject): - if original_authenticator is None: - if self.authenticator: - # NOTE: self.authenticator on reply packet is initialized - # with request authenticator by design. - original_authenticator = self.authenticator - else: - raise Exception('Missing original authenticator') - - hmac_constructor.update(original_authenticator) - else: - # On Access-Request and Status-Server use dynamic authenticator - hmac_constructor.update(self.authenticator) - - hmac_constructor.update(attr) - self['Message-Authenticator'] = prev_ma[0] - return prev_ma[0] == hmac_constructor.digest() - - def CreateReply(self, **attributes): - """Create a new packet as a reply to this one. This method - makes sure the authenticator and secret are copied over - to the new instance. - """ - return Packet(id=self.id, secret=self.secret, - authenticator=self.authenticator, dict=self.dict, - **attributes) - - def _decode_value(self, attr, value): - try: - return attr.values.get_backward(value) - except KeyError: - return tools.DecodeAttr(attr.type, value) - - def _encode_value(self, attr, value): - try: - result = attr.values.get_forward(value) - except KeyError: - result = tools.EncodeAttr(attr.type, value) - - if attr.encrypt == 2: - # salt encrypt attribute - result = self.SaltCrypt(result) - - return result - - def _encode_key_values(self, key, values): - if not isinstance(key, str): - return (key, values) - - if not isinstance(values, (list, tuple)): - values = [values] - - key, _, tag = key.partition(":") - attr = self.dict.attributes[key] - key = self._encode_key(key) - if attr.has_tag: - tag = '0' if tag == '' else tag - tag = struct.pack('B', int(tag)) - if attr.type == "integer": - # When a tagged value has the type int only 3 bytes are used - # the first byte is the tag itself, so we need to shorten our int - return (key, [tag + self._encode_value(attr, v)[1:] for v in values]) - else: - return (key, [tag + self._encode_value(attr, v) for v in values]) - else: - return (key, [self._encode_value(attr, v) for v in values]) - - def _encode_key(self, key): - if not isinstance(key, str): - return key - - attr = self.dict.attributes[key] - # sub attribute keys don't need vendor - if attr.vendor and not attr.is_sub_attribute: - return (self.dict.vendors.get_forward(attr.vendor), attr.code) - else: - return attr.code - - def _decode_key(self, key): - """Turn a key into a string if possible""" - try: - return self.dict.attrindex.get_backward(key) - except KeyError: - return key - - def AddAttribute(self, key, value): - """Add an attribute to the packet. - - :param key: attribute name or identification - :type key: string, attribute code or (vendor code, attribute code) - tuple - :param value: value - :type value: depends on type of attribute - """ - attr = self.dict.attributes[key.partition(':')[0]] - - (key, value) = self._encode_key_values(key, value) - - if attr.is_sub_attribute: - tlv = self.setdefault(self._encode_key(attr.parent.name), {}) - encoded = tlv.setdefault(key, []) - else: - encoded = self.setdefault(key, []) - - encoded.extend(value) - - def get(self, key, failobj=None): - try: - res = self.__getitem__(key) - except KeyError: - res = failobj - return res - - def __getitem__(self, key): - if not isinstance(key, str): - return super().__getitem__(key) - - values = super().__getitem__(self._encode_key(key)) - attr = self.dict.attributes[key] - if attr.type == 'tlv': # return map from sub attribute code to its values - res = {} - for (sub_attr_key, sub_attr_val) in values.items(): - sub_attr_name = attr.sub_attributes[sub_attr_key] - sub_attr = self.dict.attributes[sub_attr_name] - for v in sub_attr_val: - res.setdefault(sub_attr_name, []).append(self._decode_value(sub_attr, v)) - return res - else: - res = [] - for v in values: - res.append(self._decode_value(attr, v)) - return res - - def __contains__(self, key): - try: - return super().__contains__(self._encode_key(key)) - except KeyError: - return False - - has_key = __contains__ - - def __delitem__(self, key): - super().__delitem__(self._encode_key(key)) - - def __setitem__(self, key, item): - if isinstance(key, str): - (key, item) = self._encode_key_values(key, item) - super().__setitem__(key, item) - - def keys(self): - return [self._decode_key(key) for key in OrderedDict.keys(self)] - - @staticmethod - def CreateAuthenticator(): - """Create a packet authenticator. All RADIUS packets contain a sixteen - byte authenticator which is used to authenticate replies from the - RADIUS server and in the password hiding algorithm. This function - returns a suitable random string that can be used as an authenticator. - - :return: valid packet authenticator - :rtype: binary string - """ - - return secrets.token_bytes(16) - - def CreateID(self): - """Create a packet ID. All RADIUS requests have a ID which is used to - identify a request. This is used to detect retries and replay attacks. - This function returns a suitable random number that can be used as ID. - - :return: ID number - :rtype: integer - - """ - return int.from_bytes(secrets.token_bytes(1), 'little') - - def ReplyPacket(self): - """Create a ready-to-transmit authentication reply packet. - Returns a RADIUS packet which can be directly transmitted - to a RADIUS server. This differs with Packet() in how - the authenticator is calculated. - - :return: raw packet - :rtype: string - """ - assert self.authenticator - assert self.secret is not None - - if self.message_authenticator: - self._refresh_message_authenticator() - - attr = self._pkt_encode_attributes() - header = struct.pack('!BBH', self.code, self.id, (20 + len(attr))) - - authenticator = md5_constructor(header[0:4] + self.authenticator - + attr + self.secret).digest() - - return header + authenticator + attr - - def VerifyReply(self, rawreply): - if int(rawreply[1]) != self.id: - raise PacketError("Reply Packet has wrong id") - - # The Authenticator field in an Accounting-Response packet is called - # the Response Authenticator, and contains a one-way MD5 hash - # calculated over a stream of octets consisting of the Accounting - # Response Code, Identifier, Length, the Request Authenticator field - # from the Accounting-Request packet being replied to, and the - # response attributes if any, followed by the shared secret. The - # resulting 16 octet MD5 hash value is stored in the Authenticator - # field of the Accounting-Response packet. - hash = md5_constructor(rawreply[0:4] + self.authenticator + - rawreply[20:] + self.secret).digest() - - if hash != rawreply[4:20]: - raise PacketError("Reply Packet has a wrong authenticator") - return self.CreateReply(packet=rawreply) - - def _pkt_encode_attribute(self, key, value): - if isinstance(key, tuple): - value = struct.pack('!L', key[0]) + \ - self._pkt_encode_attribute(key[1], value) - key = 26 - - return struct.pack('!BB', key, (len(value) + 2)) + value - - def _pkt_encode_tlv(self, tlv_key, tlv_value): - tlv_attr = self.dict.attributes[self._decode_key(tlv_key)] - curr_avp = b'' - avps = [] - max_sub_attribute_len = max(map(lambda item: len(item[1]), tlv_value.items())) - for i in range(max_sub_attribute_len): - sub_attr_encoding = b'' - for (code, datalst) in tlv_value.items(): - if i < len(datalst): - sub_attr_encoding += self._pkt_encode_attribute(code, datalst[i]) - # split above 255. assuming len of one instance of all sub tlvs is lower than 255 - if (len(sub_attr_encoding) + len(curr_avp)) < 245: - curr_avp += sub_attr_encoding - else: - avps.append(curr_avp) - curr_avp = sub_attr_encoding - avps.append(curr_avp) - tlv_avps = [] - for avp in avps: - value = struct.pack('!BB', tlv_attr.code, (len(avp) + 2)) + avp - tlv_avps.append(value) - if tlv_attr.vendor: - vendor_avps = b'' - for avp in tlv_avps: - vendor_avps += struct.pack( - '!BBL', 26, (len(avp) + 6), - self.dict.vendors.get_forward(tlv_attr.vendor) - ) + avp - return vendor_avps - else: - return b''.join(tlv_avps) - - def _pkt_encode_attributes(self): - result = b'' - for (code, datalst) in self.items(): - attribute = self.dict.attributes.get(self._decode_key(code)) - if attribute and attribute.type == 'tlv': - result += self._pkt_encode_tlv(code, datalst) - else: - for data in datalst: - result += self._pkt_encode_attribute(code, data) - return result - - def _pkt_decode_vendor_attribute(self, data): - # Check if this packet is long enough to be in the - # RFC2865 recommended form - if len(data) < 6: - return [(26, data)] - - (vendor, atype, length) = struct.unpack('!LBB', data[:6])[0:3] - attribute = self.dict.attributes.get(self._decode_key((vendor, atype))) - try: - if attribute and attribute.type == 'tlv': - self._pkt_decode_tlv_attribute((vendor, atype), data[6:length + 4]) - tlvs = [] # tlv is added to the packet inside _pkt_decode_tlv_attribute - else: - tlvs = [((vendor, atype), data[6:length + 4])] - except: - return [(26, data)] - - sumlength = 4 + length - while len(data) > sumlength: - try: - atype, length = struct.unpack('!BB', data[sumlength:sumlength+2])[0:2] - except: - return [(26, data)] - tlvs.append(((vendor, atype), data[sumlength+2:sumlength+length])) - sumlength += length - return tlvs - - def _pkt_decode_tlv_attribute(self, code, data): - sub_attributes = self.setdefault(code, {}) - loc = 0 - - while loc < len(data): - atype, length = struct.unpack('!BB', data[loc:loc+2])[0:2] - sub_attributes.setdefault(atype, []).append(data[loc+2:loc+length]) - loc += length - - def DecodePacket(self, packet): - """Initialize the object from raw packet data. Decode a packet as - received from the network and decode it. - - :param packet: raw packet - :type packet: string""" - - try: - (self.code, self.id, length, self.authenticator) = \ - struct.unpack('!BBH16s', packet[0:20]) - - except struct.error: - raise PacketError('Packet header is corrupt') - if len(packet) != length: - raise PacketError('Packet has invalid length') - if length > 4096: - raise PacketError(f'Packet length is too long ({length})') - - self.clear() - - packet = packet[20:] - while packet: - try: - (key, attrlen) = struct.unpack('!BB', packet[0:2]) - except struct.error: - raise PacketError('Attribute header is corrupt') - - if attrlen < 2: - raise PacketError(f'Attribute length is too small (attrlen)') - - value = packet[2:attrlen] - attribute = self.dict.attributes.get(self._decode_key(key)) - if key == 26: - for (key, value) in self._pkt_decode_vendor_attribute(value): - self.setdefault(key, []).append(value) - elif key == 80: - # POST: Message Authenticator AVP is present. - self.message_authenticator = True - self.setdefault(key, []).append(value) - elif attribute and attribute.type == 'tlv': - self._pkt_decode_tlv_attribute(key, value) - else: - self.setdefault(key, []).append(value) - - packet = packet[attrlen:] - - def SaltCrypt(self, value): - """Salt Encryption - - :param value: plaintext value - :type password: unicode string - :return: obfuscated version of the value - :rtype: binary string - """ - - if isinstance(value, str): - value = value.encode('utf-8') - - if self.authenticator is None: - # self.authenticator = self.CreateAuthenticator() - self.authenticator = 16 * b'\x00' - - random_value = 32768 + random_generator.randrange(0, 32767) - result = struct.pack('!H', random_value) - - length = struct.pack("B", len(value)) - buf = length + value - if len(buf) % 16 != 0: - buf += b'\x00' * (16 - (len(buf) % 16)) - - last = self.authenticator + result - while buf: - cur_hash = md5_constructor(self.secret + last).digest() - for b, h in zip(buf, cur_hash): - result += bytes([b ^ h]) - last = result[-16:] - buf = buf[16:] - - return result - - -class AuthPacket(Packet): - def __init__(self, code=AccessRequest, id=None, secret=b'', - authenticator=None, auth_type='pap', **attributes): - """Constructor - - :param code: packet type code - :type code: integer (8bits) - :param id: packet identification number - :type id: integer (8 bits) - :param secret: secret needed to communicate with a RADIUS server - :type secret: string - - :param dict: RADIUS dictionary - :type dict: pyrad.dictionary.Dictionary class - - :param packet: raw packet to decode - :type packet: string - """ - - Packet.__init__(self, code, id, secret, authenticator, **attributes) - self.auth_type = auth_type - - def CreateReply(self, **attributes): - """Create a new packet as a reply to this one. This method - makes sure the authenticator and secret are copied over - to the new instance. - """ - return AuthPacket(AccessAccept, self.id, - self.secret, self.authenticator, dict=self.dict, - auth_type=self.auth_type, **attributes) - - def RequestPacket(self): - """Create a ready-to-transmit authentication request packet. - Return a RADIUS packet which can be directly transmitted - to a RADIUS server. - - :return: raw packet - :rtype: string - """ - if self.authenticator is None: - self.authenticator = self.CreateAuthenticator() - - if self.id is None: - self.id = self.CreateID() - - if self.message_authenticator: - self._refresh_message_authenticator() - - attr = self._pkt_encode_attributes() - if self.auth_type == 'eap-md5': - header = struct.pack( - '!BBH16s', self.code, self.id, (20 + 18 + len(attr)), self.authenticator - ) - msg = header \ - + attr \ - + struct.pack('!BB', 80, struct.calcsize('!BB16s')), - digest = hmac.new(self.secret, msg, digestmod=hashlib.md5).digest() - return msg + digest - - header = struct.pack('!BBH16s', self.code, self.id, - (20 + len(attr)), self.authenticator) - - return header + attr - - def PwDecrypt(self, password): - """Obfuscate a RADIUS password. RADIUS hides passwords in packets by - using an algorithm based on the MD5 hash of the packet authenticator - and RADIUS secret. This function reverses the obfuscation process. - - :param password: obfuscated form of password - :type password: binary string - :return: plaintext password - :rtype: unicode string - """ - pw = self.radius_password_pseudo_hash(password).rstrip(b'\x00') - - return pw.decode('utf-8') - - def PwCrypt(self, password): - """Obfuscate password. - RADIUS hides passwords in packets by using an algorithm - based on the MD5 hash of the packet authenticator and RADIUS - secret. If no authenticator has been set before calling PwCrypt - one is created automatically. Changing the authenticator after - setting a password that has been encrypted using this function - will not work. - - :param password: plaintext password - :type password: unicode string - :return: obfuscated version of the password - :rtype: binary string - """ - if self.authenticator is None: - self.authenticator = self.CreateAuthenticator() - - if isinstance(password, str): - password = password.encode('utf-8') - - buf = password - if len(password) % 16 != 0: - buf += b'\x00' * (16 - (len(password) % 16)) - - return self.radius_password_pseudo_hash(buf) - - def radius_password_pseudo_hash(self, password): - result = b'' - buf = password - last = self.authenticator - - while buf: - cur_hash = md5_constructor(self.secret + last).digest() - for b, h in zip(buf, cur_hash): - result += bytes([b ^ h]) - - (last, buf) = (buf[:16], buf[16:]) - - return result - - def VerifyChapPasswd(self, userpwd): - """ Verify RADIUS ChapPasswd - - :param userpwd: plaintext password - :type userpwd: str - :return: is verify ok - :rtype: bool - """ - - if not self.authenticator: - self.authenticator = self.CreateAuthenticator() - - if isinstance(userpwd, str): - userpwd = userpwd.strip().encode('utf-8') - - chap_password = tools.DecodeOctets(self.get(3)[0]) - if len(chap_password) != 17: - return False - - chapid = chr(chap_password[0]).encode('utf-8') - password = chap_password[1:] - - challenge = self.authenticator - if 'CHAP-Challenge' in self: - challenge = self['CHAP-Challenge'][0] - return password == md5_constructor(chapid + userpwd + challenge).digest() - - def VerifyAuthRequest(self): - """Verify request authenticator. - - :return: True if verification failed else False - :rtype: boolean - """ - assert self.raw_packet - hash = md5_constructor(self.raw_packet[0:4] + 16 * b'\x00' + - self.raw_packet[20:] + self.secret).digest() - return hash == self.authenticator - - -class AcctPacket(Packet): - """RADIUS accounting packets. This class is a specialization - of the generic :obj:`Packet` class for accounting packets. - """ - - def __init__(self, code=AccountingRequest, id=None, secret=b'', - authenticator=None, **attributes): - """Constructor - - :param dict: RADIUS dictionary - :type dict: pyrad.dictionary.Dictionary class - :param secret: secret needed to communicate with a RADIUS server - :type secret: string - :param id: packet identification number - :type id: integer (8 bits) - :param code: packet type code - :type code: integer (8bits) - :param packet: raw packet to decode - :type packet: string - """ - Packet.__init__(self, code, id, secret, authenticator, **attributes) - - def CreateReply(self, **attributes): - """Create a new packet as a reply to this one. This method - makes sure the authenticator and secret are copied over - to the new instance. - """ - return AcctPacket(AccountingResponse, self.id, - self.secret, self.authenticator, dict=self.dict, - **attributes) - - def VerifyAcctRequest(self): - """Verify request authenticator. - - :return: False if verification failed else True - :rtype: boolean - """ - assert self.raw_packet - - hash = md5_constructor(self.raw_packet[0:4] + 16 * b'\x00' + - self.raw_packet[20:] + self.secret).digest() - - return hash == self.authenticator - - def RequestPacket(self): - """Create a ready-to-transmit authentication request packet. - Return a RADIUS packet which can be directly transmitted - to a RADIUS server. - - :return: raw packet - :rtype: string - """ - - if self.id is None: - self.id = self.CreateID() - - if self.message_authenticator: - self._refresh_message_authenticator() - - attr = self._pkt_encode_attributes() - header = struct.pack('!BBH', self.code, self.id, (20 + len(attr))) - self.authenticator = md5_constructor(header[0:4] + 16 * b'\x00' + - attr + self.secret).digest() - - ans = header + self.authenticator + attr - - return ans - - -class CoAPacket(Packet): - """RADIUS CoA packets. This class is a specialization - of the generic :obj:`Packet` class for CoA packets. - """ - - def __init__(self, code=CoARequest, id=None, secret=b'', - authenticator=None, **attributes): - """Constructor - - :param dict: RADIUS dictionary - :type dict: pyrad.dictionary.Dictionary class - :param secret: secret needed to communicate with a RADIUS server - :type secret: string - :param id: packet identification number - :type id: integer (8 bits) - :param code: packet type code - :type code: integer (8bits) - :param packet: raw packet to decode - :type packet: string - """ - Packet.__init__(self, code, id, secret, authenticator, **attributes) - - def CreateReply(self, **attributes): - """Create a new packet as a reply to this one. This method - makes sure the authenticator and secret are copied over - to the new instance. - """ - return CoAPacket(CoAACK, self.id, - self.secret, self.authenticator, dict=self.dict, - **attributes) - - def VerifyCoARequest(self): - """Verify request authenticator. - - :return: False if verification failed else True - :rtype: boolean - """ - assert self.raw_packet - hash = md5_constructor(self.raw_packet[0:4] + 16 * b'\x00' + - self.raw_packet[20:] + self.secret).digest() - return hash == self.authenticator - - def RequestPacket(self): - """Create a ready-to-transmit CoA request packet. - Return a RADIUS packet which can be directly transmitted - to a RADIUS server. - - :return: raw packet - :rtype: string - """ - - attr = self._pkt_encode_attributes() - - if self.id is None: - self.id = self.CreateID() - - header = struct.pack('!BBH', self.code, self.id, (20 + len(attr))) - self.authenticator = md5_constructor(header[0:4] + 16 * b'\x00' + - attr + self.secret).digest() - - if self.message_authenticator: - self._refresh_message_authenticator() - attr = self._pkt_encode_attributes() - self.authenticator = md5_constructor(header[0:4] + 16 * b'\x00' + - attr + self.secret).digest() - - return header + self.authenticator + attr - - -def CreateID(): - """Generate a packet ID. - - :return: packet ID - :rtype: 8 bit integer - """ - global CurrentID - - CurrentID = (CurrentID + 1) % 256 - return CurrentID diff --git a/pyrad3/tests/__init__.py b/pyrad3/tests/__init__.py deleted file mode 100644 index 0a99242..0000000 --- a/pyrad3/tests/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -import pyrad -import sys - -pyrad # keep pyflakes happy -home = sys.modules["pyrad"].__path__[0] diff --git a/pyrad3/tests/data/chap b/pyrad3/tests/data/chap deleted file mode 100644 index a94de18..0000000 --- a/pyrad3/tests/data/chap +++ /dev/null @@ -1,6 +0,0 @@ -# A simple dictionary - -ATTRIBUTE User-Name 1 string -ATTRIBUTE User-Password 2 string encrypt=1 -ATTRIBUTE CHAP-Password 3 octets -ATTRIBUTE CHAP-Challenge 60 octets diff --git a/pyrad3/tests/data/dictfiletest b/pyrad3/tests/data/dictfiletest deleted file mode 100644 index 6d908e9..0000000 --- a/pyrad3/tests/data/dictfiletest +++ /dev/null @@ -1,5 +0,0 @@ -# A failing dictionary - -VALUE Not-Defined Undefined-Value 1 - - diff --git a/pyrad3/tests/data/full b/pyrad3/tests/data/full deleted file mode 100644 index a4162c6..0000000 --- a/pyrad3/tests/data/full +++ /dev/null @@ -1,34 +0,0 @@ -# A simple dictionary - -ATTRIBUTE Test-String 1 string -ATTRIBUTE Test-Octets 2 octets -ATTRIBUTE Test-Integer 3 integer - -VALUE Test-Integer Zero 0 -VALUE Test-Integer One 1 -VALUE Test-Integer Two 2 -VALUE Test-Integer Three 3 -VALUE Test-Integer Four 4 - -ATTRIBUTE Test-Tlv 4 tlv -ATTRIBUTE Test-Tlv-Str 4.1 string -ATTRIBUTE Test-Tlv-Int 4.2 integer - -VENDOR Simplon 16 - - -BEGIN-VENDOR Simplon -ATTRIBUTE Simplon-Number 1 integer -ATTRIBUTE Simplon-String 2 string - -VALUE Simplon-Number Zero 0 -VALUE Simplon-Number One 1 -VALUE Simplon-Number Two 2 -VALUE Simplon-Number Three 3 -VALUE Simplon-Number Four 4 - -ATTRIBUTE Simplon-Tlv 3 tlv -ATTRIBUTE Simplon-Tlv-Str 3.1 string -ATTRIBUTE Simplon-Tlv-Int 3.2 integer - -END-VENDOR Simplon diff --git a/pyrad3/tests/data/simple b/pyrad3/tests/data/simple deleted file mode 100644 index 3d555b3..0000000 --- a/pyrad3/tests/data/simple +++ /dev/null @@ -1,16 +0,0 @@ -# A simple dictionary - -ATTRIBUTE Test-String 1 string -ATTRIBUTE Test-Octets 2 octets -ATTRIBUTE Test-Integer 3 integer -ATTRIBUTE Test-Ip-Address 4 ipaddr -ATTRIBUTE Test-Ipv6-Address 5 ipv6addr -ATTRIBUTE Test-If-Id 6 ifid -ATTRIBUTE Test-Date 7 date -ATTRIBUTE Test-Abinary 8 abinary -ATTRIBUTE Test-Tlv 9 tlv -ATTRIBUTE Test-Tlv-Str 9.1 string -ATTRIBUTE Test-Tlv-Int 9.2 integer -ATTRIBUTE Test-Integer64 10 integer64 -ATTRIBUTE Test-Integer64-Hex 0x0a integer64 -ATTRIBUTE Test-Integer64-Oct 0o12 integer64 diff --git a/pyrad3/tests/data/tunnelPassword b/pyrad3/tests/data/tunnelPassword deleted file mode 100644 index 09ec00f..0000000 --- a/pyrad3/tests/data/tunnelPassword +++ /dev/null @@ -1,3 +0,0 @@ -# A simple dictionary - -ATTRIBUTE Tunnel-Password 2 string encrypt=2 diff --git a/pyrad3/tests/mock.py b/pyrad3/tests/mock.py deleted file mode 100644 index ee71fb3..0000000 --- a/pyrad3/tests/mock.py +++ /dev/null @@ -1,141 +0,0 @@ -import fcntl -import os -from pyrad.packet import PacketError - - -class MockPacket: - reply = object() - - def __init__(self, code, verify=False, error=False): - self.code = code - self.data = {} - self.verify = verify - self.error = error - - def CreateReply(self, packet=None): - if self.error: - raise PacketError - return self.reply - - def VerifyReply(self, reply, rawreply): - return self.verify - - def RequestPacket(self): - return "request packet" - - def __contains__(self, key): - return key in self.data - has_key = __contains__ - - def __setitem__(self, key, value): - self.data[key] = [value] - - def __getitem__(self, key): - return self.data[key] - - -class MockSocket: - def __init__(self, domain, type, data=None): - self.domain = domain - self.type = type - self.closed = False - self.options = [] - self.address = None - self.output = [] - - if data is not None: - (self.read_end, self.write_end) = os.pipe() - fcntl.fcntl(self.write_end, fcntl.F_SETFL, os.O_NONBLOCK) - os.write(self.write_end, data) - self.data = data - else: - self.read_end = 1 - self.write_end = None - - def fileno(self): - return self.read_end - - def bind(self, address): - self.address = address - - def recv(self, buffer): - return self.data[:buffer] - - def sendto(self, data, target): - self.output.append((data, target)) - - def setsockopt(self, level, opt, value): - self.options.append((level, opt, value)) - - def close(self): - self.closed = True - - -class MockFinished(Exception): - pass - - -class MockPoll: - results = [] - - def __init__(self): - self.registry = {} - - def register(self, fd, options): - self.registry[fd] = options - - def unregister(self, fd): - try: - del self.registry[fd] - except KeyError: - pass - - def poll(self, timeout=None): - for result in self.results: - yield result - raise MockFinished - - -def origkey(klass): - return "_originals_" + klass.__name__ - - -def MockClassMethod(klass, name, myfunc=None): - def func(self, *args, **kwargs): - if not hasattr(self, "called"): - self.called = [] - self.called.append((name, args, kwargs)) - - key = origkey(klass) - if not hasattr(klass, key): - setattr(klass, key, {}) - getattr(klass, key)[name] = getattr(klass, name) - if myfunc is None: - setattr(klass, name, func) - else: - setattr(klass, name, myfunc) - - -def UnmockClassMethods(klass): - key = origkey(klass) - if not hasattr(klass, key): - return - for (name, func) in getattr(klass, key).items(): - setattr(klass, name, func) - - delattr(klass, key) - - -class MockFd: - data = object() - source = object() - - def __init__(self, fd=0): - self.fd = fd - - def fileno(self): - return self.fd - - def recvfrom(self, size): - self.size = size - return (self.data, self.source) diff --git a/pyrad3/tests/testBidict.py b/pyrad3/tests/testBidict.py deleted file mode 100644 index e67fc0f..0000000 --- a/pyrad3/tests/testBidict.py +++ /dev/null @@ -1,56 +0,0 @@ -import operator -import unittest -from pyrad.bidict import BiDict - - -class BiDictTests(unittest.TestCase): - def setUp(self): - self.bidict = BiDict() - - def testStartEmpty(self): - self.assertEqual(len(self.bidict), 0) - self.assertEqual(len(self.bidict.forward), 0) - self.assertEqual(len(self.bidict.backward), 0) - - def testLength(self): - self.assertEqual(len(self.bidict), 0) - self.bidict.add("from", "to") - self.assertEqual(len(self.bidict), 1) - del self.bidict["from"] - self.assertEqual(len(self.bidict), 0) - - def testDeletion(self): - self.assertRaises(KeyError, operator.delitem, self.bidict, "missing") - self.bidict.add("missing", "present") - del self.bidict["missing"] - - def testBackwardDeletion(self): - self.assertRaises(KeyError, operator.delitem, self.bidict, "missing") - self.bidict.add("missing", "present") - del self.bidict["present"] - self.assertEqual(self.bidict.has_forward("missing"), False) - - def testForwardAccess(self): - self.bidict.add("shake", "vanilla") - self.bidict.add("pie", "custard") - self.assertEqual(self.bidict.has_forward("shake"), True) - self.assertEqual(self.bidict.get_forward("shake"), "vanilla") - self.assertEqual(self.bidict.has_forward("pie"), True) - self.assertEqual(self.bidict.get_forward("pie"), "custard") - self.assertEqual(self.bidict.has_forward("missing"), False) - self.assertRaises(KeyError, self.bidict.get_forward, "missing") - - def testBackwardAccess(self): - self.bidict.add("shake", "vanilla") - self.bidict.add("pie", "custard") - self.assertEqual(self.bidict.has_backward("vanilla"), True) - self.assertEqual(self.bidict.get_backward("vanilla"), "shake") - self.assertEqual(self.bidict.has_backward("missing"), False) - self.assertRaises(KeyError, self.bidict.get_backward, "missing") - - def testItemAccessor(self): - self.bidict.add("shake", "vanilla") - self.bidict.add("pie", "custard") - self.assertRaises(KeyError, operator.getitem, self.bidict, "missing") - self.assertEquals(self.bidict["shake"], "vanilla") - self.assertEquals(self.bidict["pie"], "custard") diff --git a/pyrad3/tests/testClient.py b/pyrad3/tests/testClient.py deleted file mode 100644 index 79ab6ba..0000000 --- a/pyrad3/tests/testClient.py +++ /dev/null @@ -1,184 +0,0 @@ -import select -import socket -import unittest -from pyrad.client import Client -from pyrad.client import Timeout -from pyrad.packet import AuthPacket -from pyrad.packet import AcctPacket -from pyrad.packet import AccessRequest -from pyrad.packet import AccountingRequest -from pyrad.tests.mock import MockPacket -from pyrad.tests.mock import MockPoll -from pyrad.tests.mock import MockSocket - -BIND_IP = "127.0.0.1" -BIND_PORT = 53535 - - -class ConstructionTests(unittest.TestCase): - def setUp(self): - self.server = object() - - def testSimpleConstruction(self): - client = Client(self.server) - self.failUnless(client.server is self.server) - self.assertEqual(client.authport, 1812) - self.assertEqual(client.acctport, 1813) - self.assertEqual(client.secret, b'') - self.assertEqual(client.retries, 3) - self.assertEqual(client.timeout, 5) - self.failUnless(client.dict is None) - - def testParameterOrder(self): - marker = object() - client = Client(self.server, 123, 456, 789, "secret", marker) - self.failUnless(client.server is self.server) - self.assertEqual(client.authport, 123) - self.assertEqual(client.acctport, 456) - self.assertEqual(client.coaport, 789) - self.assertEqual(client.secret, "secret") - self.failUnless(client.dict is marker) - - def testNamedParameters(self): - marker = object() - client = Client(server=self.server, authport=123, acctport=456, - secret="secret", dict=marker) - self.failUnless(client.server is self.server) - self.assertEqual(client.authport, 123) - self.assertEqual(client.acctport, 456) - self.assertEqual(client.secret, "secret") - self.failUnless(client.dict is marker) - - -class SocketTests(unittest.TestCase): - def setUp(self): - self.server = object() - self.client = Client(self.server) - self.orgsocket = socket.socket - socket.socket = MockSocket - - def tearDown(self): - socket.socket = self.orgsocket - - def testReopen(self): - self.client._socket_open() - sock = self.client._socket - self.client._socket_open() - self.failUnless(sock is self.client._socket) - - def testBind(self): - self.client.bind((BIND_IP, BIND_PORT)) - self.assertEqual(self.client._socket.address, (BIND_IP, BIND_PORT)) - self.assertEqual(self.client._socket.options, - [(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)]) - - def testBindClosesSocket(self): - s = MockSocket(socket.AF_INET, socket.SOCK_DGRAM) - self.client._socket = s - self.client._poll = MockPoll() - self.client.bind((BIND_IP, BIND_PORT)) - self.assertEqual(s.closed, True) - - def testSendPacket(self): - def MockSend(self, pkt, port): - self._mock_pkt = pkt - self._mock_port = port - - _send_packet = Client._send_packet - Client._send_packet = MockSend - - self.client.SendPacket(AuthPacket()) - self.assertEqual(self.client._mock_port, self.client.authport) - - self.client.SendPacket(AcctPacket()) - self.assertEqual(self.client._mock_port, self.client.acctport) - - Client._send_packet = _send_packet - - def testNoRetries(self): - self.client.retries = 0 - self.assertRaises(Timeout, self.client._send_packet, None, None) - - def testSingleRetry(self): - self.client.retries = 1 - self.client.timeout = 0 - packet = MockPacket(AccessRequest) - self.assertRaises(Timeout, self.client._send_packet, packet, 432) - self.assertEqual(self.client._socket.output, - [("request packet", (self.server, 432))]) - - def testDoubleRetry(self): - self.client.retries = 2 - self.client.timeout = 0 - packet = MockPacket(AccessRequest) - self.assertRaises(Timeout, self.client._send_packet, packet, 432) - self.assertEqual( - self.client._socket.output, - [("request packet", (self.server, 432)), - ("request packet", (self.server, 432))]) - - def testAuthDelay(self): - self.client.retries = 2 - self.client.timeout = 1 - packet = MockPacket(AccessRequest) - self.assertRaises(Timeout, self.client._send_packet, packet, 432) - self.failIf("Acct-Delay-Time" in packet) - - def testSingleAccountDelay(self): - self.client.retries = 2 - self.client.timeout = 1 - packet = MockPacket(AccountingRequest) - self.assertRaises(Timeout, self.client._send_packet, packet, 432) - self.assertEqual(packet["Acct-Delay-Time"], [1]) - - def testDoubleAccountDelay(self): - self.client.retries = 3 - self.client.timeout = 1 - packet = MockPacket(AccountingRequest) - self.assertRaises(Timeout, self.client._send_packet, packet, 432) - self.assertEqual(packet["Acct-Delay-Time"], [2]) - - def testIgnorePacketError(self): - self.client.retries = 1 - self.client.timeout = 1 - self.client._socket = MockSocket(1, 2, b"valid reply") - packet = MockPacket(AccountingRequest, verify=True, error=True) - self.assertRaises(Timeout, self.client._send_packet, packet, 432) - - def testValidReply(self): - self.client.retries = 1 - self.client.timeout = 1 - self.client._socket = MockSocket(1, 2, b"valid reply") - self.client._poll = MockPoll() - MockPoll.results = [(1, select.POLLIN)] - packet = MockPacket(AccountingRequest, verify=True) - reply = self.client._send_packet(packet, 432) - self.failUnless(reply is packet.reply) - - def testInvalidReply(self): - self.client.retries = 1 - self.client.timeout = 1 - self.client._socket = MockSocket(1, 2, b"invalid reply") - MockPoll.results = [(1, select.POLLIN)] - packet = MockPacket(AccountingRequest, verify=False) - self.assertRaises(Timeout, self.client._send_packet, packet, 432) - - -class OtherTests(unittest.TestCase): - def setUp(self): - self.server = object() - self.client = Client(self.server, secret=b'zeer geheim') - - def testCreateAuthPacket(self): - packet = self.client.CreateAuthPacket(id=15) - self.failUnless(isinstance(packet, AuthPacket)) - self.failUnless(packet.dict is self.client.dict) - self.assertEqual(packet.id, 15) - self.assertEqual(packet.secret, b'zeer geheim') - - def testCreateAcctPacket(self): - packet = self.client.CreateAcctPacket(id=15) - self.failUnless(isinstance(packet, AcctPacket)) - self.failUnless(packet.dict is self.client.dict) - self.assertEqual(packet.id, 15) - self.assertEqual(packet.secret, b'zeer geheim') diff --git a/pyrad3/tests/testDictionary.py b/pyrad3/tests/testDictionary.py deleted file mode 100644 index 1929798..0000000 --- a/pyrad3/tests/testDictionary.py +++ /dev/null @@ -1,332 +0,0 @@ -import unittest -import operator -import os -from io import StringIO -from pyrad.tests import home -from pyrad.dictionary import Attribute -from pyrad.dictionary import Dictionary -from pyrad.dictionary import ParseError -from pyrad.tools import DecodeAttr -from pyrad.dictfile import DictFile - - -class AttributeTests(unittest.TestCase): - def testInvalidDataType(self): - self.assertRaises(ValueError, Attribute, 'name', 'code', 'datatype') - - def testConstructionParameters(self): - attr = Attribute('name', 'code', 'integer', False, 'vendor') - self.assertEqual(attr.name, 'name') - self.assertEqual(attr.code, 'code') - self.assertEqual(attr.type, 'integer') - self.assertEqual(attr.is_sub_attribute, False) - self.assertEqual(attr.vendor, 'vendor') - self.assertEqual(len(attr.values), 0) - self.assertEqual(len(attr.sub_attributes), 0) - - def testNamedConstructionParameters(self): - attr = Attribute(name='name', code='code', datatype='integer', - vendor='vendor') - self.assertEqual(attr.name, 'name') - self.assertEqual(attr.code, 'code') - self.assertEqual(attr.type, 'integer') - self.assertEqual(attr.vendor, 'vendor') - self.assertEqual(len(attr.values), 0) - - def testValues(self): - attr = Attribute('name', 'code', 'integer', False, 'vendor', - dict(pie='custard', shake='vanilla')) - self.assertEqual(len(attr.values), 2) - self.assertEqual(attr.values['shake'], 'vanilla') - - -class DictionaryInterfaceTests(unittest.TestCase): - def testEmptyDictionary(self): - dict = Dictionary() - self.assertEqual(len(dict), 0) - - def testContainment(self): - dict = Dictionary() - self.assertEqual('test' in dict, False) - dict.attributes['test'] = 'dummy' - self.assertEqual('test' in dict, True) - - def testReadonlyContainer(self): - dict = Dictionary() - self.assertRaises(TypeError, - operator.setitem, dict, 'test', 'dummy') - self.assertRaises(AttributeError, - operator.attrgetter('clear'), dict) - self.assertRaises(AttributeError, - operator.attrgetter('update'), dict) - - -class DictionaryParsingTests(unittest.TestCase): - - simple_dict_values = [ - ('Test-String', 1, 'string'), - ('Test-Octets', 2, 'octets'), - ('Test-Integer', 0x03, 'integer'), - ('Test-Ip-Address', 4, 'ipaddr'), - ('Test-Ipv6-Address', 5, 'ipv6addr'), - ('Test-If-Id', 6, 'ifid'), - ('Test-Date', 7, 'date'), - ('Test-Abinary', 8, 'abinary'), - ('Test-Tlv', 9, 'tlv'), - ('Test-Tlv-Str', 1, 'string'), - ('Test-Tlv-Int', 2, 'integer'), - ('Test-Integer64', 10, 'integer64'), - ('Test-Integer64-Hex', 10, 'integer64'), - ('Test-Integer64-Oct', 10, 'integer64'), - ] - - def setUp(self): - self.path = os.path.join(home, 'tests', 'data') - self.dict = Dictionary(os.path.join(self.path, 'simple')) - - def testParseEmptyDictionary(self): - dict = Dictionary(StringIO('')) - self.assertEqual(len(dict), 0) - - def testParseMultipleDictionaries(self): - dict = Dictionary(StringIO('')) - self.assertEqual(len(dict), 0) - one = StringIO('ATTRIBUTE Test-First 1 string') - two = StringIO('ATTRIBUTE Test-Second 2 string') - dict = Dictionary(StringIO(''), one, two) - self.assertEqual(len(dict), 2) - - def testParseSimpleDictionary(self): - self.assertEqual(len(self.dict), len(self.simple_dict_values)) - for (attr, code, type) in self.simple_dict_values: - attr = self.dict[attr] - self.assertEqual(attr.code, code) - self.assertEqual(attr.type, type) - - def testAttributeTooFewColumnsError(self): - try: - self.dict.read_dictionary( - StringIO('ATTRIBUTE Oops-Too-Few-Columns')) - except ParseError as e: - self.assertEqual('attribute' in str(e), True) - else: - self.fail() - - def testAttributeUnknownTypeError(self): - try: - self.dict.read_dictionary(StringIO('ATTRIBUTE Test-Type 1 dummy')) - except ParseError as e: - self.assertEqual('dummy' in str(e), True) - else: - self.fail() - - def testAttributeUnknownVendorError(self): - try: - self.dict.read_dictionary(StringIO('ATTRIBUTE Test-Type 1 Simplon')) - except ParseError as e: - self.assertEqual('Simplon' in str(e), True) - else: - self.fail() - - def testAttributeOptions(self): - self.dict.read_dictionary(StringIO( - 'ATTRIBUTE Option-Type 1 string has_tag,encrypt=1')) - self.assertEqual(self.dict['Option-Type'].has_tag, True) - self.assertEqual(self.dict['Option-Type'].encrypt, 1) - - def testAttributeEncryptionError(self): - try: - self.dict.read_dictionary(StringIO( - 'ATTRIBUTE Test-Type 1 string encrypt=4')) - except ParseError as e: - self.assertEqual('encrypt' in str(e), True) - else: - self.fail() - - def testValueTooFewColumnsError(self): - try: - self.dict.read_dictionary(StringIO('VALUE Oops-Too-Few-Columns')) - except ParseError as e: - self.assertEqual('value' in str(e), True) - else: - self.fail() - - def testValueForUnknownAttributeError(self): - try: - self.dict.read_dictionary(StringIO( - 'VALUE Test-Attribute Test-Text 1')) - except ParseError as e: - self.assertEqual('unknown attribute' in str(e), True) - else: - self.fail() - - def testIntegerValueParsing(self): - self.assertEqual(len(self.dict['Test-Integer'].values), 0) - self.dict.read_dictionary(StringIO('VALUE Test-Integer Value-Six 5')) - self.assertEqual(len(self.dict['Test-Integer'].values), 1) - self.assertEqual( - DecodeAttr('integer', self.dict['Test-Integer'].values['Value-Six']), - 5) - - def testInteger64ValueParsing(self): - self.assertEqual(len(self.dict['Test-Integer64'].values), 0) - self.dict.read_dictionary(StringIO('VALUE Test-Integer64 Value-Six 5')) - self.assertEqual(len(self.dict['Test-Integer64'].values), 1) - self.assertEqual( - DecodeAttr('integer64', self.dict['Test-Integer64'].values['Value-Six']), - 5) - - def testStringValueParsing(self): - self.assertEqual(len(self.dict['Test-String'].values), 0) - self.dict.read_dictionary(StringIO('VALUE Test-String Value-Custard custardpie')) - self.assertEqual(len(self.dict['Test-String'].values), 1) - self.assertEqual( - DecodeAttr('string', self.dict['Test-String'].values['Value-Custard']), - 'custardpie') - - def testTlvParsing(self): - self.assertEqual(len(self.dict['Test-Tlv'].sub_attributes), 2) - self.assertEqual(self.dict['Test-Tlv'].sub_attributes, - {1: 'Test-Tlv-Str', 2: 'Test-Tlv-Int'}) - - def testSubTlvParsing(self): - for (attr, _, _) in self.simple_dict_values: - if attr.startswith('Test-Tlv-'): - self.assertEqual(self.dict[attr].is_sub_attribute, True) - self.assertEqual(self.dict[attr].parent, self.dict['Test-Tlv']) - else: - self.assertEqual(self.dict[attr].is_sub_attribute, False) - self.assertEqual(self.dict[attr].parent, None) - - # tlv with vendor - full_dict = Dictionary(os.path.join(self.path, 'full')) - self.assertEqual(full_dict['Simplon-Tlv-Str'].is_sub_attribute, True) - self.assertEqual(full_dict['Simplon-Tlv-Str'].parent, full_dict['Simplon-Tlv']) - self.assertEqual(full_dict['Simplon-Tlv-Int'].is_sub_attribute, True) - self.assertEqual(full_dict['Simplon-Tlv-Int'].parent, full_dict['Simplon-Tlv']) - - def testVenderTooFewColumnsError(self): - try: - self.dict.read_dictionary(StringIO('VENDOR Simplon')) - except ParseError as e: - self.assertEqual('vendor' in str(e), True) - else: - self.fail() - - def testVendorParsing(self): - self.assertRaises(ParseError, self.dict.read_dictionary, - StringIO('ATTRIBUTE Test-Type 1 integer Simplon')) - self.dict.read_dictionary(StringIO('VENDOR Simplon 42')) - self.assertEqual(self.dict.vendors['Simplon'], 42) - self.dict.read_dictionary(StringIO('ATTRIBUTE Test-Type 1 integer Simplon')) - self.assertEqual(self.dict.attrindex['Test-Type'], (42, 1)) - - def testVendorOptionError(self): - self.assertRaises(ParseError, self.dict.read_dictionary, - StringIO('ATTRIBUTE Test-Type 1 integer Simplon')) - try: - self.dict.read_dictionary(StringIO('VENDOR Simplon 42 badoption')) - except ParseError as e: - self.assertEqual('option' in str(e), True) - else: - self.fail() - - def testVendorFormatError(self): - self.assertRaises(ParseError, self.dict.read_dictionary, - StringIO('ATTRIBUTE Test-Type 1 integer Simplon')) - try: - self.dict.read_dictionary(StringIO('VENDOR Simplon 42 format=5,4')) - except ParseError as e: - self.assertEqual('format' in str(e), True) - else: - self.fail() - - def testVendorFormatSyntaxError(self): - self.assertRaises(ParseError, self.dict.read_dictionary, - StringIO('ATTRIBUTE Test-Type 1 integer Simplon')) - try: - self.dict.read_dictionary(StringIO('VENDOR Simplon 42 format=a,1')) - except ParseError as e: - self.assertEqual('Syntax' in str(e), True) - else: - self.fail() - - def testBeginVendorTooFewColumns(self): - try: - self.dict.read_dictionary(StringIO('BEGIN-VENDOR')) - except ParseError as e: - self.assertEqual('begin-vendor' in str(e), True) - else: - self.fail() - - def testBeginVendorUnknownVendor(self): - try: - self.dict.read_dictionary(StringIO('BEGIN-VENDOR Simplon')) - except ParseError as e: - self.assertEqual('Simplon' in str(e), True) - else: - self.fail() - - def testBeginVendorParsing(self): - self.dict.read_dictionary(StringIO( - 'VENDOR Simplon 42\n' - 'BEGIN-VENDOR Simplon\n' - 'ATTRIBUTE Test-Type 1 integer')) - self.assertEqual(self.dict.attrindex['Test-Type'], (42, 1)) - - def testEndVendorUnknownVendor(self): - try: - self.dict.read_dictionary(StringIO('END-VENDOR')) - except ParseError as e: - self.assertEqual('end-vendor' in str(e), True) - else: - self.fail() - - def testEndVendorUnbalanced(self): - try: - self.dict.read_dictionary(StringIO( - 'VENDOR Simplon 42\n' - 'BEGIN-VENDOR Simplon\n' - 'END-VENDOR Oops\n')) - except ParseError as e: - self.assertEqual('Oops' in str(e), True) - else: - self.fail() - - def testEndVendorParsing(self): - self.dict.read_dictionary(StringIO( - 'VENDOR Simplon 42\n' - 'BEGIN-VENDOR Simplon\n' - 'END-VENDOR Simplon\n' - 'ATTRIBUTE Test-Type 1 integer')) - self.assertEqual(self.dict.attrindex['Test-Type'], 1) - - def testInclude(self): - try: - self.dict.read_dictionary(StringIO( - '$INCLUDE this_file_does_not_exist\n' - 'VENDOR Simplon 42\n' - 'BEGIN-VENDOR Simplon\n' - 'END-VENDOR Simplon\n' - 'ATTRIBUTE Test-Type 1 integer')) - except IOError as e: - self.assertEqual('this_file_does_not_exist' in str(e), True) - else: - self.fail() - - def testDictFilePostParse(self): - f = DictFile(StringIO( - 'VENDOR Simplon 42\n')) - for _ in f: - pass - self.assertEqual(f.file(), '') - self.assertEqual(f.line(), -1) - - def testDictFileParseError(self): - tmpdict = Dictionary() - try: - tmpdict.read_dictionary(os.path.join(self.path, 'dictfiletest')) - except ParseError as e: - self.assertEqual('dictfiletest' in str(e), True) - else: - self.fail() diff --git a/pyrad3/tests/testHost.py b/pyrad3/tests/testHost.py deleted file mode 100644 index ec51deb..0000000 --- a/pyrad3/tests/testHost.py +++ /dev/null @@ -1,87 +0,0 @@ -import unittest -from pyrad.host import Host -from pyrad.packet import Packet -from pyrad.packet import AuthPacket -from pyrad.packet import AcctPacket - - -class ConstructionTests(unittest.TestCase): - def testSimpleConstruction(self): - host = Host() - self.assertEqual(host.authport, 1812) - self.assertEqual(host.acctport, 1813) - - def testParameterOrder(self): - host = Host(123, 456, 789, 101) - self.assertEqual(host.authport, 123) - self.assertEqual(host.acctport, 456) - self.assertEqual(host.coaport, 789) - self.assertEqual(host.dict, 101) - - def testNamedParameters(self): - host = Host(authport=123, acctport=456, coaport=789, dict=101) - self.assertEqual(host.authport, 123) - self.assertEqual(host.acctport, 456) - self.assertEqual(host.coaport, 789) - self.assertEqual(host.dict, 101) - - -class PacketCreationTests(unittest.TestCase): - def setUp(self): - self.host = Host() - - def testCreatePacket(self): - packet = self.host.CreatePacket(id=15) - self.failUnless(isinstance(packet, Packet)) - self.failUnless(packet.dict is self.host.dict) - self.assertEqual(packet.id, 15) - - def testCreateAuthPacket(self): - packet = self.host.CreateAuthPacket(id=15) - self.failUnless(isinstance(packet, AuthPacket)) - self.failUnless(packet.dict is self.host.dict) - self.assertEqual(packet.id, 15) - - def testCreateAcctPacket(self): - packet = self.host.CreateAcctPacket(id=15) - self.failUnless(isinstance(packet, AcctPacket)) - self.failUnless(packet.dict is self.host.dict) - self.assertEqual(packet.id, 15) - - -class MockPacket: - packet = object() - replypacket = object() - source = object() - - def Packet(self): - return self.packet - - def ReplyPacket(self): - return self.replypacket - - -class MockFd: - data = None - target = None - - def sendto(self, data, target): - self.data = data - self.target = target - - -class PacketSendTest(unittest.TestCase): - def setUp(self): - self.host = Host() - self.fd = MockFd() - self.packet = MockPacket() - - def testSendPacket(self): - self.host.SendPacket(self.fd, self.packet) - self.failUnless(self.fd.data is self.packet.packet) - self.failUnless(self.fd.target is self.packet.source) - - def testSendReplyPacket(self): - self.host.SendReplyPacket(self.fd, self.packet) - self.failUnless(self.fd.data is self.packet.replypacket) - self.failUnless(self.fd.target is self.packet.source) diff --git a/pyrad3/tests/testPacket.py b/pyrad3/tests/testPacket.py deleted file mode 100644 index 630175c..0000000 --- a/pyrad3/tests/testPacket.py +++ /dev/null @@ -1,535 +0,0 @@ -import os -import unittest - -from collections import OrderedDict -from pyrad import packet -from pyrad.client import Client -from pyrad.tests import home -from pyrad.dictionary import Dictionary - -import hashlib -md5_constructor = hashlib.md5 - - -class UtilityTests(unittest.TestCase): - def testGenerateID(self): - id = packet.CreateID() - self.assertTrue(isinstance(id, int)) - newid = packet.CreateID() - self.assertNotEqual(id, newid) - - -class PacketConstructionTests(unittest.TestCase): - klass = packet.Packet - - def setUp(self): - self.path = os.path.join(home, 'tests', 'data') - self.dict = Dictionary(os.path.join(self.path, 'simple')) - - def testBasicConstructor(self): - pkt = self.klass() - self.assertTrue(isinstance(pkt.code, int)) - self.assertTrue(isinstance(pkt.id, int)) - self.assertTrue(isinstance(pkt.secret, bytes)) - - def testNamedConstructor(self): - pkt = self.klass(code=26, id=38, secret=b'secret', - authenticator=b'authenticator', - dict='fakedict') - self.assertEqual(pkt.code, 26) - self.assertEqual(pkt.id, 38) - self.assertEqual(pkt.secret, b'secret') - self.assertEqual(pkt.authenticator, b'authenticator') - self.assertEqual(pkt.dict, 'fakedict') - - def testConstructWithDictionary(self): - pkt = self.klass(dict=self.dict) - self.assertTrue(pkt.dict is self.dict) - - def testConstructorIgnoredParameters(self): - marker = [] - pkt = self.klass(fd=marker) - self.assertFalse(getattr(pkt, 'fd', None) is marker) - - def testSecretMustBeBytestring(self): - self.assertRaises(TypeError, self.klass, secret='secret') - - def testConstructorWithAttributes(self): - pkt = self.klass(**{'Test-String': 'this works', 'dict': self.dict}) - self.assertEqual(pkt['Test-String'], ['this works']) - - def testConstructorWithTlvAttribute(self): - pkt = self.klass(**{ - 'Test-Tlv-Str': 'this works', - 'Test-Tlv-Int': 10, - 'dict': self.dict - }) - self.assertEqual( - pkt['Test-Tlv'], - {'Test-Tlv-Str': ['this works'], 'Test-Tlv-Int': [10]} - ) - - -class PacketTests(unittest.TestCase): - def setUp(self): - self.path = os.path.join(home, 'tests', 'data') - self.dict = Dictionary(os.path.join(self.path, 'full')) - self.packet = packet.Packet( - id=0, secret=b'secret', - authenticator=b'01234567890ABCDEF', dict=self.dict) - - def testCreateReply(self): - reply = self.packet.CreateReply(**{'Test-Integer': 10}) - self.assertEqual(reply.id, self.packet.id) - self.assertEqual(reply.secret, self.packet.secret) - self.assertEqual(reply.authenticator, self.packet.authenticator) - self.assertEqual(reply['Test-Integer'], [10]) - - def testAttributeAccess(self): - self.packet['Test-Integer'] = 10 - self.assertEqual(self.packet['Test-Integer'], [10]) - self.assertEqual(self.packet[3], [b'\x00\x00\x00\x0a']) - - self.packet['Test-String'] = 'dummy' - self.assertEqual(self.packet['Test-String'], ['dummy']) - self.assertEqual(self.packet[1], [b'dummy']) - - def testAttributeValueAccess(self): - self.packet['Test-Integer'] = 'Three' - self.assertEqual(self.packet['Test-Integer'], ['Three']) - self.assertEqual(self.packet[3], [b'\x00\x00\x00\x03']) - - def testVendorAttributeAccess(self): - self.packet['Simplon-Number'] = 10 - self.assertEqual(self.packet['Simplon-Number'], [10]) - self.assertEqual(self.packet[(16, 1)], [b'\x00\x00\x00\x0a']) - - self.packet['Simplon-Number'] = 'Four' - self.assertEqual(self.packet['Simplon-Number'], ['Four']) - self.assertEqual(self.packet[(16, 1)], [b'\x00\x00\x00\x04']) - - def testRawAttributeAccess(self): - marker = [b''] - self.packet[1] = marker - self.assertTrue(self.packet[1] is marker) - self.packet[(16, 1)] = marker - self.assertTrue(self.packet[(16, 1)] is marker) - - def testHasKey(self): - self.assertEqual('Test-String' in self.packet, False) - self.packet['Test-String'] = 'dummy' - self.assertEqual('Test-String' in self.packet, True) - self.assertEqual(1 in self.packet, True) - - def testHasKeyWithUnknownKey(self): - self.assertEqual('Unknown-Attribute' in self.packet, False) - - def testDelItem(self): - self.packet['Test-String'] = 'dummy' - del self.packet['Test-String'] - self.assertEqual('Test-String' in self.packet, False) - self.packet['Test-String'] = 'dummy' - del self.packet[1] - self.assertEqual('Test-String' in self.packet, False) - - def testKeys(self): - self.assertEqual(self.packet.keys(), []) - self.packet['Test-String'] = 'dummy' - self.assertEqual(self.packet.keys(), ['Test-String']) - self.packet['Test-Integer'] = 10 - self.assertEqual(self.packet.keys(), ['Test-String', 'Test-Integer']) - OrderedDict.__setitem__(self.packet, 12345, None) - self.assertEqual(self.packet.keys(), - ['Test-String', 'Test-Integer', 12345]) - - def testCreateAuthenticator(self): - a = packet.Packet.CreateAuthenticator() - self.assertTrue(isinstance(a, bytes)) - self.assertEqual(len(a), 16) - - b = packet.Packet.CreateAuthenticator() - self.assertNotEqual(a, b) - - def testGenerateID(self): - id = self.packet.CreateID() - self.assertTrue(isinstance(id, int)) - newid = self.packet.CreateID() - self.assertNotEqual(id, newid) - - def testReplyPacket(self): - reply = self.packet.ReplyPacket() - self.assertEqual( - reply, - (b'\x00\x00\x00\x14\xb0\x5e\x4b\xfb\xcc\x1c' - b'\x8c\x8e\xc4\x72\xac\xea\x87\x45\x63\xa7')) - - def testVerifyReply(self): - reply = self.packet.CreateReply() - reply.id += 1 - with self.assertRaises(packet.PacketError): - self.packet.VerifyReply(reply.ReplyPacket()) - reply.id = self.packet.id - - reply.secret = b'different' - with self.assertRaises(packet.PacketError): - self.packet.VerifyReply(reply.ReplyPacket()) - reply.secret = self.packet.secret - - reply.authenticator = b'X' * 16 - with self.assertRaises(packet.PacketError): - self.packet.VerifyReply(reply.ReplyPacket()) - - def testPktEncodeAttribute(self): - encode = self.packet._pkt_encode_attribute - - # Encode a normal attribute - self.assertEqual( - encode(1, b'value'), - b'\x01\x07value') - # Encode a vendor attribute - self.assertEqual( - encode((1, 2), b'value'), - b'\x1a\x0d\x00\x00\x00\x01\x02\x07value') - - def testPktEncodeTlvAttribute(self): - encode = self.packet._pkt_encode_tlv - - # Encode a normal tlv attribute - self.assertEqual( - encode(4, {1: [b'value'], 2: [b'\x00\x00\x00\x02']}), - b'\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x02') - - # Encode a normal tlv attribute with several sub attribute instances - self.assertEqual( - encode(4, {1: [b'value', b'other'], 2: [b'\x00\x00\x00\x02']}), - b'\x04\x16\x01\x07value\x02\x06\x00\x00\x00\x02\x01\x07other') - # Encode a vendor tlv attribute - self.assertEqual( - encode((16, 3), {1: [b'value'], 2: [b'\x00\x00\x00\x02']}), - b'\x1a\x15\x00\x00\x00\x10\x03\x0f\x01\x07value\x02\x06\x00\x00\x00\x02') - - def testPktEncodeLongTlvAttribute(self): - encode = self.packet._pkt_encode_tlv - - long_str = b'a' * 245 - # Encode a long tlv attribute - check it is split between AVPs - self.assertEqual( - encode(4, {1: [b'value', long_str], 2: [b'\x00\x00\x00\x02']}), - b'\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x02\x04\xf9\x01\xf7' + long_str) - - # Encode a long vendor tlv attribute - first_avp = b'\x1a\x15\x00\x00\x00\x10\x03\x0f\x01\x07value\x02\x06\x00\x00\x00\x02' - second_avp = b'\x1a\xff\x00\x00\x00\x10\x03\xf9\x01\xf7' + long_str - self.assertEqual( - encode((16, 3), {1: [b'value', long_str], 2: [b'\x00\x00\x00\x02']}), - first_avp + second_avp) - - def testpkt_encode_attributes(self): - self.packet[1] = [b'value'] - self.assertEqual(self.packet._pkt_encode_attributes(), - b'\x01\x07value') - - self.packet.clear() - self.packet[(16, 2)] = [b'value'] - self.assertEqual(self.packet._pkt_encode_attributes(), - b'\x1a\x0d\x00\x00\x00\x10\x02\x07value') - - self.packet.clear() - self.packet[1] = [b'one', b'two', b'three'] - self.assertEqual(self.packet._pkt_encode_attributes(), - b'\x01\x05one\x01\x05two\x01\x07three') - - self.packet.clear() - self.packet[1] = [b'value'] - self.packet[(16, 2)] = [b'value'] - self.assertEqual(self.packet._pkt_encode_attributes(), - b'\x01\x07value\x1a\x0d\x00\x00\x00\x10\x02\x07value') - - def testPktDecodeVendorAttribute(self): - decode = self.packet._pkt_decode_vendor_attribute - - # Non-RFC2865 recommended form - self.assertEqual(decode(b''), [(26, b'')]) - self.assertEqual(decode(b'12345'), [(26, b'12345')]) - - # Almost RFC2865 recommended form: bad length value - self.assertEqual(decode(b'\x00\x00\x00\x01\x02\x06value'), - [(26, b'\x00\x00\x00\x01\x02\x06value')]) - - # Proper RFC2865 recommended form - self.assertEqual(decode(b'\x00\x00\x00\x10\x02\x07value'), - [((16, 2), b'value')]) - - def testPktDecodeTlvAttribute(self): - decode = self.packet._pkt_decode_tlv_attribute - - decode(4, b'\x01\x07value') - self.assertEqual(self.packet[4], {1: [b'value']}) - - # add another instance of the same sub attribute - decode(4, b'\x01\x07other') - self.assertEqual(self.packet[4], {1: [b'value', b'other']}) - - # add a different sub attribute - decode(4, b'\x02\x07\x00\x00\x00\x01') - self.assertEqual(self.packet[4], { - 1: [b'value', b'other'], - 2: [b'\x00\x00\x00\x01'] - }) - - def testDecodePacketWithEmptyPacket(self): - try: - self.packet.DecodePacket(b'') - except packet.PacketError as e: - self.assertTrue('header is corrupt' in str(e)) - else: - self.fail() - - def testDecodePacketWithInvalidLength(self): - try: - self.packet.DecodePacket(b'\x00\x00\x00\x001234567890123456') - except packet.PacketError as e: - self.assertTrue('invalid length' in str(e)) - else: - self.fail() - - def testDecodePacketWithTooBigPacket(self): - try: - self.packet.DecodePacket(b'\x00\x00\x24\x00' + (0x2400 - 4) * b'X') - except packet.PacketError as e: - self.assertTrue('too long' in str(e)) - else: - self.fail() - - def testDecodePacketWithPartialAttributes(self): - try: - self.packet.DecodePacket( - b'\x01\x02\x00\x151234567890123456\x00') - except packet.PacketError as e: - self.assertTrue('header is corrupt' in str(e)) - else: - self.fail() - - def testDecodePacketWithoutAttributes(self): - self.packet.DecodePacket(b'\x01\x02\x00\x141234567890123456') - self.assertEqual(self.packet.code, 1) - self.assertEqual(self.packet.id, 2) - self.assertEqual(self.packet.authenticator, b'1234567890123456') - self.assertEqual(self.packet.keys(), []) - - def testDecodePacketWithBadAttribute(self): - try: - self.packet.DecodePacket( - b'\x01\x02\x00\x161234567890123456\x00\x01') - except packet.PacketError as e: - self.assertTrue('too small' in str(e)) - else: - self.fail() - - def testDecodePacketWithEmptyAttribute(self): - self.packet.DecodePacket( - b'\x01\x02\x00\x161234567890123456\x01\x02') - self.assertEqual(self.packet[1], [b'']) - - def testDecodePacketWithAttribute(self): - self.packet.DecodePacket( - b'\x01\x02\x00\x1b1234567890123456\x01\x07value') - self.assertEqual(self.packet[1], [b'value']) - - def testDecodePacketWithTlvAttribute(self): - self.packet.DecodePacket( - b'\x01\x02\x00\x1d1234567890123456\x04\x09\x01\x07value') - self.assertEqual(self.packet[4], {1: [b'value']}) - - def testDecodePacketWithVendorTlvAttribute(self): - self.packet.DecodePacket( - b'\x01\x02\x00\x231234567890123456\x1a\x0f\x00\x00\x00\x10\x03\x09\x01\x07value') - self.assertEqual(self.packet[(16, 3)], {1: [b'value']}) - - def testDecodePacketWithTlvAttributeWith2SubAttributes(self): - self.packet.DecodePacket( - b'\x01\x02\x00\x231234567890123456\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x09') - self.assertEqual(self.packet[4], {1: [b'value'], 2: [b'\x00\x00\x00\x09']}) - - def testDecodePacketWithSplitTlvAttribute(self): - self.packet.DecodePacket( - b'\x01\x02\x00\x251234567890123456\x04\x09\x01\x07value\x04\x09\x02\x06\x00\x00\x00\x09') - self.assertEqual(self.packet[4], {1: [b'value'], 2: [b'\x00\x00\x00\x09']}) - - def testDecodePacketWithMultiValuedAttribute(self): - self.packet.DecodePacket( - b'\x01\x02\x00\x1e1234567890123456\x01\x05one\x01\x05two') - self.assertEqual(self.packet[1], [b'one', b'two']) - - def testDecodePacketWithTwoAttributes(self): - self.packet.DecodePacket( - b'\x01\x02\x00\x1e1234567890123456\x01\x05one\x01\x05two') - self.assertEqual(self.packet[1], [b'one', b'two']) - - def testDecodePacketWithVendorAttribute(self): - self.packet.DecodePacket( - b'\x01\x02\x00\x1b1234567890123456\x1a\x07value') - self.assertEqual(self.packet[26], [b'value']) - - def testEncodeKeyValues(self): - self.assertEqual(self.packet._encode_key_values(1, '1234'), (1, '1234')) - - def testEncodeKey(self): - self.assertEqual(self.packet._encode_key(1), 1) - - def testAddAttribute(self): - self.packet.AddAttribute('Test-String', '1') - self.assertEqual(self.packet['Test-String'], ['1']) - self.packet.AddAttribute('Test-String', '1') - self.assertEqual(self.packet['Test-String'], ['1', '1']) - self.packet.AddAttribute('Test-String', ['2', '3']) - self.assertEqual(self.packet['Test-String'], ['1', '1', '2', '3']) - - -class AuthPacketConstructionTests(PacketConstructionTests): - klass = packet.AuthPacket - - def testConstructorDefaults(self): - pkt = self.klass() - self.assertEqual(pkt.code, packet.AccessRequest) - - -class AuthPacketTests(unittest.TestCase): - def setUp(self): - self.path = os.path.join(home, 'tests', 'data') - self.dict = Dictionary(os.path.join(self.path, 'full')) - self.packet = packet.AuthPacket( - id=0, secret=b'secret', - authenticator=b'01234567890ABCDEF', dict=self.dict) - - def testCreateReply(self): - reply = self.packet.CreateReply(**{'Test-Integer': 10}) - self.assertEqual(reply.code, packet.AccessAccept) - self.assertEqual(reply.id, self.packet.id) - self.assertEqual(reply.secret, self.packet.secret) - self.assertEqual(reply.authenticator, self.packet.authenticator) - self.assertEqual(reply['Test-Integer'], [10]) - - def testRequestPacket(self): - self.assertEqual(self.packet.RequestPacket(), - b'\x01\x00\x00\x1401234567890ABCDE') - - def testRequestPacketCreatesAuthenticator(self): - self.packet.authenticator = None - self.packet.RequestPacket() - self.assertTrue(self.packet.authenticator is not None) - - def testRequestPacketCreatesID(self): - self.packet.id = None - self.packet.RequestPacket() - self.assertTrue(self.packet.id is not None) - - def testPwCryptEmptyPassword(self): - self.assertEqual(self.packet.PwCrypt(''), b'') - - def testPwCryptPassword(self): - self.assertEqual(self.packet.PwCrypt('Simplon'), - b'\xd3U;\xb23\r\x11\xba\x07\xe3\xa8*\xa8x\x14\x01') - - def testPwCryptSetsAuthenticator(self): - self.packet.authenticator = None - self.packet.PwCrypt('') - self.assertTrue(self.packet.authenticator is not None) - - def testPwDecryptEmptyPassword(self): - self.assertEqual(self.packet.PwDecrypt(b''), '') - - def testPwDecryptPassword(self): - self.assertEqual( - self.packet.PwDecrypt(b'\xd3U;\xb23\r\x11\xba\x07\xe3\xa8*\xa8x\x14\x01'), - 'Simplon') - - -class AuthPacketChapTests(unittest.TestCase): - def setUp(self): - self.path = os.path.join(home, 'tests', 'data') - self.dict = Dictionary(os.path.join(self.path, 'chap')) - self.client = Client(server='localhost', secret=b'secret', - dict=self.dict) - - def testVerifyChapPasswd(self): - chap_id = b'9' - chap_challenge = b'987654321' - chap_password = chap_id + md5_constructor( - chap_id + b'test_password' + chap_challenge).digest() - pkt = self.client.CreateAuthPacket( - code=packet.AccessChallenge, - authenticator=b'ABCDEFG', - User_Name='test_name', - CHAP_Challenge=chap_challenge, - CHAP_Password=chap_password - ) - self.assertEqual(pkt['CHAP-Challenge'][0], chap_challenge) - self.assertEqual(pkt['CHAP-Password'][0], chap_password) - self.assertEqual(pkt.VerifyChapPasswd('test_password'), True) - - -class AuthPacketSaltTests(unittest.TestCase): - def setUp(self): - self.path = os.path.join(home, 'tests', 'data') - self.dict = Dictionary(os.path.join(self.path, 'tunnelPassword')) - self.packet = packet.Packet(id=0, secret=b'secret', - dict=self.dict) - - def testSaltCrypt(self): - self.packet['Tunnel-Password:1'] = 'test' - # TODO: need to get a correct reference values - # self.assertEqual(self.packet['Tunnel-Password'], b'') - - -class AcctPacketConstructionTests(PacketConstructionTests): - klass = packet.AcctPacket - - def testConstructorDefaults(self): - pkt = self.klass() - self.assertEqual(pkt.code, packet.AccountingRequest) - - def testConstructorRawPacket(self): - raw = (b'\x00\x00\x00\x14\xb0\x5e\x4b\xfb\xcc\x1c' - b'\x8c\x8e\xc4\x72\xac\xea\x87\x45\x63\xa7') - pkt = self.klass(packet=raw) - self.assertEqual(pkt.raw_packet, raw) - - -class AcctPacketTests(unittest.TestCase): - def setUp(self): - self.path = os.path.join(home, 'tests', 'data') - self.dict = Dictionary(os.path.join(self.path, 'full')) - self.packet = packet.AcctPacket( - id=0, secret=b'secret', - authenticator=b'01234567890ABCDEF', dict=self.dict) - - def testCreateReply(self): - reply = self.packet.CreateReply(**{'Test-Integer': 10}) - self.assertEqual(reply.code, packet.AccountingResponse) - self.assertEqual(reply.id, self.packet.id) - self.assertEqual(reply.secret, self.packet.secret) - self.assertEqual(reply.authenticator, self.packet.authenticator) - self.assertEqual(reply['Test-Integer'], [10]) - - def testVerifyAcctRequest(self): - rawpacket = self.packet.RequestPacket() - pkt = packet.AcctPacket(secret=b'secret', packet=rawpacket) - self.assertEqual(pkt.VerifyAcctRequest(), True) - - pkt.secret = b'different' - self.assertEqual(pkt.VerifyAcctRequest(), False) - pkt.secret = b'secret' - - pkt.raw_packet = b'X' + pkt.raw_packet[1:] - self.assertEqual(pkt.VerifyAcctRequest(), False) - - def testRequestPacket(self): - self.assertEqual( - self.packet.RequestPacket(), - b'\x04\x00\x00\x14\x95\xdf\x90\xccbn\xfb\x15G!\x13\xea\xfa>6\x0f') - - def testRequestPacketSetsId(self): - self.packet.id = None - self.packet.RequestPacket() - self.assertTrue(self.packet.id is not None) diff --git a/pyrad3/tests/testProxy.py b/pyrad3/tests/testProxy.py deleted file mode 100644 index ff63ee0..0000000 --- a/pyrad3/tests/testProxy.py +++ /dev/null @@ -1,99 +0,0 @@ -import select -import socket -import unittest -from pyrad.proxy import Proxy -from pyrad.packet import AccessAccept -from pyrad.packet import AccessRequest -from pyrad.server import ServerPacketError -from pyrad.server import Server -from pyrad.tests.mock import MockFd -from pyrad.tests.mock import MockPoll -from pyrad.tests.mock import MockSocket -from pyrad.tests.mock import MockClassMethod -from pyrad.tests.mock import UnmockClassMethods - - -class TrivialObject: - """dummy object""" - - -class SocketTests(unittest.TestCase): - def setUp(self): - self.orgsocket = socket.socket - socket.socket = MockSocket - self.proxy = Proxy() - self.proxy._fdmap = {} - - def tearDown(self): - socket.socket = self.orgsocket - - def testProxyFd(self): - self.proxy._poll = MockPoll() - self.proxy._prepare_sockets() - self.failUnless(isinstance(self.proxy._proxyfd, MockSocket)) - self.assertEqual(list(self.proxy._fdmap.keys()), [1]) - self.assertEqual( - self.proxy._poll.registry, - {1: select.POLLIN | select.POLLPRI | select.POLLERR}) - - -class ProxyPacketHandlingTests(unittest.TestCase): - def setUp(self): - self.proxy = Proxy() - self.proxy.hosts['host'] = TrivialObject() - self.proxy.hosts['host'].secret = 'supersecret' - self.packet = TrivialObject() - self.packet.code = AccessAccept - self.packet.source = ('host', 'port') - - def testHandleProxyPacketUnknownHost(self): - self.packet.source = ('stranger', 'port') - try: - self.proxy._handle_proxy_packet(self.packet) - except ServerPacketError as e: - self.failUnless('unknown host' in str(e)) - else: - self.fail() - - def testHandleProxyPacketSetsSecret(self): - self.proxy._handle_proxy_packet(self.packet) - self.assertEqual(self.packet.secret, 'supersecret') - - def testHandleProxyPacketHandlesWrongPacket(self): - self.packet.code = AccessRequest - try: - self.proxy._handle_proxy_packet(self.packet) - except ServerPacketError as e: - self.failUnless('non-response' in str(e)) - else: - self.fail() - - -class OtherTests(unittest.TestCase): - def setUp(self): - self.proxy = Proxy() - self.proxy._proxyfd = MockFd() - - def tearDown(self): - UnmockClassMethods(Proxy) - UnmockClassMethods(Server) - - def testProcessInputNonProxyPort(self): - fd = MockFd(fd=111) - MockClassMethod(Server, '_process_input') - self.proxy._process_input(fd) - self.assertEqual( - self.proxy.called, - [('_process_input', (fd,), {})]) - - def testProcessInput(self): - MockClassMethod(Proxy, '_grab_packet') - MockClassMethod(Proxy, '_handle_proxy_packet') - self.proxy._process_input(self.proxy._proxyfd) - self.assertEqual( - [x[0] for x in self.proxy.called], - ['_grab_packet', '_handle_proxy_packet']) - - -if not hasattr(select, 'poll'): - del SocketTests diff --git a/pyrad3/tests/testServer.py b/pyrad3/tests/testServer.py deleted file mode 100644 index 065b705..0000000 --- a/pyrad3/tests/testServer.py +++ /dev/null @@ -1,329 +0,0 @@ -import select -import socket -import unittest -from pyrad.packet import PacketError -from pyrad.server import RemoteHost -from pyrad.server import Server -from pyrad.server import ServerPacketError -from pyrad.tests.mock import MockFinished -from pyrad.tests.mock import MockFd -from pyrad.tests.mock import MockPoll -from pyrad.tests.mock import MockSocket -from pyrad.tests.mock import MockClassMethod -from pyrad.tests.mock import UnmockClassMethods -from pyrad.packet import AccessRequest -from pyrad.packet import AccountingRequest - - -class TrivialObject: - """dummy objec""" - - -class RemoteHostTests(unittest.TestCase): - def testSimpleConstruction(self): - host = RemoteHost('address', 'secret', 'name', 'authport', 'acctport', 'coaport') - self.assertEqual(host.address, 'address') - self.assertEqual(host.secret, 'secret') - self.assertEqual(host.name, 'name') - self.assertEqual(host.authport, 'authport') - self.assertEqual(host.acctport, 'acctport') - self.assertEqual(host.coaport, 'coaport') - - def testNamedConstruction(self): - host = RemoteHost( - address='address', secret='secret', name='name', - authport='authport', acctport='acctport', coaport='coaport') - self.assertEqual(host.address, 'address') - self.assertEqual(host.secret, 'secret') - self.assertEqual(host.name, 'name') - self.assertEqual(host.authport, 'authport') - self.assertEqual(host.acctport, 'acctport') - self.assertEqual(host.coaport, 'coaport') - - -class ServerConstructiontests(unittest.TestCase): - def testSimpleConstruction(self): - server = Server() - self.assertEqual(server.authfds, []) - self.assertEqual(server.acctfds, []) - self.assertEqual(server.authport, 1812) - self.assertEqual(server.acctport, 1813) - self.assertEqual(server.coaport, 3799) - self.assertEqual(server.hosts, {}) - - def testParameterOrder(self): - server = Server([], 'authport', 'acctport', 'coaport', 'hosts', 'dict') - self.assertEqual(server.authfds, []) - self.assertEqual(server.acctfds, []) - self.assertEqual(server.authport, 'authport') - self.assertEqual(server.acctport, 'acctport') - self.assertEqual(server.coaport, 'coaport') - self.assertEqual(server.dict, 'dict') - - def testBindDuringConstruction(self): - def BindToAddress(self, addr): - self.bound.append(addr) - bta = Server.BindToAddress - Server.BindToAddress = BindToAddress - - Server.bound = [] - server = Server(['one', 'two', 'three']) - self.assertEqual(server.bound, ['one', 'two', 'three']) - del Server.bound - - Server.BindToAddress = bta - - -class SocketTests(unittest.TestCase): - def setUp(self): - self.orgsocket = socket.socket - socket.socket = MockSocket - self.server = Server() - - def tearDown(self): - socket.socket = self.orgsocket - - def testBind(self): - self.server.BindToAddress('192.168.13.13') - self.assertEqual(len(self.server.authfds), 1) - self.assertEqual(self.server.authfds[0].address, - ('192.168.13.13', 1812)) - - self.assertEqual(len(self.server.acctfds), 1) - self.assertEqual(self.server.acctfds[0].address, - ('192.168.13.13', 1813)) - - def testBindv6(self): - self.server.BindToAddress('2001:db8:123::1') - self.assertEqual(len(self.server.authfds), 1) - self.assertEqual(self.server.authfds[0].address, - ('2001:db8:123::1', 1812)) - - self.assertEqual(len(self.server.acctfds), 1) - self.assertEqual(self.server.acctfds[0].address, - ('2001:db8:123::1', 1813)) - - def testGrabPacket(self): - def gen(data): - res = TrivialObject() - res.data = data - return res - - fd = MockFd() - fd.source = object() - pkt = self.server._grab_packet(gen, fd) - self.failUnless(isinstance(pkt, TrivialObject)) - self.failUnless(pkt.fd is fd) - self.failUnless(pkt.source is fd.source) - self.failUnless(pkt.data is fd.data) - - def testPrepareSocketNoFds(self): - self.server._poll = MockPoll() - self.server._prepare_sockets() - - self.assertEqual(self.server._poll.registry, {}) - self.assertEqual(self.server._realauthfds, []) - self.assertEqual(self.server._realacctfds, []) - - def testPrepareSocketAuthFds(self): - self.server._poll = MockPoll() - self.server._fdmap = {} - self.server.authfds = [MockFd(12), MockFd(14)] - self.server._prepare_sockets() - - self.assertEqual(list(self.server._fdmap.keys()), [12, 14]) - self.assertEqual( - self.server._poll.registry, - {12: select.POLLIN | select.POLLPRI | select.POLLERR, - 14: select.POLLIN | select.POLLPRI | select.POLLERR}) - - def testPrepareSocketAcctFds(self): - self.server._poll = MockPoll() - self.server._fdmap = {} - self.server.acctfds = [MockFd(12), MockFd(14)] - self.server._prepare_sockets() - - self.assertEqual(list(self.server._fdmap.keys()), [12, 14]) - self.assertEqual( - self.server._poll.registry, - {12: select.POLLIN | select.POLLPRI | select.POLLERR, - 14: select.POLLIN | select.POLLPRI | select.POLLERR}) - - -class AuthPacketHandlingTests(unittest.TestCase): - def setUp(self): - self.server = Server() - self.server.hosts['host'] = TrivialObject() - self.server.hosts['host'].secret = 'supersecret' - self.packet = TrivialObject() - self.packet.code = AccessRequest - self.packet.source = ('host', 'port') - - def testHandleAuthPacketUnknownHost(self): - self.packet.source = ('stranger', 'port') - try: - self.server._handle_auth_packet(self.packet) - except ServerPacketError as e: - self.failUnless('unknown host' in str(e)) - else: - self.fail() - - def testHandleAuthPacketWrongPort(self): - self.packet.code = AccountingRequest - try: - self.server._handle_auth_packet(self.packet) - except ServerPacketError as e: - self.failUnless('port' in str(e)) - else: - self.fail() - - def testHandleAuthPacket(self): - def HandleAuthPacket(self, pkt): - self.handled = pkt - hap = Server.HandleAuthPacket - Server.HandleAuthPacket = HandleAuthPacket - - self.server._handle_auth_packet(self.packet) - self.failUnless(self.server.handled is self.packet) - - Server.HandleAuthPacket = hap - - -class AcctPacketHandlingTests(unittest.TestCase): - def setUp(self): - self.server = Server() - self.server.hosts['host'] = TrivialObject() - self.server.hosts['host'].secret = 'supersecret' - self.packet = TrivialObject() - self.packet.code = AccountingRequest - self.packet.source = ('host', 'port') - - def testHandleAcctPacketUnknownHost(self): - self.packet.source = ('stranger', 'port') - try: - self.server._handle_acct_packet(self.packet) - except ServerPacketError as e: - self.failUnless('unknown host' in str(e)) - else: - self.fail() - - def testHandleAcctPacketWrongPort(self): - self.packet.code = AccessRequest - try: - self.server._handle_acct_packet(self.packet) - except ServerPacketError as e: - self.failUnless('port' in str(e)) - else: - self.fail() - - def testHandleAcctPacket(self): - def HandleAcctPacket(self, pkt): - self.handled = pkt - hap = Server.HandleAcctPacket - Server.HandleAcctPacket = HandleAcctPacket - - self.server._handle_acct_packet(self.packet) - self.failUnless(self.server.handled is self.packet) - - Server.HandleAcctPacket = hap - - -class OtherTests(unittest.TestCase): - def setUp(self): - self.server = Server() - - def tearDown(self): - UnmockClassMethods(Server) - - def testCreateReplyPacket(self): - class TrivialPacket: - source = object() - - def CreateReply(self, **kw): - reply = TrivialObject() - reply.kw = kw - return reply - - reply = self.server.CreateReplyPacket( - TrivialPacket(), - one='one', two='two') - self.failUnless(isinstance(reply, TrivialObject)) - self.failUnless(reply.source is TrivialPacket.source) - self.assertEqual(reply.kw, dict(one='one', two='two')) - - def testAuthProcessInput(self): - fd = MockFd(1) - self.server._realauthfds = [1] - MockClassMethod(Server, '_grab_packet') - MockClassMethod(Server, '_handle_auth_packet') - - self.server._process_input(fd) - self.assertEqual( - [x[0] for x in self.server.called], - ['_grab_packet', '_handle_auth_packet']) - self.assertEqual(self.server.called[0][1][1], fd) - - def testAcctProcessInput(self): - fd = MockFd(1) - self.server._realauthfds = [] - self.server._realacctfds = [1] - MockClassMethod(Server, '_grab_packet') - MockClassMethod(Server, '_handle_acct_packet') - - self.server._process_input(fd) - self.assertEqual( - [x[0] for x in self.server.called], - ['_grab_packet', '_handle_acct_packet']) - self.assertEqual(self.server.called[0][1][1], fd) - - -class ServerRunTests(unittest.TestCase): - def setUp(self): - self.server = Server() - self.origpoll = select.poll - select.poll = MockPoll - - def tearDown(self): - MockPoll.results = [] - select.poll = self.origpoll - UnmockClassMethods(Server) - - def testRunInitializes(self): - MockClassMethod(Server, '_prepare_sockets') - self.assertRaises(MockFinished, self.server.Run) - self.assertEqual(self.server.called, [('_prepare_sockets', (), {})]) - self.failUnless(isinstance(self.server._fdmap, dict)) - self.failUnless(isinstance(self.server._poll, MockPoll)) - - def testRunIgnoresPollErrors(self): - self.server.authfds = [MockFd()] - MockPoll.results = [(0, select.POLLERR)] - self.assertRaises(MockFinished, self.server.Run) - - def testRunIgnoresServerPacketErrors(self): - def RaisePacketError(self, fd): - raise ServerPacketError - MockClassMethod(Server, '_process_input', RaisePacketError) - self.server.authfds = [MockFd()] - MockPoll.results = [(0, select.POLLIN)] - self.assertRaises(MockFinished, self.server.Run) - - def testRunIgnoresPacketErrors(self): - def RaisePacketError(self, fd): - raise PacketError - MockClassMethod(Server, '_process_input', RaisePacketError) - self.server.authfds = [MockFd()] - MockPoll.results = [(0, select.POLLIN)] - self.assertRaises(MockFinished, self.server.Run) - - def testRunRunsProcessInput(self): - MockClassMethod(Server, '_process_input') - self.server.authfds = fd = [MockFd()] - MockPoll.results = [(0, select.POLLIN)] - self.assertRaises(MockFinished, self.server.Run) - self.assertEqual(self.server.called, [('_process_input', (fd[0],), {})]) - - -if not hasattr(select, 'poll'): - del SocketTests - del ServerRunTests diff --git a/pyrad3/tests/testTools.py b/pyrad3/tests/testTools.py deleted file mode 100644 index 3914bd3..0000000 --- a/pyrad3/tests/testTools.py +++ /dev/null @@ -1,119 +0,0 @@ -from ipaddress import AddressValueError -from pyrad import tools -import unittest - - -class EncodingTests(unittest.TestCase): - def testStringEncoding(self): - self.assertRaises(ValueError, tools.EncodeString, 'x' * 254) - self.assertEqual( - tools.EncodeString('1234567890'), - b'1234567890') - - def testInvalidStringEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeString, 1) - - def testAddressEncoding(self): - self.assertRaises(AddressValueError, tools.EncodeAddress, 'TEST123') - self.assertEqual( - tools.EncodeAddress('192.168.0.255'), - b'\xc0\xa8\x00\xff') - - def testInvalidAddressEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeAddress, 1) - - def testIntegerEncoding(self): - self.assertEqual(tools.EncodeInteger(0x01020304), b'\x01\x02\x03\x04') - - def testInteger64Encoding(self): - self.assertEqual( - tools.EncodeInteger64(0xFFFFFFFFFFFFFFFF), b'\xff' * 8 - ) - - def testUnsignedIntegerEncoding(self): - self.assertEqual(tools.EncodeInteger(0xFFFFFFFF), b'\xff\xff\xff\xff') - - def testInvalidIntegerEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeInteger, 'ONE') - - def testDateEncoding(self): - self.assertEqual(tools.EncodeDate(0x01020304), b'\x01\x02\x03\x04') - - def testInvalidDataEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeDate, '1') - - def testEncodeAscendBinary(self): - self.assertEqual( - tools.EncodeAscendBinary('family=ipv4 action=discard direction=in dst=10.10.255.254/32'), - b'\x01\x00\x01\x00\x00\x00\x00\x00\n\n\xff\xfe\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00') - - def testStringDecoding(self): - self.assertEqual( - tools.DecodeString(b'1234567890'), - '1234567890') - - def testAddressDecoding(self): - self.assertEqual( - tools.DecodeAddress(b'\xc0\xa8\x00\xff'), - '192.168.0.255') - - def testIntegerDecoding(self): - self.assertEqual( - tools.DecodeInteger(b'\x01\x02\x03\x04'), - 0x01020304) - - def testInteger64Decoding(self): - self.assertEqual( - tools.DecodeInteger64(b'\xff' * 8), 0xFFFFFFFFFFFFFFFF - ) - - def testDateDecoding(self): - self.assertEqual( - tools.DecodeDate(b'\x01\x02\x03\x04'), - 0x01020304) - - def testUnknownTypeEncoding(self): - self.assertRaises(ValueError, tools.EncodeAttr, 'unknown', None) - - def testUnknownTypeDecoding(self): - self.assertRaises(ValueError, tools.DecodeAttr, 'unknown', None) - - def testEncodeFunction(self): - self.assertEqual( - tools.EncodeAttr('string', 'string'), - b'string') - self.assertEqual( - tools.EncodeAttr('octets', b'string'), - b'string') - self.assertEqual( - tools.EncodeAttr('ipaddr', '192.168.0.255'), - b'\xc0\xa8\x00\xff') - self.assertEqual( - tools.EncodeAttr('integer', 0x01020304), - b'\x01\x02\x03\x04') - self.assertEqual( - tools.EncodeAttr('date', 0x01020304), - b'\x01\x02\x03\x04') - self.assertEqual( - tools.EncodeAttr('integer64', 0xFFFFFFFFFFFFFFFF), - b'\xff'*8) - - def testDecodeFunction(self): - self.assertEqual( - tools.DecodeAttr('string', b'string'), - 'string') - self.assertEqual( - tools.EncodeAttr('octets', b'string'), - b'string') - self.assertEqual( - tools.DecodeAttr('ipaddr', b'\xc0\xa8\x00\xff'), - '192.168.0.255') - self.assertEqual( - tools.DecodeAttr('integer', b'\x01\x02\x03\x04'), - 0x01020304) - self.assertEqual( - tools.DecodeAttr('integer64', b'\xff'*8), - 0xFFFFFFFFFFFFFFFF) - self.assertEqual( - tools.DecodeAttr('date', b'\x01\x02\x03\x04'), - 0x01020304) diff --git a/pyrad3/tools.py b/pyrad3/tools.py deleted file mode 100644 index 77a7cb9..0000000 --- a/pyrad3/tools.py +++ /dev/null @@ -1,226 +0,0 @@ -# tools.py -# -# Utility functions -import binascii -import ipaddress -import struct - - -def EncodeString(string): - if len(string) > 253: - raise ValueError('Can only encode strings of <= 253 characters') - if isinstance(string, str): - return string.encode('utf-8') - return string - - -def EncodeOctets(string): - if len(string) > 253: - raise ValueError('Can only encode strings of <= 253 characters') - - if string.startswith(b'0x'): - hexstring = string.split(b'0x')[1] - return binascii.unhexlify(hexstring) - else: - return string - - -def EncodeAddress(addr): - if not isinstance(addr, str): - raise TypeError('Address has to be a string') - return ipaddress.IPv4Address(addr).packed - - -def EncodeIPv6Prefix(addr): - if not isinstance(addr, str): - raise TypeError('IPv6 Prefix has to be a string') - ip = ipaddress.IPv6Network(addr) - return struct.pack('2B', *[0, ip.prefixlen]) + ip.network_address.packed - - -def EncodeIPv6Address(addr): - if not isinstance(addr, str): - raise TypeError('IPv6 Address has to be a string') - return ipaddress.IPv6Address(addr).packed - - -def EncodeAscendBinary(string): - """ - Format: List of type=value pairs sperated by spaces. - - Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32' - - Type: - family ipv4(default) or ipv6 - action discard(default) or accept - direction in(default) or out - src source prefix (default ignore) - dst destination prefix (default ignore) - proto protocol number / next-header number (default ignore) - sport source port (default ignore) - dport destination port (default ignore) - sportq source port qualifier (default 0) - dportq destination port qualifier (default 0) - - Source/Destination Port Qualifier: - 0 no compare - 1 less than - 2 equal to - 3 greater than - 4 not equal to - """ - - terms = { - 'family': b'\x01', - 'action': b'\x00', - 'direction': b'\x01', - 'src': b'\x00\x00\x00\x00', - 'dst': b'\x00\x00\x00\x00', - 'srcl': b'\x00', - 'dstl': b'\x00', - 'proto': b'\x00', - 'sport': b'\x00\x00', - 'dport': b'\x00\x00', - 'sportq': b'\x00', - 'dportq': b'\x00' - } - - for t in string.split(' '): - key, value = t.split('=') - if key == 'family' and value == 'ipv6': - terms[key] = b'\x03' - if terms['src'] == b'\x00\x00\x00\x00': - terms['src'] = 16 * b'\x00' - if terms['dst'] == b'\x00\x00\x00\x00': - terms['dst'] = 16 * b'\x00' - elif key == 'action' and value == 'accept': - terms[key] = b'\x01' - elif key == 'direction' and value == 'out': - terms[key] = b'\x00' - elif key in ('src', 'dst'): - ip = ipaddress.ip_network(value) - terms[key] = ip.network_address.packed - terms[key+'l'] = struct.pack('B', ip.prefixlen) - elif key in ('sport', 'dport'): - terms[key] = struct.pack('!H', int(value)) - elif key in ('sportq', 'dportq', 'proto'): - terms[key] = struct.pack('B', int(value)) - - trailer = 8 * b'\x00' - - result = b''.join(( - terms['family'], terms['action'], terms['direction'], b'\x00', - terms['src'], terms['dst'], terms['srcl'], terms['dstl'], terms['proto'], b'\x00', - terms['sport'], terms['dport'], terms['sportq'], terms['dportq'], b'\x00\x00', trailer)) - return result - - -def EncodeInteger(num, format='!I'): - try: - num = int(num) - except ValueError: - raise TypeError('Can not encode non-integer as integer') - return struct.pack(format, num) - - -def EncodeInteger64(num, format='!Q'): - try: - num = int(num) - except ValueError: - raise TypeError('Can not encode non-integer as integer64') - return struct.pack(format, num) - - -def EncodeDate(num): - if not isinstance(num, int): - raise TypeError('Can not encode non-integer as date') - return struct.pack('!I', num) - - -def DecodeString(string): - try: - return string.decode('utf-8') - except: - return string - - -def DecodeOctets(string): - return string - - -def DecodeAddress(addr): - return '.'.join((str(a) for a in struct.unpack('BBBB', addr))) - - -def DecodeIPv6Prefix(addr): - addr = addr + b'\x00' * (18-len(addr)) - prefix = addr[:2] - addr = addr[2:] - return str(ipaddress.ip_network((prefix, addr))) - - -def DecodeIPv6Address(addr): - addr = addr + b'\x00' * (16-len(addr)) - return str(ipaddress.IPv6Address(addr)) - - -def DecodeAscendBinary(string): - return string - - -def DecodeInteger(num, format='!I'): - return (struct.unpack(format, num))[0] - - -def DecodeInteger64(num, format='!Q'): - return (struct.unpack(format, num))[0] - - -def DecodeDate(num): - return (struct.unpack('!I', num))[0] - - -ENCODE_MAP = { - 'string': EncodeString, - 'octets': EncodeOctets, - 'integer': EncodeInteger, - 'ipaddr': EncodeAddress, - 'ipv6prefix': EncodeIPv6Prefix, - 'ipv6addr': EncodeIPv6Address, - 'abinary': EncodeAscendBinary, - 'signed': lambda value: EncodeInteger(value, '!i'), - 'short': lambda value: EncodeInteger(value, '!H'), - 'byte': lambda value: EncodeInteger(value, '!B'), - 'date': EncodeDate, - 'integer64': EncodeInteger64, -} - - -def EncodeAttr(datatype, value): - try: - return ENCODE_MAP[datatype](value) - except KeyError: - raise ValueError(f'Unknown attribute type {datatype}') - - -DECODE_MAP = { - 'string': DecodeString, - 'octets': DecodeOctets, - 'integer': DecodeInteger, - 'ipaddr': DecodeAddress, - 'ipv6prefix': DecodeIPv6Prefix, - 'ipv6addr': DecodeIPv6Address, - 'abinary': DecodeAscendBinary, - 'signed': lambda value: DecodeInteger(value, '!i'), - 'short': lambda value: DecodeInteger(value, '!H'), - 'byte': lambda value: DecodeInteger(value, '!B'), - 'date': DecodeDate, - 'integer64': DecodeInteger64, -} - - -def DecodeAttr(datatype, value): - try: - return DECODE_MAP[datatype](value) - except KeyError: - raise ValueError(f'Unknown attribute type {datatype}') diff --git a/pyrad3/utils.py b/pyrad3/utils.py deleted file mode 100644 index bb14f06..0000000 --- a/pyrad3/utils.py +++ /dev/null @@ -1,57 +0,0 @@ -from pyrad.packet import Packet, random_generator - - -def salt_encrypt(packet: Packet, value: bytes) -> bytes: - length = struct.pack('B', len(value)) - buf = length + value - buf += b'\x00' * (16 - (len(buf) % 16)) - - # First bit if the random value must be 1 - random_value = 32768 + random_generator.randrange(0, 32767) - result = struct.pack('!H', random_value) - - last = packet.authenticator + result - while buf: - cur_hash = md5_constructor(packet.secret + last).digest() - for b, h in zip(buf, cur_hash): - result += bytes([b ^ h]) - last = result[-16:] - buf = buf[16:] - - return result - - -def validate_chap_password(packet: Packet, password: bytes) -> bool: - # TODO: - challange = packet.get('CHAP-Challenge', packet.authenticator) - return chap_password == hashlib.md5(chapid + password + challenge).digest() - - -def validate_pap_password(packet: Packet, password: bytes) -> bool: - obf_pass = password_encode(packet, password) - return packet['Password'] == obf_pass - - -def password_encode(packet: Packet, password: bytes) -> bytes: - password += b'\x00' * (16 - (len(password) % 16)) - return obfuscation_algorithm(packet, password) - - -def password_decode(packet: Packet, password: bytes) -> bytes: - decoded = obfuscation_algorithm(packet, password) - return decoded.rstrip(b'\x00') - - -def obfuscation_algorithm(packet: Packet, password: bytes) -> bytes: - result = b'' - buf = password - last = packet.authenticator - - while buf: - cur_hash = md5_constructor(packet.secret + last) - for b, h in zip(buf, cur_hash): - result += bytes([b ^ h]) - (last, buf) = (buf[:16], buf[16:]) - - return result - diff --git a/shell.nix b/shell.nix index 64d3be1..3379b60 100644 --- a/shell.nix +++ b/shell.nix @@ -1,5 +1,20 @@ let pkgs = import {}; - python = pkgs.python36; in - import ./default.nix { inherit pkgs python; } +pkgs.mkShell { + buildInputs = with pkgs; [ + git + nixfmt + pkgs.python3 + pkgs.python3Packages.black + pkgs.python3Packages.pytest + pkgs.python3Packages.pytest-black + pkgs.python3Packages.pytest-flake8 + pkgs.python3Packages.pytest-mypy + pkgs.python3Packages.pytest-pylint + ]; + + shellHook = '' + export PYTHONPATH=${./src}:$PYTHONPATH + ''; +} diff --git a/pyrad3/__init__.py b/src/pyrad3/__init__.py similarity index 84% rename from pyrad3/__init__.py rename to src/pyrad3/__init__.py index 3aa65fc..f771b8b 100644 --- a/pyrad3/__init__.py +++ b/src/pyrad3/__init__.py @@ -38,9 +38,9 @@ This package contains four modules: __docformat__ = 'epytext en' -__author__ = 'Christian Giese ' +__author__ = 'Istvan Ruzman ' __url__ = 'http://pyrad.readthedocs.io/en/latest/?badge=latest' -__copyright__ = 'Copyright 2002-2020 Wichert Akkerman and Christian Giese. All rights reserved.' -__version__ = '2.3' +__copyright__ = 'Copyright 2020 Istvan Ruzman' +__version__ = '0.1.0' -__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'dictfile', 'new_client', 'new_host', 'new_packet'] +__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'utils'] diff --git a/pyrad3/bidict.py b/src/pyrad3/bidict.py similarity index 98% rename from pyrad3/bidict.py rename to src/pyrad3/bidict.py index f246672..15aaa21 100644 --- a/pyrad3/bidict.py +++ b/src/pyrad3/bidict.py @@ -3,7 +3,7 @@ # Bidirectional map -class BiDict(): +class BiDict: def __init__(self): self.forward = {} self.backward = {} diff --git a/pyrad3/new_client.py b/src/pyrad3/client.py similarity index 55% rename from pyrad3/new_client.py rename to src/pyrad3/client.py index 9491cfb..ae924f4 100644 --- a/pyrad3/new_client.py +++ b/src/pyrad3/client.py @@ -1,63 +1,94 @@ # Copyright 2020 Istvan Ruzman # SPDX-License-Identifier: MIT OR Apache-2.0 +"""Implementation of a simple but extensible RADIUS Client""" + +from typing import Optional, Union +from ipaddress import IPv4Address, IPv6Address + import select import socket import time -from pyrad3 import new_packet as P -from pyrad3 import new_host +import pyrad3.packet as P +from pyrad3 import host SUPPORTED_SEND_TYPES = [ - P.AccessRequest, - P.AccountingRequest, - P.CoARequest, + P.Code.AccessRequest, + P.Code.AccountingRequest, + P.Code.CoARequest, ] PACKET_TYPE_PORT_MAPPING = { - P.AccessRequest: 'authport', - P.AccountingRequest: 'acctport', - P.CoARequest: 'coaport', + P.Code.AccessRequest: "authport", + P.Code.AccountingRequest: "acctport", + P.Code.CoARequest: "coaport", } + class Timeout(Exception): - pass + """Exception for wait timeouts""" + class UnsupportedPacketType(Exception): - pass + """Exception for received packets""" -class Client(new_host.Host): - def __init__(self, server, secret, radius_dictionary, **kwargs): +class Client(host.Host): + """A simple and extensible RADIUS Client.""" + + def __init__( + self, + server: Union[str, IPv4Address, IPv6Address], + secret: bytes, + radius_dictionary: dict, + interface: Optional[str], + **kwargs, + ): super().__init__(secret, radius_dictionary, **kwargs) self.server = server - self._socket = None - self._poll = None + self.interface = interface + self._socket: Optional[socket.socket] = None + self._poll: Optional[select.poll] = None def bind(self, addr): + """Bind the Address to some socket""" self._socket_close() self._socket_open() self._socket.bind(addr) def _socket_open(self): + """Open a client socket""" if self._socket is not None: return try: - family = socket.getaddrinfo(self.server, 'www')[0][0] + family = socket.getaddrinfo(self.server, "www")[0][0] except socket.gaierror: family = socket.AF_INET self._socket = socket.socket(family, socket.SOCK_DGRAM) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if self.interface is not None: + # This will fail on non-Linux systems - that's ok, because we don't + # have any implementation yet for non-Linux systems. + # Better to fail loudly than to do something unexpected silently. + self._socket.setsockopt( + socket.SOL_SOCKET, + socket.SO_BINDTODEVICE, + self.interface, + len(self.interface), + ) self._poll = select.poll() self._poll.register(self._socket, select.POLLIN) def _socket_close(self): + """Close the Client socket""" if self._socket is not None: self._poll.unregister(self._socket) self._socket.close() self._socket = None - def _select_port(self, packet): + def _select_port(self, packet: P.Packet): + """Select the RADIUS Port depending on the RADIUS Packet type""" try: port_type = PACKET_TYPE_PORT_MAPPING[packet.code] return getattr(self, port_type) @@ -65,7 +96,11 @@ class Client(new_host.Host): pass raise UnsupportedPacketType(f"The packet type {packet.code} by Client") - def _send_packet(self, packet): + def _send_packet(self, packet: P.Packet): + """Send a RADIUS Packet and wait for the reply""" + assert self._socket is not None + assert self._poll is not None + port = self._select_port(packet) raw_packet = packet.serialize() diff --git a/pyrad3/client_async.py b/src/pyrad3/client_async.py similarity index 68% rename from pyrad3/client_async.py rename to src/pyrad3/client_async.py index 31b606c..783d691 100644 --- a/pyrad3/client_async.py +++ b/src/pyrad3/client_async.py @@ -13,9 +13,7 @@ from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket class DatagramProtocolClient(asyncio.Protocol): - - def __init__(self, server, port, logger, - client, retries=3, timeout=30): + def __init__(self, server, port, logger, client, retries=3, timeout=30): self.transport = None self.port = port self.server = server @@ -42,22 +40,31 @@ class DatagramProtocolClient(asyncio.Protocol): # noinspection PyShadowingBuiltins for id, req in self.pending_requests.items(): - secs = (req['send_date'] - now).seconds + secs = (req["send_date"] - now).seconds if secs > self.timeout: - if req['retries'] == self.retries: - self.logger.debug('[%s:%d] For request %d execute all retries', - self.server, self.port, id) - req['future'].set_exception( - TimeoutError('Timeout on Reply') + if req["retries"] == self.retries: + self.logger.debug( + "[%s:%d] For request %d execute all retries", + self.server, + self.port, + id, + ) + req["future"].set_exception( + TimeoutError("Timeout on Reply") ) req2delete.append(id) else: # Send again packet - req['send_date'] = now - req['retries'] += 1 - self.logger.debug('[%s:%d] For request %d execute retry %d', - self.server, self.port, id, req['retries']) - self.transport.sendto(req['packet'].RequestPacket()) + req["send_date"] = now + req["retries"] += 1 + self.logger.debug( + "[%s:%d] For request %d execute retry %d", + self.server, + self.port, + id, + req["retries"], + ) + self.transport.sendto(req["packet"].RequestPacket()) elif next_weak_up > secs: next_weak_up = secs @@ -72,15 +79,15 @@ class DatagramProtocolClient(asyncio.Protocol): def send_packet(self, packet, future): if packet.id in self.pending_requests: - raise Exception(f'Packet with id {packet.id} already present') + raise Exception(f"Packet with id {packet.id} already present") # Store packet on pending requests map self.pending_requests[packet.id] = { - 'packet': packet, - 'creation_date': datetime.now(), - 'retries': 0, - 'future': future, - 'send_date': datetime.now() + "packet": packet, + "creation_date": datetime.now(), + "retries": 0, + "future": future, + "send_date": datetime.now(), } # In queue packet raw on socket buffer @@ -88,10 +95,11 @@ class DatagramProtocolClient(asyncio.Protocol): def connection_made(self, transport): self.transport = transport - socket = transport.get_extra_info('socket') + socket = transport.get_extra_info("socket") self.logger.info( - '[%s:%d] Transport created with binding in %s:%d', - self.server, self.port, + "[%s:%d] Transport created with binding in %s:%d", + self.server, + self.port, socket.getsockname()[0], socket.getsockname()[1], ) @@ -99,38 +107,47 @@ class DatagramProtocolClient(asyncio.Protocol): pre_loop = asyncio.get_event_loop() asyncio.set_event_loop(loop=self.client.loop) # Start asynchronous timer handler - self.timeout_future = asyncio.ensure_future( - self.__timeout_handler__() - ) + self.timeout_future = asyncio.ensure_future(self.__timeout_handler__()) asyncio.set_event_loop(loop=pre_loop) def error_received(self, exc): - self.logger.error('[%s:%d] Error received: %s', self.server, self.port, exc) + self.logger.error( + "[%s:%d] Error received: %s", self.server, self.port, exc + ) def connection_lost(self, exc): if exc: - self.logger.warn('[%s:%d] Connection lost: %s', self.server, self.port, str(exc)) + self.logger.warn( + "[%s:%d] Connection lost: %s", self.server, self.port, str(exc) + ) else: - self.logger.info('[%s:%d] Transport closed', self.server, self.port) + self.logger.info("[%s:%d] Transport closed", self.server, self.port) # noinspection PyUnusedLocal def datagram_received(self, data, addr): try: req = self.pending_requests[data[0]] reply = req.VerifyPacket(data) - req['future'].set_result(reply) + req["future"].set_result(reply) # Remove request for map del self.pending_requests[reply.id] except KeyError: - self.logger.warn('[%s:%d] Ignore invalid reply: %s', - self.server, self.port, data) + self.logger.warn( + "[%s:%d] Ignore invalid reply: %s", self.server, self.port, data + ) except PacketError as exc: - self.logger.error('[%s:%d] Error on decode or verify packet: %s', - self.server, self.port, exc) + self.logger.error( + "[%s:%d] Error on decode or verify packet: %s", + self.server, + self.port, + exc, + ) async def close_transport(self): if self.transport: - self.logger.debug('[%s:%d] Closing transport...', self.server, self.port) + self.logger.debug( + "[%s:%d] Closing transport...", self.server, self.port + ) self.transport.close() self.transport = None if self.timeout_future: @@ -143,7 +160,9 @@ class DatagramProtocolClient(asyncio.Protocol): return self.packet_id def __str__(self): - return f'DatagramProtocolClient(server?={self.server}, port={self.port})' + return ( + f"DatagramProtocolClient(server?={self.server}, port={self.port})" + ) # Used as protocol_factory def __call__(self): @@ -161,11 +180,21 @@ class ClientAsync: :ivar timeout: number of seconds to wait for an answer :type timeout: integer """ + # noinspection PyShadowingBuiltins - def __init__(self, server, auth_port=1812, acct_port=1813, - coa_port=3799, secret=b'', dict=None, - loop=None, retries=3, timeout=30, - logger_name='pyrad'): + def __init__( + self, + server, + auth_port=1812, + acct_port=1813, + coa_port=3799, + secret=b"", + dict=None, + loop=None, + retries=3, + timeout=30, + logger_name="pyrad", + ): """Constructor. @@ -205,23 +234,30 @@ class ClientAsync: self.protocol_coa = None self.coa_port = coa_port - async def initialize_transports(self, enable_acct=False, - enable_auth=False, enable_coa=False, - local_addr=None, local_auth_port=None, - local_acct_port=None, local_coa_port=None): + async def initialize_transports( + self, + enable_acct=False, + enable_auth=False, + enable_coa=False, + local_addr=None, + local_auth_port=None, + local_acct_port=None, + local_coa_port=None, + ): task_list = [] if not enable_acct and not enable_auth and not enable_coa: - raise Exception('No transports selected') + raise Exception("No transports selected") if enable_acct and not self.protocol_acct: self.protocol_acct = DatagramProtocolClient( self.server, self.acct_port, - self.logger, self, + self.logger, + self, retries=self.retries, - timeout=self.timeout + timeout=self.timeout, ) bind_addr = None if local_addr and local_acct_port: @@ -231,7 +267,7 @@ class ClientAsync: self.protocol_acct, reuse_port=True, remote_addr=(self.server, self.acct_port), - local_addr=bind_addr + local_addr=bind_addr, ) task_list.append(acct_connect) @@ -239,9 +275,10 @@ class ClientAsync: self.protocol_auth = DatagramProtocolClient( self.server, self.auth_port, - self.logger, self, + self.logger, + self, retries=self.retries, - timeout=self.timeout + timeout=self.timeout, ) bind_addr = None if local_addr and local_auth_port: @@ -251,7 +288,7 @@ class ClientAsync: self.protocol_auth, reuse_port=True, remote_addr=(self.server, self.auth_port), - local_addr=bind_addr + local_addr=bind_addr, ) task_list.append(auth_connect) @@ -259,9 +296,10 @@ class ClientAsync: self.protocol_coa = DatagramProtocolClient( self.server, self.coa_port, - self.logger, self, + self.logger, + self, retries=self.retries, - timeout=self.timeout + timeout=self.timeout, ) bind_addr = None if local_addr and local_coa_port: @@ -271,22 +309,18 @@ class ClientAsync: self.protocol_coa, reuse_port=True, remote_addr=(self.server, self.coa_port), - local_addr=bind_addr + local_addr=bind_addr, ) task_list.append(coa_connect) await asyncio.ensure_future( - asyncio.gather( - *task_list, - return_exceptions=False, - ), - loop=self.loop + asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop ) # noinspection SpellCheckingInspection - async def deinitialize_transports(self, deinit_coa=True, - deinit_auth=True, - deinit_acct=True): + async def deinitialize_transports( + self, deinit_coa=True, deinit_auth=True, deinit_acct=True + ): if self.protocol_coa and deinit_coa: await self.protocol_coa.close_transport() del self.protocol_coa @@ -312,11 +346,14 @@ class ClientAsync: :rtype: pyrad.packet.Packet """ if not self.protocol_auth: - raise Exception('Transport not initialized') + raise Exception("Transport not initialized") - return AuthPacket(dict=self.dict, - id=self.protocol_auth.create_id(), - secret=self.secret, **args) + return AuthPacket( + dict=self.dict, + id=self.protocol_auth.create_id(), + secret=self.secret, + **args, + ) # noinspection PyPep8Naming def CreateAcctPacket(self, **args): @@ -330,11 +367,14 @@ class ClientAsync: :rtype: pyrad.packet.Packet """ if not self.protocol_acct: - raise Exception('Transport not initialized') + raise Exception("Transport not initialized") - return AcctPacket(id=self.protocol_acct.create_id(), - dict=self.dict, - secret=self.secret, **args) + return AcctPacket( + id=self.protocol_acct.create_id(), + dict=self.dict, + secret=self.secret, + **args, + ) # noinspection PyPep8Naming def CreateCoAPacket(self, **args): @@ -349,20 +389,22 @@ class ClientAsync: """ if not self.protocol_acct: - raise Exception('Transport not initialized') + raise Exception("Transport not initialized") - return CoAPacket(id=self.protocol_coa.create_id(), - dict=self.dict, - secret=self.secret, **args) + return CoAPacket( + id=self.protocol_coa.create_id(), + dict=self.dict, + secret=self.secret, + **args, + ) # noinspection PyPep8Naming # noinspection PyShadowingBuiltins def CreatePacket(self, id, **args): if not id: - raise Exception('Missing mandatory packet id') + raise Exception("Missing mandatory packet id") - return Packet(id=id, dict=self.dict, - secret=self.secret, **args) + return Packet(id=id, dict=self.dict, secret=self.secret, **args) # noinspection PyPep8Naming def SendPacket(self, pkt): @@ -378,23 +420,23 @@ class ClientAsync: if isinstance(pkt, AuthPacket): if not self.protocol_auth: - raise Exception('Transport not initialized') + raise Exception("Transport not initialized") self.protocol_auth.send_packet(pkt, ans) elif isinstance(pkt, AcctPacket): if not self.protocol_acct: - raise Exception('Transport not initialized') + raise Exception("Transport not initialized") self.protocol_acct.send_packet(pkt, ans) elif isinstance(pkt, CoAPacket): if not self.protocol_coa: - raise Exception('Transport not initialized') + raise Exception("Transport not initialized") self.protocol_coa.send_packet(pkt, ans) else: - raise Exception('Unsupported packet') + raise Exception("Unsupported packet") return ans diff --git a/src/pyrad3/dictionary.py b/src/pyrad3/dictionary.py new file mode 100644 index 0000000..54a960d --- /dev/null +++ b/src/pyrad3/dictionary.py @@ -0,0 +1,483 @@ +# Copyright 2020 Istvan Ruzman +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""RADIUS Dictionary. + +Classes and Types to parse and represent a RADIUS dictionary. +""" + +from enum import IntEnum, Enum, auto +from dataclasses import dataclass +from os.path import dirname, isabs, join, normpath +from typing import Dict, Generator, IO, List, Optional, Sequence, Tuple, Union + +import logging + +LOG = logging.getLogger(__name__) + + +INTEGER_TYPES = { + "byte": (0, 255), + "short": (0, 2 ** 16 - 1), + "signed": (-(2 ** 31), 2 ** 31 - 1), + "integer": (0, 2 ** 32 - 1), + "integer64": (0, 2 ** 64 - 1), +} + + +class Datatype(Enum): + """Possible Datatypes for ATTRIBUTES""" + + string = auto() + octets = auto() + date = auto() + abinary = auto() + byte = auto() + short = auto() + integer = auto() + signed = auto() + integer64 = auto() + ipaddr = auto() + ipv4prefix = auto() + ipv6addr = auto() + ipv6prefix = auto() + comboip = auto() + ifid = auto() + ether = auto() + concat = auto() + tlv = auto() + extended = auto() + longextended = auto() + evs = auto() + + +class ParseError(Exception): + """RADIUS Dictionary Parser Error""" + + def __init__( + self, + filename: str, + msg: str = None, + line: Optional[int] = None, + **data, + ): + super().__init__() + self.msg = msg + self.file = filename + self.line = line + self.data = data + + def __str__(self): + line = f"({self.line}" if self.line is not None else "" + return f"{self.file}{line}: ParseError: {self.msg}" + + +class Encrypt(IntEnum): + """Enum for different RADIUS Encryption types.""" + + NoEncrpytion = 0 + RadiusCrypt = 1 + SaltCrypt = 2 + AscendCrypt = 3 + + +@dataclass +class Attribute: # pylint: disable=too-many-instance-attributes + """RADIUS Attribute definition""" + + name: str + code: int + datatype: Datatype + has_tag: bool = False + encrypt: Encrypt = Encrypt(0) + is_sub_attr: bool = False + # vendor = Dictionary + values: Dict[Union[int, str], Union[int, str]] = None + + +@dataclass +class Vendor: + """Representation of a vendor""" + + name: str + code: int + tlength: int + llength: int + continuation: bool + attrs: Dict[Union[int, Tuple[int, ...]], Attribute] + + +def dict_parser( + filename: str, rad_dict: IO +) -> Generator[Tuple[int, List[str]], None, None]: + """Tokenstream of RADIUS Dictionary files + + Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED" + and "FILE_CLOSED" tokens will be emitted. + """ + yield (-1, ["FILE_OPENED", filename]) + for line_num, line in enumerate(rad_dict.readlines()): + tokens = line.split("#", 1)[0].strip().split() + if tokens: + first_tok = tokens[0] = tokens[0].upper() + if first_tok == "$INCLUDE": + try: + inner_filename = tokens[1] + except IndexError: + raise ParseError( + filename, "$INCLUDE is missing a filename", line_num, + ) + if not isabs(tokens[1]): + path = dirname(filename) + inner_filename = normpath(join(path, inner_filename)) + yield from dict_loader(inner_filename) + yield (line_num, tokens) + yield (-1, ["FILE_CLOSED"]) + + +def dict_loader(filename: str) -> Generator[Tuple[int, List[str]], None, None]: + """Tokenstream of RADIUS Dictionary files + + Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED" + and "FILE_CLOSED" tokens will be emitted. + """ + with open(filename, "r") as rad_dict: + yield from dict_parser(filename, rad_dict) + + +def _parse_number(num: str) -> int: + """Parse a number from (Free)RADIUS dictionaries + + Numbers can be either decimal, octal, or hexadecimal. + """ + if num.startswith("0x"): + return int(num, 16) + if num.startswith("0o"): + return int(num, 8) + return int(num) + + +def _parse_attribute_code(attr_code: str) -> List[int]: + """Parse attribute codes from (Free)RADIUS dictionaries + + Codes can be either decimal, octal, or hexadecimal. + TLV typed can + """ + codes = [] + for code in attr_code.split("."): + codes.append(_parse_number(code)) + return codes + + +class Dictionary: + """(Free)RADIUS Dictionary. + + #TODO: Better documentation + This dictionary can "contain" multiple dictionaries. + """ + + # there must be some nicer way to unittest this... + def __init__(self, dictionary: str, __dictio: Optional[IO] = None): + self.vendor: Dict[int, Vendor] = {} + self.vendor_lookup_id_by_name: Dict[str, int] = {} + self.attrindex: Dict[Union[int, str], Attribute] = {} + self.rfc_vendor = Vendor("RFC", 0, 1, 1, False, {}) + self.cur_vendor = self.rfc_vendor + if __dictio is not None: + loader = dict_parser(dictionary, __dictio) + else: + loader = dict_loader(dictionary) + self.read_dictionary(loader) + + def read_dictionary( + self, reader: Generator[Tuple[int, List[str]], None, None] + ): + """Read and parse a (Free)RADIUS dictionary.""" + self.filestack = [] + for line_num, tokens in reader: + key = tokens[0] + if key == "ATTRIBUTE": + self._parse_attribute(tokens, line_num) + elif key == "VALUE": + self._parse_value(tokens, line_num) + elif key == "FILE_OPENED": + LOG.info("Parsing file: %s", tokens[1]) + if tokens[1] in self.filestack: + raise ParseError( + self.filestack[-1], "Include recursion detected" + ) + self.filestack.append(tokens[1]) + elif key == "FILE_CLOSED": + filename = self.filestack.pop() + LOG.info("Finished parsing file: %s", filename) + elif key == "VENDOR": + self._parse_vendor(tokens, line_num) + elif key == "BEGIN-VENDOR": + self._parse_begin_vendor(tokens, line_num) + elif key == "END-VENDOR": + self._parse_end_vendor(tokens, line_num) + elif key == "BEGIN-TLV": + raise NotImplementedError( + "BEGIN-TLV is deprecated and not supported by pyrad3" + ) + elif key == "END-TLV": + raise NotImplementedError( + "END-TLV is deprecated and not supported by pyrad3" + ) + else: + raise ParseError( + self.filestack[-1], f"Invalid Token key {key}", line_num + ) + + def _parse_vendor(self, tokens: Sequence[str], line_num: int): + """Parse the vendor definition""" + filename = self.filestack[-1] + if len(tokens) not in {3, 4}: + raise ParseError( + filename, "Incorrect number of tokens for vendor statement" + ) + vendor_name = tokens[1] + vendor_id = int(tokens[2], 0) + continuation = False + + # Parse optional vendor specification + try: + vendor_format = tokens[3].split("=") + if vendor_format[0] != "format": + raise ParseError( + filename, + f"Unknown option {vendor_format[0]} for vendor definition", + line_num, + ) + try: + vendor_format = vendor_format[1].split(",") + t_len, l_len = (int(a) for a in vendor_format[:2]) + if t_len not in {1, 2, 4}: + raise ParseError( + filename, + f'Invalid type length definition "{t_len}" for vendor {vendor_name}', + line_num, + ) + if l_len not in {0, 1, 2}: + raise ParseError( + filename, + f'Invalid length definition "{l_len}" for vendor {vendor_name}', + line_num, + ) + try: + if vendor_format[2] == "c": + if not vendor_name.upper() == "WIMAX": + # Not sure why, but FreeRADIUS has this limit, + # so we just do the same cause they know better than me + raise ParseError( + filename, + "continuation-bit is only supported for WiMAX", + line_num, + ) + continuation = True + except IndexError: + pass + except ValueError: + raise ParseError( + filename, + f"Syntax error in specification for vendor {vendor_name}", + line_num, + ) + except IndexError: + # no format definition + t_len, l_len = 1, 1 + + vendor = Vendor(vendor_name, vendor_id, t_len, l_len, continuation, {}) + self.vendor_lookup_id_by_name[vendor_name] = vendor_id + self.vendor[vendor_id] = vendor + + def _parse_begin_vendor(self, tokens: Sequence[str], line_num: int): + """Parse the BEGIN-VENDOR line of (Free)RADIUS dictionaries.""" + filename = self.filestack[-1] + if self.cur_vendor != self.rfc_vendor: + raise ParseError( + filename, + "vendor-begin sections are not allowed to be nested", + line_num, + ) + if len(tokens) != 2: + raise ParseError( + filename, + "Incorrect number of tokens for begin-vendor statement", + line_num, + ) + try: + vendor_id = self.vendor_lookup_id_by_name[tokens[1]] + self.cur_vendor = self.vendor[vendor_id] + except KeyError: + raise ParseError( + filename, + f"Unknown vendor {tokens[1]} in begin-vendor statement", + line_num, + ) + + def _parse_end_vendor(self, tokens: Sequence[str], line_num: int): + """Parse the END-VENDOR line of (Free)RADIUS dictionaries.""" + filename = self.filestack[-1] + if len(tokens) != 2: + raise ParseError( + filename, + "Incorrect number of tokens for end-vendor statement", + line_num, + ) + if self.cur_vendor.name != tokens[1]: + raise ParseError( + filename, + f"Closing non-opened vendor {tokens[1]} in end-vendor statement", + line_num, + ) + self.cur_vendor = self.rfc_vendor + + def _parse_attribute_flags( + self, tokens: Sequence[str], line_num: int + ) -> Tuple[bool, Encrypt]: + """Parse Attribute flags of (Free)RADIUS dictionaries.""" + filename = self.filestack[-1] + has_tag = False + encrypt = Encrypt.NoEncrpytion + + try: + flags = [flag.split("=") for flag in tokens[4].split(",")] + except IndexError: + return False, Encrypt.NoEncrpytion + + for flag in flags: + flag_len = len(flag) + if flag == 1: + value = None + elif flag_len == 2: + value = flag[1] + else: + raise ParseError( + filename, f"Incorrect attribute flag {flag}", line_num + ) + key = flag[0] + + if key == "has_tag": + has_tag = True + elif key == "encrypt": + try: + encrypt = Encrypt(int(value)) # type: ignore + except (ValueError, TypeError): + raise ParseError( + filename, + f"Illegal attribute encryption {value}", + line_num, + ) + else: + raise ParseError( + filename, "Unknown attribute flag {key}", line_num + ) + + return has_tag, encrypt + + def _parse_attribute(self, tokens: Sequence[str], line_num: int): + """Parse an ATTRIBUTE line of (Free)RADIUS dictionaries.""" + filename = self.filestack[-1] + if not len(tokens) in {4, 5}: + raise ParseError( + filename, + "Incorrect number of tokens for attribute definition", + line_num, + ) + + has_tag, encrypt = self._parse_attribute_flags(tokens, line_num) + name, code, datatype = tokens[1:4] + + if datatype == "concat" and self.cur_vendor != self.rfc_vendor: + raise ParseError( + filename, + 'vendor attributes are not allowed to have the datatype "concat"', + line_num, + ) + + try: + codes = _parse_attribute_code(code) + except ValueError: + raise ParseError( + filename, f'invalid attribute code {code}""', line_num + ) + + for code in codes: + tlength = self.cur_vendor.tlength + if 2 ** (8 * tlength) <= code: + raise ParseError( + filename, + f"attribute code is too big, must be smaller than 2**{tlength}", + line_num, + ) + if code < 0: + raise ParseError( + filename, + "negative attribute codes are not allowed", + line_num, + ) + + # TODO: Do we some explicit handling of tlvs? + # if len(codes) > 1: + # self._parse_attribute_tlv(codes, line_num) + # else: + # pass + + base_datatype = datatype.split("[")[0].replace("-", "") + try: + attribute_type = Datatype[base_datatype] + except KeyError: + raise ParseError(filename, f"Illegal type: {datatype}", line_num) + + attribute = Attribute( + name, + codes[-1], + attribute_type, + has_tag, + encrypt, + len(codes) > 1, + {}, + ) + + attrcode = codes[0] if len(codes) == 1 else tuple(codes) + self.cur_vendor.attrs[attrcode] = attribute + + if self.cur_vendor != self.rfc_vendor: + codes = tuple([26] + codes) + attrcode = codes[0] if len(codes) == 1 else tuple(codes) + self.attrindex[attrcode] = attribute + self.attrindex[name] = attribute + + def _parse_value(self, tokens: Sequence[str], line_num: int): + """Parse an ATTRIBUTE line of (Free)RADIUS dictionaries.""" + filename = self.filestack[-1] + if len(tokens) != 4: + raise ParseError( + filename, + "Incorrect number of tokens for VALUE definition", + line_num, + ) + + (attr_name, key, value) = tokens[1:] + value = _parse_number(value) + + attribute = self.attrindex[attr_name] + try: + datatype = str(attribute.datatype).split(".")[1] + lmin, lmax = INTEGER_TYPES[datatype] + if value < lmin or value > lmax: + raise ParseError( + filename, + f"VALUE {key}({value}) is not in the limit of type {datatype}", + line_num, + ) + except KeyError: + raise ParseError( + filename, + f"only attributes with integer typed datatypes can have" + f"value definitions {attribute.datatype}", + line_num, + ) + attribute.values[value] = key + attribute.values[key] = value diff --git a/src/pyrad3/host.py b/src/pyrad3/host.py new file mode 100644 index 0000000..e7d7c6e --- /dev/null +++ b/src/pyrad3/host.py @@ -0,0 +1,46 @@ +# Copyright 2020 Istvan Ruzman +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""Interface Class for RADIUS Clients and Servers""" + +from pyrad3.dictionary import Dictionary +from pyrad3 import packet + + +class Host: # pylint: disable=too-many-arguments + """Interface Class for RADIUS Clients and Servers""" + def __init__( + self, + secret: bytes, + radius_dict: Dictionary, + authport: int = 1812, + acctport: int = 1813, + coaport: int = 3799, + timeout: float = 30, + retries: int = 3, + ): + self.secret = secret + self.dictionary = radius_dict + + self.authport = authport + self.acctport = acctport + self.coaport = coaport + + self.timeout = timeout + self.retries = retries + + def create_packet(self, **kwargs): + """Create a generic RADIUS Packet""" + return packet.Packet(self, **kwargs) + + def create_auth_packet(self, **kwargs): + """Create an Authentictaion packet (request per default)""" + return packet.AuthPacket(self, **kwargs) + + def create_acct_packet(self, **kwargs): + """Create an Accounting packet (request per default)""" + return packet.AcctPacket(self, **kwargs) + + def create_coa_packet(self, **kwargs): + """Create an CoA packet (requset per default)""" + return packet.CoAPacket(self, **kwargs) diff --git a/src/pyrad3/packet.py b/src/pyrad3/packet.py new file mode 100644 index 0000000..76d3764 --- /dev/null +++ b/src/pyrad3/packet.py @@ -0,0 +1,305 @@ +# Copyright 2020 Istvan Ruzman +# SPDX-License-Identifier: MIT OR Apache-2.0 + +from collections import OrderedDict +from enum import IntEnum +from secrets import token_bytes +from typing import Any, Dict, Optional, Sequence + +import hashlib +import hmac + +from pyrad3.host import Host +from pyrad3.utils import ( + PacketError, + Attribute, + parse_header, + parse_attributes, + calculate_authenticator, + validate_pap_password, + validate_chap_password, +) + +HMAC = hmac.new + + +# Packet codes +class Code(IntEnum): + AccessRequest = 1 + AccessAccept = 2 + AccessReject = 3 + AccountingRequest = 4 + AccountingResponse = 5 + AccessChallenge = 11 + StatusServer = 12 + StatusClient = 13 + DisconnectRequest = 40 + DisconnectACK = 41 + DisconnectNAK = 42 + CoARequest = 43 + CoAACK = 44 + CoANAK = 45 + + +class AuthError(Exception): + pass + + +class Packet(OrderedDict): + def __init__( + self, + host: Host, + code: Code, + radius_id: int, + *, + request: "Packet" = None, + **attributes + ): + super().__init__(**attributes) + self.code = code + self.id = radius_id + self.host = host + self.request = request + self.ordered_attributes: Sequence[Attribute] = [] + self.raw_packet: Optional[bytearray] = None + self.authenticator: Optional[bytes] = None + + @staticmethod + def from_raw(host: Host, raw_packet: bytearray) -> "Packet": + (code, radius_id, _length, authenticator) = parse_header(raw_packet) + + ordered_attrs = parse_attributes(host.dictionary, raw_packet) + + # Can we do better than Any with type hinting? + attrs: Dict[str, Any] = {} + for attr in ordered_attrs: + try: + attrs[attr.name].append(attr.value) + except KeyError: + attrs[attr.name] = [attr.value] + + parsed_packet = Packet(host, code, radius_id, **attrs) + parsed_packet.authenticator = authenticator + parsed_packet.raw_packet = raw_packet + parsed_packet.ordered_attributes = ordered_attrs + return parsed_packet + + def from_raw_reply(self, raw_packet: bytearray) -> "Packet": + self.verify_reply(raw_packet) + reply = Packet.from_raw(self.host, raw_packet) + reply.request = self + try: + if not reply.validate_message_authenticator(): + raise PacketError("Packet has a wrong message authenticator") + except KeyError: + if "EAP-Message" in reply: + raise PacketError("Packet is missing a message authenticator") + return reply + + def send(self): + """Send the packet to the Client/Server. + """ + self.host._send_packet(self) + + def verify_reply(self, raw_reply: bytes): + """Verify the reply to this packet. + """ + if self.id != raw_reply[1]: + raise PacketError("Response has a wrong id") + + # self.authenticator MUST be set, this packet got send so by definitation + # self.authenticator will not be non, but bytes + radius_hash = calculate_authenticator( + self.host.secret, + self.authenticator, # type: ignore + raw_reply, + ) + + if radius_hash != raw_reply[4:20]: + raise PacketError("Reply Packet has a wrong authenticator") + + def validate_message_authenticator(self): + message_authenticator = self["Message-Authenticator"] + if isinstance(list, message_authenticator): + # There are multiple Message Authenticators, but a packet MUST NOT have + # more than one + return False + ma_attribute = self.find_first_attribute("Message-Authenticator") + generated = self._generate_message_authenticator(ma_attribute) + return message_authenticator == generated + + def _generate_message_authenticator(self, ma_attr: Attribute): + assert self.authenticator is not None + assert self.request is not None + assert self.request.authenticator is not None + assert self.raw_packet is not None + + # The message authenticator must be treated as 16 * \00 + start_pos = ma_attr.pos + 2 + end_pos = start_pos + 16 + original_ma: bytes = ma_attr.value + self.raw_packet[start_pos:end_pos] = 16 * b"\00" + + hmac_builder = HMAC(self.host.secret, digestmod=hashlib.md5) + hmac_builder.update(self.raw_packet) + + if self.code in (Code.AccessRequest, Code.StatusServer): + hmac_builder.update(self.authenticator) + elif self.code in ( + Code.AccessAccept, + Code.AccessChallenge, + Code.AccessReject, + ): + hmac_builder.update(self.request.authenticator) + else: + hmac_builder.update(16 * b"\00") + + hmac_builder.update(self.raw_packet[20:]) + self.raw_packet[start_pos:end_pos] = original_ma + return hmac_builder.digest() + + def add_message_authenticator(self): + self._encode_packet() + self._generate_message_authenticator(self) + try: + # quick lookup before we iterate over the whole packet + _ = self["Message-Authenticator"] + attr = self.find_first_attribute("Message-Authenticator") + except KeyError: + self["Message-Authenticator"] = 16 * b"\00" + attr = self.ordered_attributes[-1] + generated = self._generate_message_authenticator(attr) + self[attr.pos + 2 :] = generated + + def refresh_message_authenticator(self): + self.add_message_authenticator() + + def find_first_attribute(self, attr_type_name: str) -> Attribute: + for attr in self.ordered_attributes: + if attr.type == attr_type_name: + return attr.type + raise KeyError + + def _encode_packet(self): + self.raw_packet = None + + +class AuthPacket(Packet): + def __init__( + self, + host: Host, + radius_id: int, + auth_type, + *, + code: Code = Code.AccessRequest, + request: Optional[Packet] = None, + **attributes + ): + super().__init__(host, code, radius_id, request=request, **attributes) + self.auth_type = auth_type + if code == Code.AccessRequest: + self.authenticator = token_bytes(16) + + def create_accept(self, **attributes): + return AuthPacket( + self.host, + self.id, + self.auth_type, + request=self, + code=Code.AccessAccept, + **attributes + ) + + def create_reject(self, **attributes): + return AuthPacket( + self.host, + self.id, + self.auth_type, + request=self, + code=Code.AccessReject, + **attributes + ) + + def create_challange(self, **attributes): + return AuthPacket( + self.host, + self.id, + self.auth_type, + request=self, + code=Code.AccessChallenge, + **attributes + ) + + def validate_password(self, password: bytes) -> bool: + try: + return self.validate_pap(password) + except KeyError: + pass + # Will throw KeyError if no chap password exists + return self.validate_chap(password) + + def validate_pap(self, password: bytes) -> bool: + packet_password = self["User-Password"] + return validate_pap_password( + self.host.secret, + self.authenticator, # type: ignore + packet_password, + password, + ) + + def validate_chap(self, password: bytes) -> bool: + packet_password = self["Chap-Password"] + chap_id = packet_password[:1] + chap_password = packet_password[1:] + try: + challenge = self["Chap-Challenge"] + except KeyError: + challenge = self.authenticator + return validate_chap_password( + chap_id, challenge, chap_password, password, # type: ignore + ) + + +class AcctPacket(Packet): + def __init__( + self, + host: Host, + radius_id: int, + *, + code: Code = Code.AccountingRequest, + request: Optional[Packet] = None, + **attributes + ): + super().__init__(host, code, radius_id, request=request, **attributes) + + def create_response(self, **attributes): + return AcctPacket( + self.host, + self.id, + code=Code.AccountingResponse, + request=self, + **attributes + ) + + +class CoAPacket(Packet): + def __init__( + self, + host: Host, + radius_id: int, + *, + code: Code = Code.CoARequest, + request: Optional[Packet] = None, + **attributes + ): + super().__init__(host, code, radius_id, request=request, **attributes) + + def create_ack(self, **attributes): + return CoAPacket( + self.host, self.id, code=Code.CoAACK, request=self, **attributes + ) + + def create_nack(self, **attributes): + return CoAPacket( + self.host, self.id, code=Code.CoANAK, request=self, **attributes + ) diff --git a/pyrad3/proxy.py b/src/pyrad3/proxy.py similarity index 84% rename from pyrad3/proxy.py rename to src/pyrad3/proxy.py index 7a7ee3e..20afb26 100644 --- a/pyrad3/proxy.py +++ b/src/pyrad3/proxy.py @@ -26,7 +26,8 @@ class Proxy(Server): self._fdmap[self._proxyfd.fileno()] = self._proxyfd self._poll.register( self._proxyfd.fileno(), - (select.POLLIN | select.POLLPRI | select.POLLERR)) + (select.POLLIN | select.POLLPRI | select.POLLERR), + ) def _handle_proxy_packet(self, pkt): """Process a packet received on the reply socket. @@ -38,12 +39,15 @@ class Proxy(Server): :type pkt: Packet class instance """ if pkt.source[0] not in self.hosts: - raise ServerPacketError('Received packet from unknown host') + raise ServerPacketError("Received packet from unknown host") pkt.secret = self.hosts[pkt.source[0]].secret - if pkt.code not in [packet.AccessAccept, packet.AccessReject, - packet.AccountingResponse]: - raise ServerPacketError('Received non-response on proxy socket') + if pkt.code not in [ + packet.AccessAccept, + packet.AccessReject, + packet.AccountingResponse, + ]: + raise ServerPacketError("Received non-response on proxy socket") def _process_input(self, fd): """Process available data. @@ -62,7 +66,8 @@ class Proxy(Server): """ if fd.fileno() == self._proxyfd.fileno(): pkt = self._grab_packet( - lambda data, s=self: s.CreatePacket(packet=data), fd) + lambda data, s=self: s.CreatePacket(packet=data), fd + ) self._handle_proxy_packet(pkt) else: Server._process_input(self, fd) diff --git a/pyrad3/server.py b/src/pyrad3/server.py similarity index 86% rename from pyrad3/server.py rename to src/pyrad3/server.py index a966243..f20b06f 100644 --- a/pyrad3/server.py +++ b/src/pyrad3/server.py @@ -9,13 +9,15 @@ from pyrad import host from pyrad import packet -LOGGER = logging.getLogger('pyrad') +LOGGER = logging.getLogger("pyrad") class RemoteHost: """Remote RADIUS capable host we can talk to.""" - def __init__(self, address, secret, name, authport=1812, acctport=1813, coaport=3799): + def __init__( + self, address, secret, name, authport=1812, acctport=1813, coaport=3799 + ): """Constructor. :param address: IP address @@ -62,10 +64,21 @@ class Server(host.Host): :cvar MaxPacketSize: maximum size of a RADIUS packet :type MaxPacketSize: integer """ + MaxPacketSize = 4096 - def __init__(self, addresses=[], authport=1812, acctport=1813, coaport=3799, - hosts=None, dict=None, auth_enabled=True, acct_enabled=True, coa_enabled=False): + def __init__( + self, + addresses=[], + authport=1812, + acctport=1813, + coaport=3799, + hosts=None, + dict=None, + auth_enabled=True, + acct_enabled=True, + coa_enabled=False, + ): """Constructor. :param addresses: IP addresses to listen on @@ -114,7 +127,7 @@ class Server(host.Host): """ results = set() try: - tmp = socket.getaddrinfo(addr, 'www') + tmp = socket.getaddrinfo(addr, "www") except socket.gaierror: return [] @@ -199,10 +212,10 @@ class Server(host.Host): """ if pkt.source[0] in self.hosts: pkt.secret = self.hosts[pkt.source[0]].secret - elif '0.0.0.0' in self.hosts: - pkt.secret = self.hosts['0.0.0.0'].secret + elif "0.0.0.0" in self.hosts: + pkt.secret = self.hosts["0.0.0.0"].secret else: - raise ServerPacketError('Received packet from unknown host') + raise ServerPacketError("Received packet from unknown host") def _handle_auth_packet(self, pkt): """Process a packet received on the authentication port. @@ -216,7 +229,8 @@ class Server(host.Host): self._add_secret(pkt) if pkt.code != packet.AccessRequest: raise ServerPacketError( - 'Received non-authentication packet on authentication port') + "Received non-authentication packet on authentication port" + ) self.HandleAuthPacket(pkt) def _handle_acct_packet(self, pkt): @@ -229,10 +243,13 @@ class Server(host.Host): :type pkt: Packet class instance """ self._add_secret(pkt) - if pkt.code not in [packet.AccountingRequest, - packet.AccountingResponse]: + if pkt.code not in [ + packet.AccountingRequest, + packet.AccountingResponse, + ]: raise ServerPacketError( - 'Received non-accounting packet on accounting port') + "Received non-accounting packet on accounting port" + ) self.HandleAcctPacket(pkt) def _handle_coa_packet(self, pkt): @@ -251,7 +268,7 @@ class Server(host.Host): elif pkt.code == packet.DisconnectRequest: self.HandleDisconnectPacket(pkt) else: - raise ServerPacketError('Received non-coa packet on coa port') + raise ServerPacketError("Received non-coa packet on coa port") def _grab_packet(self, pktgen, fd): """Read a packet from a network connection. @@ -273,7 +290,9 @@ class Server(host.Host): """ for fd in self.authfds + self.acctfds + self.coafds: self._fdmap[fd.fileno()] = fd - self._poll.register(fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR) + self._poll.register( + fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR + ) if self.auth_enabled: self._realauthfds = list(map(lambda x: x.fileno(), self.authfds)) if self.acct_enabled: @@ -307,16 +326,22 @@ class Server(host.Host): :type fd: socket class instance """ if self.auth_enabled and fd.fileno() in self._realauthfds: - pkt = self._grab_packet(lambda data, s=self: s.CreateAuthPacket(packet=data), fd) + pkt = self._grab_packet( + lambda data, s=self: s.CreateAuthPacket(packet=data), fd + ) self._handle_auth_packet(pkt) elif self.acct_enabled and fd.fileno() in self._realacctfds: - pkt = self._grab_packet(lambda data, s=self: s.CreateAcctPacket(packet=data), fd) + pkt = self._grab_packet( + lambda data, s=self: s.CreateAcctPacket(packet=data), fd + ) self._handle_acct_packet(pkt) elif self.coa_enabled: - pkt = self._grab_packet(lambda data, s=self: s.CreateCoAPacket(packet=data), fd) + pkt = self._grab_packet( + lambda data, s=self: s.CreateCoAPacket(packet=data), fd + ) self._handle_coa_packet(pkt) else: - raise ServerPacketError('Received packet for unknown handler') + raise ServerPacketError("Received packet for unknown handler") def Run(self): """Main loop. @@ -335,8 +360,8 @@ class Server(host.Host): fdo = self._fdmap[fd] self._process_input(fdo) except ServerPacketError as err: - LOGGER.info('Dropping packet: %s', err) + LOGGER.info("Dropping packet: %s", err) except packet.PacketError as err: - LOGGER.info('Received a broken packet: %s', err) + LOGGER.info("Received a broken packet: %s", err) else: - LOGGER.error('Unexpected event in server main loop') + LOGGER.error("Unexpected event in server main loop") diff --git a/pyrad3/server_async.py b/src/pyrad3/server_async.py similarity index 53% rename from pyrad3/server_async.py rename to src/pyrad3/server_async.py index fb72da4..2476bde 100644 --- a/pyrad3/server_async.py +++ b/src/pyrad3/server_async.py @@ -9,26 +9,37 @@ from abc import abstractmethod, ABCMeta from enum import Enum from datetime import datetime from pyrad.packet import ( - Packet, AccessAccept, AccessReject, - AccountingRequest, AccountingResponse, - DisconnectACK, DisconnectNAK, DisconnectRequest, CoARequest, - CoAACK, CoANAK, AccessRequest, AuthPacket, AcctPacket, CoAPacket, - PacketError + Packet, + AccessAccept, + AccessReject, + AccountingRequest, + AccountingResponse, + DisconnectACK, + DisconnectNAK, + DisconnectRequest, + CoARequest, + CoAACK, + CoANAK, + AccessRequest, + AuthPacket, + AcctPacket, + CoAPacket, + PacketError, ) from pyrad.server import ServerPacketError class ServerType(Enum): - Auth = 'Authentication' - Acct = 'Accounting' - Coa = 'Coa' + Auth = "Authentication" + Acct = "Accounting" + Coa = "Coa" class DatagramProtocolServer(asyncio.Protocol): - - def __init__(self, ip, port, logger, server, server_type, hosts, - request_callback): + def __init__( + self, ip, port, logger, server, server_type, hosts, request_callback + ): self.transport = None self.ip = ip self.port = port @@ -40,94 +51,149 @@ class DatagramProtocolServer(asyncio.Protocol): def connection_made(self, transport): self.transport = transport - self.logger.info('[%s:%d] Transport created', self.ip, self.port) + self.logger.info("[%s:%d] Transport created", self.ip, self.port) def connection_lost(self, exc): if exc: - self.logger.warn('[%s:%d] Connection lost: %s', self.ip, self.port, str(exc)) + self.logger.warn( + "[%s:%d] Connection lost: %s", self.ip, self.port, str(exc) + ) else: - self.logger.info('[%s:%d] Transport closed', self.ip, self.port) + self.logger.info("[%s:%d] Transport closed", self.ip, self.port) def send_response(self, reply, addr): self.transport.sendto(reply.ReplyPacket(), addr) def datagram_received(self, data, addr): - self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, self.port, len(data), addr) + self.logger.debug( + "[%s:%d] Received %d bytes from %s", + self.ip, + self.port, + len(data), + addr, + ) receive_date = datetime.utcnow() if addr[0] in self.hosts: remote_host = self.hosts[addr[0]] - elif '0.0.0.0' in self.hosts: - remote_host = self.hosts['0.0.0.0'] + elif "0.0.0.0" in self.hosts: + remote_host = self.hosts["0.0.0.0"] else: - self.logger.warn('[%s:%d] Drop package from unknown source %s', self.ip, self.port, addr) + self.logger.warn( + "[%s:%d] Drop package from unknown source %s", + self.ip, + self.port, + addr, + ) return try: - self.logger.debug('[%s:%d] Received from %s packet: %s', self.ip, self.port, addr, data.hex()) + self.logger.debug( + "[%s:%d] Received from %s packet: %s", + self.ip, + self.port, + addr, + data.hex(), + ) req = Packet(packet=data, dict=self.server.dict) except Exception as exc: - self.logger.error('[%s:%d] Error on decode packet: %s', self.ip, self.port, exc) + self.logger.error( + "[%s:%d] Error on decode packet: %s", self.ip, self.port, exc + ) return try: - if req.code in (AccountingResponse, AccessAccept, AccessReject, CoANAK, CoAACK, DisconnectNAK, DisconnectACK): - raise ServerPacketError(f'Invalid response packet {req.code}') + if req.code in ( + AccountingResponse, + AccessAccept, + AccessReject, + CoANAK, + CoAACK, + DisconnectNAK, + DisconnectACK, + ): + raise ServerPacketError(f"Invalid response packet {req.code}") elif self.server_type == ServerType.Auth: if req.code != AccessRequest: - raise ServerPacketError('Received non-auth packet on auth port') - req = AuthPacket(secret=remote_host.secret, - dict=self.server.dict, - packet=data) + raise ServerPacketError( + "Received non-auth packet on auth port" + ) + req = AuthPacket( + secret=remote_host.secret, + dict=self.server.dict, + packet=data, + ) if self.server.enable_pkt_verify: if req.VerifyAuthRequest(): - raise PacketError('Packet verification failed') + raise PacketError("Packet verification failed") elif self.server_type == ServerType.Coa: if req.code != DisconnectRequest and req.code != CoARequest: - raise ServerPacketError('Received non-coa packet on coa port') - req = CoAPacket(secret=remote_host.secret, - dict=self.server.dict, - packet=data) + raise ServerPacketError( + "Received non-coa packet on coa port" + ) + req = CoAPacket( + secret=remote_host.secret, + dict=self.server.dict, + packet=data, + ) if self.server.enable_pkt_verify: if req.VerifyCoARequest(): - raise PacketError('Packet verification failed') + raise PacketError("Packet verification failed") elif self.server_type == ServerType.Acct: if req.code != AccountingRequest: - raise ServerPacketError('Received non-acct packet on acct port') - req = AcctPacket(secret=remote_host.secret, - dict=self.server.dict, - packet=data) + raise ServerPacketError( + "Received non-acct packet on acct port" + ) + req = AcctPacket( + secret=remote_host.secret, + dict=self.server.dict, + packet=data, + ) if self.server.enable_pkt_verify: if req.VerifyAcctRequest(): - raise PacketError('Packet verification failed') + raise PacketError("Packet verification failed") # Call request callback self.request_callback(self, req, addr) except Exception as exc: if self.server.debug: - self.logger.exception('[%s:%d] Error for packet from %s', self.ip, self.port, addr) + self.logger.exception( + "[%s:%d] Error for packet from %s", self.ip, self.port, addr + ) else: - self.logger.error('[%s:%d] Error for packet from %s: %s', self.ip, self.port, addr, exc) + self.logger.error( + "[%s:%d] Error for packet from %s: %s", + self.ip, + self.port, + addr, + exc, + ) process_date = datetime.utcnow() - self.logger.debug('[%s:%d] Request from %s processed in %d ms', self.ip, self.port, addr, (process_date-receive_date).microseconds/1000) + self.logger.debug( + "[%s:%d] Request from %s processed in %d ms", + self.ip, + self.port, + addr, + (process_date - receive_date).microseconds / 1000, + ) def error_received(self, exc): - self.logger.error('[%s:%d] Error received: %s', self.ip, self.port, exc) + self.logger.error("[%s:%d] Error received: %s", self.ip, self.port, exc) async def close_transport(self): if self.transport: - self.logger.debug('[%s:%d] Close transport...', self.ip, self.port) + self.logger.debug("[%s:%d] Close transport...", self.ip, self.port) self.transport.close() self.transport = None def __str__(self): - return f'DatagramProtocolServer(ip={self.ip}, port={self.port})' + return f"DatagramProtocolServer(ip={self.ip}, port={self.port})" # Used as protocol_factory def __call__(self): @@ -135,12 +201,18 @@ class DatagramProtocolServer(asyncio.Protocol): class ServerAsync(metaclass=ABCMeta): - - def __init__(self, auth_port=1812, acct_port=1813, - coa_port=3799, hosts=None, dictionary=None, - loop=None, logger_name='pyrad', - enable_pkt_verify=False, - debug=False): + def __init__( + self, + auth_port=1812, + acct_port=1813, + coa_port=3799, + hosts=None, + dictionary=None, + loop=None, + logger_name="pyrad", + enable_pkt_verify=False, + debug=False, + ): if not loop: self.loop = asyncio.get_event_loop() @@ -174,20 +246,35 @@ class ServerAsync(metaclass=ABCMeta): self.handle_acct_packet(protocol, req, addr) elif protocol.server_type == ServerType.Auth: self.handle_auth_packet(protocol, req, addr) - elif protocol.server_type == ServerType.Coa and \ - req.code == CoARequest: + elif ( + protocol.server_type == ServerType.Coa + and req.code == CoARequest + ): self.handle_coa_packet(protocol, req, addr) - elif protocol.server_type == ServerType.Coa and \ - req.code == DisconnectRequest: + elif ( + protocol.server_type == ServerType.Coa + and req.code == DisconnectRequest + ): self.handle_disconnect_packet(protocol, req, addr) else: - self.logger.error('[%s:%s] Unexpected request found', protocol.ip, protocol.port) + self.logger.error( + "[%s:%s] Unexpected request found", + protocol.ip, + protocol.port, + ) except Exception as exc: if self.debug: - self.logger.exception('[%s:%s] Unexpected error', protocol.ip, protocol.port) + self.logger.exception( + "[%s:%s] Unexpected error", protocol.ip, protocol.port + ) else: - self.logger.error('[%s:%s] Unexpected error: %s', protocol.ip, protocol.port, exc) + self.logger.error( + "[%s:%s] Unexpected error: %s", + protocol.ip, + protocol.port, + exc, + ) def __is_present_proto__(self, ip, port): if port == self.auth_port: @@ -217,87 +304,92 @@ class ServerAsync(metaclass=ABCMeta): reply = pkt.CreateReply(**attributes) return reply - async def initialize_transports(self, enable_acct=False, - enable_auth=False, enable_coa=False, - addresses=None): + async def initialize_transports( + self, + enable_acct=False, + enable_auth=False, + enable_coa=False, + addresses=None, + ): task_list = [] if not enable_acct and not enable_auth and not enable_coa: - raise Exception('No transports selected') + raise Exception("No transports selected") if not addresses or len(addresses) == 0: - addresses = ['127.0.0.1'] + addresses = ["127.0.0.1"] # noinspection SpellCheckingInspection for addr in addresses: - if enable_acct and not self.__is_present_proto__(addr, self.acct_port): + if enable_acct and not self.__is_present_proto__( + addr, self.acct_port + ): protocol_acct = DatagramProtocolServer( addr, self.acct_port, - self.logger, self, + self.logger, + self, ServerType.Acct, self.hosts, - self.__request_handler__ + self.__request_handler__, ) bind_addr = (addr, self.acct_port) acct_connect = self.loop.create_datagram_endpoint( - protocol_acct, - reuse_port=True, - local_addr=bind_addr + protocol_acct, reuse_port=True, local_addr=bind_addr ) self.acct_protocols.append(protocol_acct) task_list.append(acct_connect) - if enable_auth and not self.__is_present_proto__(addr, self.auth_port): + if enable_auth and not self.__is_present_proto__( + addr, self.auth_port + ): protocol_auth = DatagramProtocolServer( addr, self.auth_port, - self.logger, self, + self.logger, + self, ServerType.Auth, self.hosts, - self.__request_handler__ + self.__request_handler__, ) bind_addr = (addr, self.auth_port) auth_connect = self.loop.create_datagram_endpoint( - protocol_auth, - reuse_port=True, - local_addr=bind_addr + protocol_auth, reuse_port=True, local_addr=bind_addr ) self.auth_protocols.append(protocol_auth) task_list.append(auth_connect) - if enable_coa and not self.__is_present_proto__(addr, self.coa_port): + if enable_coa and not self.__is_present_proto__( + addr, self.coa_port + ): protocol_coa = DatagramProtocolServer( addr, self.coa_port, - self.logger, self, + self.logger, + self, ServerType.Coa, self.hosts, - self.__request_handler__ + self.__request_handler__, ) bind_addr = (addr, self.coa_port) coa_connect = self.loop.create_datagram_endpoint( - protocol_coa, - reuse_port=True, - local_addr=bind_addr + protocol_coa, reuse_port=True, local_addr=bind_addr ) self.coa_protocols.append(protocol_coa) task_list.append(coa_connect) await asyncio.ensure_future( - asyncio.gather( - *task_list, - return_exceptions=False, - ), - loop=self.loop + asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop ) # noinspection SpellCheckingInspection - async def deinitialize_transports(self, deinit_coa=True, deinit_auth=True, deinit_acct=True): + async def deinitialize_transports( + self, deinit_coa=True, deinit_auth=True, deinit_acct=True + ): if deinit_coa: for proto in self.coa_protocols: diff --git a/src/pyrad3/tools.py b/src/pyrad3/tools.py new file mode 100644 index 0000000..4cd40e4 --- /dev/null +++ b/src/pyrad3/tools.py @@ -0,0 +1,243 @@ +# Copyright 2020 Istvan Ruzman +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""Collections of functions to en- and decode RADIUS Attributes""" + +from typing import Union +from ipaddress import IPv4Address, IPv6Address, IPv6Network, ip_network, ip_address + +import struct + + +def encode_string(string: str) -> bytes: + """Encode a RADIUS value of type string""" + if len(string) > 253: + raise ValueError("Can only encode strings of <= 253 characters") + if isinstance(string, str): + return string.encode("utf-8") + return string + + +def encode_octets(string: bytes) -> bytes: + """Encode a RADIUS value of type octet""" + if len(string) > 253: + raise ValueError("Can only encode strings of <= 253 characters") + return string + + +def encode_address(addr: Union[str, IPv4Address]) -> bytes: + """Encode a RADIUS value of type ipaddr""" + return IPv4Address(addr).packed + + +def encode_ipv6_prefix(addr: Union[str, IPv6Network]) -> bytes: + """Encode a RADIUS value of type ipv6prefix""" + address = IPv6Network(addr) + return struct.pack("2B", *[0, address.prefixlen]) + address.network_address.packed + + +def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes: + """Encode a RADIUS value of type ipv6addr""" + return IPv6Address(addr).packed + + +def encode_combo_ip(addr: Union[str, IPv4Address, IPv6Address]) -> bytes: + """Encode a RADIUS value of type combo-ip""" + return ip_address(addr).packed + + +def encode_ascend_binary(string: str) -> bytes: + """ + struct_format: List of type=value pairs sperated by spaces. + + Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32' + + Type: + family ipv4(default) or ipv6 + action discard(default) or accept + direction in(default) or out + src source prefix (default ignore) + dst destination prefix (default ignore) + proto protocol number / next-header number (default ignore) + sport source port (default ignore) + dport destination port (default ignore) + sportq source port qualifier (default 0) + dportq destination port qualifier (default 0) + + Source/Destination Port Qualifier: + 0 no compare + 1 less than + 2 equal to + 3 greater than + 4 not equal to + """ + + terms = { + "family": b"\x01", + "action": b"\x00", + "direction": b"\x01", + "src": b"\x00\x00\x00\x00", + "dst": b"\x00\x00\x00\x00", + "srcl": b"\x00", + "dstl": b"\x00", + "proto": b"\x00", + "sport": b"\x00\x00", + "dport": b"\x00\x00", + "sportq": b"\x00", + "dportq": b"\x00", + } + + for term in string.split(" "): + key, value = term.split("=") + if key == "family" and value == "ipv6": + terms[key] = b"\x03" + if terms["src"] == b"\x00\x00\x00\x00": + terms["src"] = 16 * b"\x00" + if terms["dst"] == b"\x00\x00\x00\x00": + terms["dst"] = 16 * b"\x00" + elif key == "action" and value == "accept": + terms[key] = b"\x01" + elif key == "direction" and value == "out": + terms[key] = b"\x00" + elif key in ("src", "dst"): + address = ip_network(value) + terms[key] = address.network_address.packed + terms[key + "l"] = struct.pack("B", address.prefixlen) + elif key in ("sport", "dport"): + terms[key] = struct.pack("!H", int(value)) + elif key in ("sportq", "dportq", "proto"): + terms[key] = struct.pack("B", int(value)) + + trailer = 8 * b"\x00" + + result = b"".join( + ( + terms["family"], + terms["action"], + terms["direction"], + b"\x00", + terms["src"], + terms["dst"], + terms["srcl"], + terms["dstl"], + terms["proto"], + b"\x00", + terms["sport"], + terms["dport"], + terms["sportq"], + terms["dportq"], + b"\x00\x00", + trailer, + ) + ) + return result + + +def encode_integer(num: Union[int, str], struct_format="!I") -> bytes: + """Encode a RADIUS value of some type integer""" + return struct.pack(struct_format, int(num)) + + +def encode_date(num: Union[int, str]) -> bytes: + """Encode a RADIUS value of type date""" + return struct.pack("!I", int(num)) + + +def decode_string(string: bytes) -> Union[str, bytes]: + """Decode a RADIUS value of type string""" + try: + return string.decode("utf-8") + except UnicodeDecodeError: + return string + + +def decode_octets(string: bytes) -> bytes: + """Decode a RADIUS value of type octet""" + return string + + +def decode_address(addr: bytes) -> IPv4Address: + """Decode a RADIUS value of type ipaddr""" + return IPv4Address(addr) + + +def decode_ipv6_prefix(addr: bytes) -> IPv6Network: + """Decode a RADIUS value of type ipv6prefix""" + addr = addr + b"\x00" * (18 - len(addr)) + prefix = addr[:2] + addr = addr[2:] + return IPv6Network((prefix, addr)) + + +def decode_ipv6_address(addr: bytes) -> IPv6Address: + """Decode a RADIUS value of type ipv6addr""" + addr = addr + b"\x00" * (16 - len(addr)) + return IPv6Address(addr) + + +def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]: + """Decode a RADIUS value of type combo-ip""" + return ip_address(addr).packed + + +def decode_ascend_binary(string): + """Decode a RADIUS value of type abinary""" + raise NotImplementedError + + +def decode_integer(num: bytes, struct_format="!I") -> int: + """Decode a RADIUS value of some integer type""" + return (struct.unpack(struct_format, num))[0] + + +def decode_date(num): + """Decode a RADIUS value of type date""" + return (struct.unpack("!I", num))[0] + + +ENCODE_MAP = { + "string": encode_string, + "octets": encode_octets, + "integer": encode_integer, + "ipaddr": encode_address, + "ipv6prefix": encode_ipv6_prefix, + "ipv6addr": encode_ipv6_address, + "abinary": encode_ascend_binary, + "signed": lambda value: encode_integer(value, "!i"), + "short": lambda value: encode_integer(value, "!H"), + "byte": lambda value: encode_integer(value, "!B"), + "integer64": lambda value: encode_integer(value, '!Q'), + "date": encode_date, +} + + +def encode_attr(datatype, value): + """Encode a RADIUS attribute""" + try: + return ENCODE_MAP[datatype](value) + except KeyError: + raise ValueError(f"Unknown attribute type {datatype}") + + +DECODE_MAP = { + "string": decode_string, + "octets": decode_octets, + "integer": decode_integer, + "ipaddr": decode_address, + "ipv6prefix": decode_ipv6_prefix, + "ipv6addr": decode_ipv6_address, + "abinary": decode_ascend_binary, + "signed": lambda value: decode_integer(value, "!i"), + "short": lambda value: decode_integer(value, "!H"), + "byte": lambda value: decode_integer(value, "!B"), + "integer64": lambda value: decode_integer(value, "!Q"), + "date": decode_date, +} + + +def decode_attr(datatype, value): + """Decode a RADIUS attribute""" + try: + return DECODE_MAP[datatype](value) + except KeyError: + raise ValueError(f"Unknown attribute type {datatype}") diff --git a/src/pyrad3/utils.py b/src/pyrad3/utils.py new file mode 100644 index 0000000..7ef2395 --- /dev/null +++ b/src/pyrad3/utils.py @@ -0,0 +1,234 @@ +# Copyright 2020 Istvan Ruzman +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""Collection of functions to deal with RADIUS packet en- and decoding.""" + +from collections import namedtuple +from typing import List, Union + +import hashlib +import secrets +import struct + +from pyrad3.dictionary import Dictionary + +RANDOM_GENERATOR = secrets.SystemRandom() +MD5 = hashlib.md5 + + +class PacketError(Exception): + """Exception for Invalid Packets""" + + +Header = namedtuple("Header", ["code", "radius_id", "length", "authenticator"]) +Attribute = namedtuple("Attribute", ["name", "pos", "type", "length", "value"]) + + +def parse_header(raw_packet: bytes) -> Header: + """Parse the Header of a RADIUS Packet.""" + try: + header = struct.unpack("!BBH16s", raw_packet) + except struct.error: + raise PacketError("Packet header is corrupt") + + length = header[3] + if len(raw_packet) != length: + raise PacketError( + f"RADIUS Packet ({len(raw_packet)}) has an invalid length ({length})" + ) + if length > 4096: + raise PacketError(f"Packet length is too big ({length})") + return Header(*header) + + +def parse_attributes( + rad_dict: Dictionary, raw_packet: bytes +) -> List[Attribute]: + """Parse the Attributes in a RADIUS Packet. + + This function skips the Header. The Header must be parsed and verified + separately. + """ + attributes = [] + packet = raw_packet[20:] + + while packet: + try: + (key, length) = struct.unpack("!BB", packet[0:2]) + except struct.error: + raise PacketError("Attribute header is corrupt") + if length < 2: + raise PacketError(f"Attribute length ({length}) is too small") + + value = packet[2:length] + offset = len(raw_packet) - len(packet) + length + if key == 26: + try: + attributes.extend( + parse_vendor_attributes(rad_dict, offset, value) + ) + except (PacketError, IndexError): + attributes.append( + Attribute( + name="Unknown-Vendor-Attribute", + pos=offset, + type="octets", + length=int(packet[1]), + value=packet[2:], + ) + ) + else: + key = parse_key(rad_dict, key) + attributes.extend(parse_value(rad_dict, key, offset, value)) + packet = packet[length:] + + return attributes + + +def parse_vendor_attributes( + rad_dict: Dictionary, offset: int, vendor_value: bytes +) -> List[Attribute]: + """Parse A Vendor Attribute""" + if len(vendor_value) < 4: + raise PacketError + vendor_id = int.from_bytes(vendor_value[:4], "big") + vendor_dict = rad_dict.vendor_by_id[vendor_id] + vendor_name = vendor_dict.name + + attributes = [] + vendor_tlv = vendor_value[4:] + while vendor_tlv: + try: + (key, length) = struct.unpack("!BB", vendor_tlv[0:2]) + except struct.error: + attribute = [ + Attribute( + name=f"Unknown-{vendor_name}-Attribute", + pos=offset - len(vendor_value), + type="octets", + length=len(vendor_value) - 4, + value=vendor_value, + ) + ] + else: + offset = offset - len(vendor_tlv) + length + key = parse_key(vendor_dict, key) + attribute = parse_value(vendor_dict, key, offset, vendor_tlv) + attributes.extend(attribute) + vendor_tlv = vendor_tlv[length:] + return attributes + + +def parse_key(rad_dict: Dictionary, key_id: int) -> Union[str, int]: + """Parse the key in the Dictionary Context""" + try: + return rad_dict.attrs[key_id].name + except KeyError: + return key_id + + +def parse_value( + rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes +) -> List[Attribute]: + """Parse the Value in the given Key/Dictionary Context""" + return [] + + +def calculate_authenticator( + secret: bytes, authenticator: bytes, raw_packet: bytes +) -> bytes: + """Calculate the Authenticator for the RADIUS Packet""" + return MD5( + raw_packet[0:4] + authenticator + raw_packet[20:] + secret + ).digest() + + +def validate_pap_password( + secret: bytes, + authenticator: bytes, + obfuscated_password: bytes, + plaintext_password: bytes, +) -> bool: + """Check if the plaintext and the RADIUS passwords match + This function does not "decrypt" the received password. + """ + obf_pass = password_encode(secret, authenticator, plaintext_password) + return obfuscated_password == obf_pass + + +def password_encode( + secret: bytes, authenticator: bytes, password: bytes +) -> bytes: + """Obfuscate the plaintext Password for RADIUS""" + password += b"\x00" * (16 - (len(password) % 16)) + return obfuscation_algorithm(secret, authenticator, password) + + +def password_decode( + secret: bytes, authenticator: bytes, obfuscated_password: bytes +) -> str: + """Reverse the RADIUS obfuscation on a given password + + The password password is padded with \\x00 to a 16 byte boundary. The padding will + be removed by this function. + If the original password had some trailing \\x00 it will get lost. Therefore it is + not recommended to use (trailing) \\x00 in passwords. + """ + deobfuscated = obfuscation_algorithm( + secret, authenticator, obfuscated_password + ) + return deobfuscated.rstrip(b"\x00").decode("utf-8") + + +def obfuscation_algorithm( + secret: bytes, authenticator: bytes, password: bytes +) -> bytes: + """Obfuscate the plaintext password. + + This function does not deal with the padding (which the + RADIUS Protocol requires.) + The User has to pad the password themself, or better use + the `password_encode` or `password_decode` function. + """ + result = b"" + buf = password + last = authenticator + + while buf: + cur_hash = MD5(secret + last).digest() + for cbuf, chash in zip(buf, cur_hash): + result += bytes([cbuf ^ chash]) + (last, buf) = (buf[:16], buf[16:]) + + return result + + +def validate_chap_password( + chap_id: bytes, + challenge: bytes, + chap_password: bytes, + plaintext_password: bytes, +) -> bool: + """Validate the CHAP password against the given plaintext password""" + return ( + chap_password == MD5(chap_id + plaintext_password + challenge).digest() + ) + + +def salt_encrypt(secret: bytes, authenticator: bytes, value: bytes) -> bytes: + """Salt Encrypt the given value""" + # The highest bit MUST be 1 + random_value = RANDOM_GENERATOR.randrange(32768, 65535) + salt = struct.pack("!H", random_value) + + salted_auth = authenticator + salt + + return obfuscation_algorithm(secret, salted_auth, value) + + +def salt_decrypt( + secret: bytes, authenticator: bytes, salt: bytes, encrypted_value: bytes +) -> bytes: + """Decrypt the given value""" + salted_auth = authenticator + salt + return obfuscation_algorithm(secret, salted_auth, encrypted_value) diff --git a/tests/dictionaries/mutual_recursive b/tests/dictionaries/mutual_recursive new file mode 100644 index 0000000..ac1bb6c --- /dev/null +++ b/tests/dictionaries/mutual_recursive @@ -0,0 +1 @@ +$INCLUDE ./other_mutual_recursive diff --git a/tests/dictionaries/other_mutual_recursive b/tests/dictionaries/other_mutual_recursive new file mode 100644 index 0000000..a9c01e2 --- /dev/null +++ b/tests/dictionaries/other_mutual_recursive @@ -0,0 +1 @@ +$INCLUDE mutual_recursive diff --git a/tests/dictionaries/self_recursive b/tests/dictionaries/self_recursive new file mode 100644 index 0000000..949709b --- /dev/null +++ b/tests/dictionaries/self_recursive @@ -0,0 +1 @@ +$INCLUDE ./self_recursive diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py new file mode 100644 index 0000000..ac56ec1 --- /dev/null +++ b/tests/test_dictionary.py @@ -0,0 +1,190 @@ +# Copyright 2020 Istvan Ruzman +# SPDX-License-Identifier: MIT OR Apache-2.0 + +from io import StringIO + +import pytest +from pyrad3.dictionary import Dictionary, ParseError + + +@pytest.mark.parametrize( + "filename", ["dictionaries/self_recursive", "dictionaries/mutual_recursive"] +) +def test_dictionary_recursion(filename): + with pytest.raises(ParseError): + Dictionary("tests/" + filename) + + +@pytest.mark.parametrize( + "line", + [ + "$INCLUDE", + "BEGIN-VENDOR", + "END-VENDOR", + "VENDOR", + "VENDOR NAME", + "ATTRIBUTE", + "ATTRIBUTE NAME", + "VALUE", + "VALUE ATTRNAME", + "VALUE ATTRNAME VALUENAME", + ], +) +def test_lines_missing_tokens(line): + dictionary = StringIO(line) + with pytest.raises(ParseError): + Dictionary("", dictionary) + + +@pytest.mark.parametrize( + "vendor", + [ + "VENDOR test 1234", + "VENDOR test 1234 format=1,1", + "VENDOR test 1234 format=2,2", + "VENDOR test 1234 format=1,2", + "VENDOR test 1234 format=4,2", + "VENDOR test 1234 format=4,0", + "VENDOR WiMAX 1234 format=1,1,c", + ], +) +def test_valid_vendor_definitions(vendor): + dictionary = StringIO(vendor) + Dictionary("", dictionary) + + +@pytest.mark.parametrize( + "vendor", + [ + "VENDOR test 1234 1,1", + "VENDOR test 1234 format=3,1", + "VENDOR test 1234 format=2", + "VENDOR test 1234 format=1,2,c", + "VENDOR test 1234 format=1,9", + "VENDOR test 1234 format=4,4 suffix", + "VENDOR test 1234 format=a,b suffix", + ], +) +def test_invalid_vendor_definitions(vendor): + dictionary = StringIO(vendor) + with pytest.raises(ParseError): + Dictionary("", dictionary) + + +@pytest.mark.parametrize( + "number", + [ + "ATTRIBUTE NAME 0x01 byte", + "ATTRIBUTE NAME 0x0001 byte", + "ATTRIBUTE NAME 0o123 byte", + "ATTRIBUTE NAME 5 byte", + ], +) +def test_valid_attribute_numbers(number): + dictionary = StringIO(number) + Dictionary("", dictionary) + + +@pytest.mark.parametrize( + "number", + [ + "ATTRIBUTE NAME 1234 byte", + "ATTRIBUTE NAME ABCD byte", + "ATTRIBUTE NAME -1 byte", + ], +) +def test_invalid_attribute_numbers(number): + dictionary = StringIO(number) + with pytest.raises(ParseError): + Dictionary("", dictionary) + + +@pytest.mark.parametrize("type_length", [1, 2, 4]) +def test_attribute_number_limits(type_length): + too_big = 2 ** (8 * type_length) + max_value = too_big - 1 + dictionary = StringIO( + f"VENDOR TEST 1234 format={type_length},1\n" + "BEGIN-VENDOR TEST\n" + f"ATTRIBUTE TEST {max_value} byte\n" + "END-VENDOR TEST\n" + ) + Dictionary("", dictionary) + dictionary = StringIO( + f"VENDOR TEST 1234 format={type_length},1\n" + "BEGIN-VENDOR TEST\n" + f"ATTRIBUTE TEST {too_big} byte\n" + "END-VENDOR TEST\n" + ) + with pytest.raises(ParseError): + Dictionary("", dictionary) + + +@pytest.mark.parametrize( + "value_definition", + [ + "VALUE TEST-ATTRIBUTE TEST-VALUE 1", + "VALUE TEST-ATTRIBUTE TEST-VALUE 0x1", + "VALUE TEST-ATTRIBUTE TEST-VALUE 0o1", + ], +) +def test_value_definition(value_definition): + dictionary = StringIO( + "\n".join(["ATTRIBUTE TEST-ATTRIBUTE 1 byte", value_definition]) + ) + Dictionary("", dictionary) + + +@pytest.mark.parametrize( + "value_num, attr_type", + [ + (0, "byte"), + (255, "byte"), + (0, "short"), + (2 ** 16 - 1, "short"), + (0, "integer"), + (2 ** 32 - 1, "integer"), + ((-(2 ** 31)), "signed"), + (2 ** 31 - 1, "signed"), + (0, "integer64"), + (2 ** 64 - 1, "integer64"), + ], +) +def test_value_number_within_limit(value_num, attr_type): + dictionary = StringIO( + "\n".join( + [ + f"ATTRIBUTE TEST-ATTRIBUTE 1 {attr_type}", + f"VALUE TEST-ATTRIBUTE TEST-VALUE {value_num}", + ] + ) + ) + Dictionary("", dictionary) + + +@pytest.mark.parametrize( + "value_num, attr_type", + [ + (-1, "byte"), + (256, "byte"), + (-1, "short"), + (2 ** 16, "short"), + (-1, "integer"), + (2 ** 32, "integer"), + (2 ** 31, "signed"), + ((-(2 ** 31)) - 1, "signed"), + (-1, "integer64"), + (2 ** 64, "integer64"), + ], +) +def test_value_number_out_of_limit(value_num, attr_type): + dictionary = StringIO( + "\n".join( + [ + f"ATTRIBUTE TEST-ATTRIBUTE 1 {attr_type}", + f"VALUE TEST-ATTRIBUTE TEST-VALUE {value_num}", + ] + ) + ) + with pytest.raises(ParseError): + Dictionary("", dictionary)