diff --git a/module/gamefree/netpkg.py b/module/gamefree/netpkg.py index feb9ea3b3..faec8e765 100644 --- a/module/gamefree/netpkg.py +++ b/module/gamefree/netpkg.py @@ -1,5 +1,5 @@ import threading -from typing import Self +from typing import Union from google.protobuf import message import module.gamefree.bytearray as ba @@ -9,13 +9,21 @@ class AzurLaneNetworkEndPackage: ... +class AzurLaneNetworkPackageAbort(Exception): + ... + + class AzurLaneNetworkPackage: def __init__(self, id, proto_message: message.Message): self.event = threading.Event() self.proto_message = proto_message self.id = id - self.returned_data: bytes = None + self.returned_data: ba.ByteArray = None + self.abort = False + + def is_aborted(self): + return self.abort def pack(self) -> bytes: buffer = ba.ByteArray() @@ -34,8 +42,13 @@ class AzurLaneNetworkPackage: return buffer.toBytes() - def unpack(self, data: bytes) -> Self: - buffer = ba.ByteArray.fromBytes(data) + def unpack(self, data: Union[bytes, ba.ByteArray]): + if isinstance(data, bytes): + buffer = ba.ByteArray.fromBytes(data) + elif isinstance(data, ba.ByteArray): + buffer = data + else: + raise TypeError(f"Invalid data type: {type(data).__name__}") buffer.readBigEndianUInt16() buffer.readBigEndianUInt8() diff --git a/module/gamefree/network.py b/module/gamefree/network.py index 18199b3bb..c3fdcdd97 100644 --- a/module/gamefree/network.py +++ b/module/gamefree/network.py @@ -12,17 +12,17 @@ class AzurLaneNetworkClient: def __init__(self): self.task_queue = Queue() self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.worker: threading.Thread = None - self.heartbeater: threading.Thread = None - self.heartbeat_thread_event = threading.Event() + self.vm: Union[threading.Thread, None] = None + self.vm_continue = threading.Event() + self.vm_exception: Union[Exception, None] = None def __del__(self): - self.stop_work() + self.vm_stop() def connect(self, addr, port): self.server_socket.connect((addr, port)) - def transfer(self, pkg: AzurLaneNetworkPackage) -> bytes: + def transfer(self, pkg: AzurLaneNetworkPackage) -> ba.ByteArray: self.server_socket.send(pkg.pack()) buffer = ba.ByteArray() while 1: @@ -30,48 +30,77 @@ class AzurLaneNetworkClient: if returned_data == 0: break buffer.writeBytes(returned_data) - return buffer.toBytes() + return buffer - def start_heartbeat(self): - if self.heartbeater is not None: - if self.heartbeater.is_alive(): + def clear_task_queue(self): + while not self.task_queue.empty(): + pkg = self.task_queue.get() + pkg.is_aborted = True + pkg.event.set() + + def vm_clear_task_queue(self): + self.vm_interrupt() + self.clear_task_queue() + self.vm_resume() + + def vm_start(self): + if self.vm is not None: + if self.vm.is_alive(): return - self.heartbeater = threading.Thread(target=self.work_thread) - self.heartbeater.start() + self.vm = threading.Thread(target=self.vm_thread) + self.vm.start() + self.vm_continue.set() - def start_work(self): - if self.worker is not None: - if self.worker.is_alive(): - return - self.worker = threading.Thread(target=self.work_thread) - self.worker.start() + def vm_interrupt(self): + self.vm_continue.clear() - def stop_work(self): - self.heartbeat_thread_event.set() - self.task_queue.queue.clear() - self.task_queue.put(AzurLaneNetworkEndPackage()) + def vm_resume(self): + self.vm_continue.set() - def work_thread(self): + def vm_stop(self): + self.vm_interrupt() + self.clear_task_queue() + self.queue_package(AzurLaneNetworkEndPackage()) + self.vm_resume() + + def vm_wait_until_exception_handled(self): + self.vm_continue.wait() + + def vm_get_exception(self) -> Exception: + return self.vm_exception + + def vm_has_exception(self): + return self.vm_exception is not None + + def vm_set_exception(self, e): + self.vm_exception = e + + def vm_clear_exception(self): + self.vm_exception = None + + def vm_ensure_continue(self): + if not self.vm_continue.is_set(): + self.vm_continue.wait() + + def vm_thread(self): while 1: + self.vm_ensure_continue() pkg = self.task_queue.get() if isinstance(pkg, AzurLaneNetworkEndPackage): break elif isinstance(pkg, AzurLaneNetworkPackage): - data = self.transfer(pkg) - pkg.returned_data = data - pkg.event.set() + try: + data = self.transfer(pkg) + pkg.returned_data = data + pkg.event.set() + except Exception as e: + self.vm_set_exception(e) + self.vm_interrupt() + pkg.event.set() + self.vm_wait_until_exception_handled() else: logger.warning(f"Unknown net package class: {type(pkg).__name__}") - def heartbeat_thread(self): - while 1: - if self.heartbeat_thread_event.is_set(): - break - self.task_queue.put(HeartBeatPackage()) - if self.heartbeat_thread_event.is_set(): - break - time.sleep(1) - def queue_package(self, pkg): self.task_queue.put(pkg) @@ -86,8 +115,12 @@ class AzurLaneNetwork: def __init__(self): self.client = AzurLaneNetworkClient.get_instance() - def send(self, pkg: AzurLaneNetworkPackage) -> bytes: + def send(self, pkg: AzurLaneNetworkPackage) -> ba.ByteArray: self.client.queue_package(pkg) pkg.event.wait() + if self.client.vm_has_exception(): + raise self.client.vm_get_exception() + if pkg.is_aborted(): + raise AzurLaneNetworkPackageAbort() return pkg.returned_data