From 9fee11b500b06e127630f051610fe09c3c452740 Mon Sep 17 00:00:00 2001 From: Kai Vogelgesang Date: Tue, 20 Oct 2020 02:46:01 +0200 Subject: [PATCH] Refactor --- handler.py | 10 ++-- proto.py | 62 ---------------------- proto/__init__.py | 0 proto/fieldtype.py | 94 +++++++++++++++++++++++++++++++++ proto/packet.py | 67 ++++++++++++++++++++++++ proto/parser.py | 119 ++++++++++++++++++++++++++++++++++++++++++ protoparser.py | 127 --------------------------------------------- 7 files changed, 285 insertions(+), 194 deletions(-) delete mode 100644 proto.py create mode 100644 proto/__init__.py create mode 100644 proto/fieldtype.py create mode 100644 proto/packet.py create mode 100644 proto/parser.py delete mode 100644 protoparser.py diff --git a/handler.py b/handler.py index 9bf1a9d..5d8348c 100644 --- a/handler.py +++ b/handler.py @@ -1,7 +1,7 @@ import binascii import itertools -import proto +import proto.packet from colordiff import Color, Diff d_c = Diff(Color.GREEN, Color.RED) @@ -34,9 +34,9 @@ def handle2(tag: str, data: str): ignore = [ - proto.AckPacket, - proto.PingPacket, - proto.FinPacket, + proto.packet.AckPacket, + proto.packet.PingPacket, + proto.packet.FinPacket, ] @@ -46,7 +46,7 @@ def handle(tag: str, data: str): return d_bytes = binascii.unhexlify(data) - pkt = proto.Parser.parse_packet(d_bytes) + pkt = proto.packet.Parser.parse(d_bytes) for packet_type in ignore: if isinstance(pkt, packet_type): diff --git a/proto.py b/proto.py deleted file mode 100644 index b8feccc..0000000 --- a/proto.py +++ /dev/null @@ -1,62 +0,0 @@ -import enum -import binascii - -import protoparser -from protoparser import Packet, unknown - - -class HazelPacketType(bytes, enum.Enum): - UNRELIABLE = bytes([0]) - RELIABLE = bytes([1]) - - HELLO = bytes([8]) - PING = bytes([12]) - ACK = bytes([10]) - FIN = bytes([9]) - - FRAGMENT = bytes([11]) # not observed yet, maybe unused in among us? - - -def int_big_endian(data: bytes) -> int: - return int.from_bytes(data, "big") - - -Parser = protoparser.Parser() - - -@Parser.register(HazelPacketType.PING, ("nonce", 2, int_big_endian)) -class PingPacket(Packet): - def __init__(self, data, nonce): - self.nonce = nonce - super().__init__(data) - - def __repr__(self): - return f"Ping {self.nonce}" - - -@Parser.register(HazelPacketType.ACK, ("nonce", 2, int_big_endian), b"\xFF") -class AckPacket(Packet): - def __init__(self, data, nonce): - self.nonce = nonce - super().__init__(data) - - def __repr__(self): - return f"Ack {self.nonce}" - - -@Parser.register(HazelPacketType.FIN) -class FinPacket(Packet): - def __repr__(self): - return "Fin" - - -@Parser.register( - HazelPacketType.HELLO, - unknown(7), - ("name_len", 1, int_big_endian), - ("name", "name_len", bytes.decode), -) -class HelloPacket(Packet): - def __init__(self, data, name, **kwargs): - self.name = name - super().__init__(data) diff --git a/proto/__init__.py b/proto/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/proto/fieldtype.py b/proto/fieldtype.py new file mode 100644 index 0000000..e6d0f87 --- /dev/null +++ b/proto/fieldtype.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +import binascii +from typing import * + +__all__ = [ + "Fixed", + "Unknown", + "Bytes", + "Ascii", + "Int", +] + +T = TypeVar("T") + +Extractor = Callable[[bytes], T] +Formatter = Callable[[T], str] + +default_bytes_formatter: Formatter[bytes] = lambda data: binascii.hexlify(data).decode() + + +def assert_extractor(pattern: bytes) -> Extractor[bytes]: + def inner(data: bytes) -> bytes: + assert data == pattern + return pattern + + return inner + + +Backref = str + + +@dataclass +class FieldType(Generic[T]): + """ + Dataclass to store all the information necessary to parse a packet field. + + Fields: + ------- + name : Optional[str] + Field Name. + length : Union[int, Backref] + How many bytes the field has. + Can be either an integer, or the name of a previous field containing one. + extractor : Extractor[T] + Function to parse the bytes of the field. + May throw Exceptions to signal parsing errors. + formatter : Formatter[T] + Function to get a string representation of the field. + + """ + + name: Optional[str] + length: Union[int, Backref] + extractor: Extractor[T] + formatter: Formatter[T] = str + + +def Fixed(pattern: bytes, **kwargs) -> FieldType[bytes]: + return FieldType(None, len(pattern), assert_extractor(pattern), **kwargs) + + +def Unknown( + length: Union[int, Backref], + name: Optional[str] = None, + formatter: Formatter[bytes] = default_bytes_formatter, +) -> FieldType[bytes]: + return FieldType(name, length, lambda data: data, formatter=formatter) + + +def Bytes( + name: str, + length: Union[int, Backref], + extractor: Extractor[bytes] = lambda data: data, + formatter: Formatter[bytes] = default_bytes_formatter, +) -> FieldType[bytes]: + return FieldType(name, length, extractor, formatter=formatter) + + +def Ascii( + name: str, + length: Union[int, Backref], + extractor: Extractor[str] = bytes.decode, + **kwargs +) -> FieldType[str]: + return FieldType(name, length, extractor, **kwargs) + + +def Int( + name: str, + length: Union[int, Backref], + extractor: Extractor[int] = lambda data: int.from_bytes(data, "big"), + **kwargs +) -> FieldType[int]: + return FieldType(name, length, extractor, **kwargs) diff --git a/proto/packet.py b/proto/packet.py new file mode 100644 index 0000000..157a3d0 --- /dev/null +++ b/proto/packet.py @@ -0,0 +1,67 @@ +import enum +from typing import * +from dataclasses import dataclass + +from . import parser +from .fieldtype import * + + +class HazelPacketType(bytes, enum.Enum): + UNRELIABLE = bytes([0]) + RELIABLE = bytes([1]) + + HELLO = bytes([8]) + PING = bytes([12]) + ACK = bytes([10]) + FIN = bytes([9]) + + FRAGMENT = bytes([11]) # not observed yet, maybe unused in among us? + + +Parser = parser.MetaParser() + + +@Parser.register +class PingPacket(parser.Packet): + Fields = [ + Fixed(HazelPacketType.PING), + Int("nonce", 2), + ] + + def __str__(self): + return f"Ping {self.nonce}" + + +@Parser.register +class AckPacket(parser.Packet): + Fields = [ + Fixed(HazelPacketType.ACK), + Int("nonce", 2), + Fixed(b"\xff"), + ] + + def __str__(self): + return f"Ack {self.nonce}" + + +@Parser.register +class FinPacket(parser.Packet): + Fields = [ + Fixed(HazelPacketType.FIN), + ] + + def __str__(self): + return "Fin" + + +@Parser.register +class HelloPacket(parser.Packet): + Fields = [ + Fixed(HazelPacketType.HELLO), + Unknown(7), + Int("name_len", 1), + Ascii("name", "name_len"), + ] + + def __str__(self): + return f"Hello {self.name} ({self.fields[1]} ???)" diff --git a/proto/parser.py b/proto/parser.py new file mode 100644 index 0000000..11b10e4 --- /dev/null +++ b/proto/parser.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass +from typing import * + +from . import fieldtype + +T = TypeVar("T") + + +@dataclass +class FieldData(Generic[T]): + name: Optional[str] + data: T + formatter: fieldtype.Formatter[T] + + def __str__(self): + return self.formatter(self.data) + + def __repr__(self): + return f"" + + +@dataclass +class Packet: + fields: List[FieldData] + + +class UnknownPacket(Packet): + def __str__(self): + return f"Unknown {''.join(str(f) for f in self.fields)}" + + +@dataclass +class AmbiguousPacket(Packet): + candidates: List[Packet] + + +class Buffer: + def __init__(self, data: bytes): + self.data = data + + def consume(self, n: int) -> bytes: + assert n <= len(self.data) + result, self.data = self.data[:n], self.data[n:] + return result + + +class FieldGetter: + def __init__(self, idx): + self.idx = idx + + def __call__(self, obj): + return obj.fields[self.idx].data + + +class MetaParser: + def __init__(self): + self.registered = list() + + def register(self, cls): + assert issubclass(cls, Packet) + assert hasattr(cls, "Fields") + + for idx, field in enumerate(cls.Fields): + if field.name is None: + continue + + assert not hasattr(cls, field.name) + + setattr(cls, field.name, property(FieldGetter(idx))) + + self.registered.append(cls) + + return cls + + def _try_parse( + self, data: bytes, cls: Type[Packet], fields: List[fieldtype.FieldType] + ) -> Packet: + res: List[FieldData] = [] + buffer = Buffer(data) + backref: Dict[str, int] = dict() + + for field in fields: + n = field.length + if isinstance(n, fieldtype.Backref): + n = backref[n] + + field_data = FieldData( + field.name, + field.extractor(buffer.consume(n)), # type: ignore + field.formatter, # type: ignore + ) + + if field.name is not None and isinstance(field_data.data, int): + backref[field.name] = field_data.data + + res.append(field_data) + + return cls(res) + + def parse(self, data: bytes): + possible_results = list() + for cls in self.registered: + try: + res = self._try_parse(data, cls, cls.Fields) + possible_results.append(res) + except (KeyError, ValueError, AssertionError) as e: + continue + + if len(possible_results) == 0: + return UnknownPacket( + [FieldData(None, data, fieldtype.default_bytes_formatter)] + ) + elif len(possible_results) > 1: + return AmbiguousPacket( + [FieldData(None, data, fieldtype.default_bytes_formatter)], + possible_results, + ) + else: + return possible_results[0] diff --git a/protoparser.py b/protoparser.py deleted file mode 100644 index c9a050d..0000000 --- a/protoparser.py +++ /dev/null @@ -1,127 +0,0 @@ -import binascii -import enum -from typing import Tuple, Dict, List, Union, Callable, Optional, Any, Type - -Extractor = Callable[[bytes], Any] - - -FieldSpec = Union[ - # specific value to be expected - bytes, - # [named] field with fixed length - Tuple[Optional[str], int, Optional[Extractor]], - # [named] field with length backreference - Tuple[Optional[str], str, Optional[Extractor]], - # parse until end - None, -] - -# Spec = Tuple[..., List[FieldSpec]] - - -class Buffer: - def __init__(self, data: bytes): - self.data = data - - def consume(self, n: int) -> bytes: - assert n <= len(self.data) - result, self.data = self.data[:n], self.data[n:] - return result - - -class Packet: - def __init__(self, data: List[Any], **kwargs): - del kwargs # unused - self.data = data - - def __repr__(self): - res = list() - for item in self.data: - if isinstance(item, enum.Enum): - res.append(item.name) - elif isinstance(item, bytes): - res.append(binascii.hexlify(item).decode()) - else: - res.append(str(item)) - return f"{self.__class__.__name__} [{' '.join(res)}]" - - -class AmbiguousPacket(Packet): - pass - - -class UnknownPacket(Packet): - pass - - -class Parser: - def __init__(self): - self.specs = list() - - def register(self, *fields: FieldSpec): - def deco(cls: Type[Packet]): - self.specs.append((cls, fields)) - return cls - - return deco - - def parse_packet(self, data: bytes) -> Packet: - result = None - for (cls, fields) in self.specs: - try: - m = _match_spec(cls, fields, data) - except AssertionError: - continue - if m: - if result: - return AmbiguousPacket([data]) - result = m - - if not result: - return UnknownPacket([data]) - - return result - - -def _match_spec(cls: Type[Packet], fields: List[FieldSpec], data: bytes) -> Packet: - - buffer = Buffer(data) - - backref: Dict[str, Any] = dict() - - res_data: List[Any] = list() - - for fieldspec in fields: - if isinstance(fieldspec, bytes): - assert buffer.consume(len(fieldspec)) == fieldspec - res_data.append(fieldspec) - continue - - if fieldspec is None: - res_data.append(buffer.data) - break # TODO implement unknown blob can also be in the middle - - if isinstance(fieldspec, tuple): - fieldname, fieldlen, extractor = fieldspec - - # backreference - if isinstance(fieldlen, str): - fieldlen = backref[fieldlen] - - assert isinstance(fieldlen, int) - - fielddata = buffer.consume(fieldlen) - - if extractor: - fielddata = extractor(fielddata) - - if fieldname: - backref[fieldname] = fielddata - - res_data.append(fielddata) - - return cls(res_data, **backref) - - -def unknown(n: int, format: Extractor = None) -> FieldSpec: - return (None, n, format)