mirror of
https://github.com/0O0o0oOoO00/Alas.git
synced 2026-05-14 14:49:25 +08:00
313 lines
9.5 KiB
Python
313 lines
9.5 KiB
Python
import os
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
from typing import Dict
|
||
|
||
|
||
class LabelType:
|
||
OPTIONAL = 1
|
||
REQUIRED = 2
|
||
REPEATED = 3
|
||
|
||
@staticmethod
|
||
def type_to_name(label_type):
|
||
if label_type == LabelType.OPTIONAL:
|
||
return "optional"
|
||
elif label_type == LabelType.REQUIRED:
|
||
return "required"
|
||
elif label_type == LabelType.REPEATED:
|
||
return "repeated"
|
||
else:
|
||
raise ValueError(f"Invalid label type: {label_type}")
|
||
|
||
|
||
class CppType:
|
||
INT32 = 1
|
||
INT64 = 2
|
||
UINT32 = 3
|
||
UINT64 = 4
|
||
DOUBLE = 5
|
||
FLOAT = 6
|
||
BOOL = 7
|
||
ENUM = 8
|
||
STRING = 9
|
||
MESSAGE = 10
|
||
|
||
@staticmethod
|
||
def type_to_name(cpp_type):
|
||
if cpp_type == CppType.INT32:
|
||
return "int32"
|
||
elif cpp_type == CppType.INT64:
|
||
return "int64"
|
||
elif cpp_type == CppType.UINT32:
|
||
return "uint32"
|
||
elif cpp_type == CppType.UINT64:
|
||
return "uint64"
|
||
elif cpp_type == CppType.DOUBLE:
|
||
return "double"
|
||
elif cpp_type == CppType.FLOAT:
|
||
return "float"
|
||
elif cpp_type == CppType.BOOL:
|
||
return "bool"
|
||
elif cpp_type == CppType.ENUM:
|
||
return "enum"
|
||
elif cpp_type == CppType.STRING:
|
||
return "string"
|
||
elif cpp_type == CppType.MESSAGE:
|
||
return "message"
|
||
else:
|
||
raise ValueError(f"Invalid cpp type: {cpp_type}")
|
||
|
||
|
||
class FieldType:
|
||
DOUBLE = 1
|
||
FLOAT = 2
|
||
INT64 = 3
|
||
UINT64 = 4
|
||
INT32 = 5
|
||
FIXED64 = 6
|
||
FIXED32 = 7
|
||
BOOL = 8
|
||
STRING = 9
|
||
GROUP = 10
|
||
MESSAGE = 11
|
||
BYTES = 12
|
||
UINT32 = 13
|
||
ENUM = 14
|
||
SFIXED32 = 15
|
||
SFIXED64 = 16
|
||
SINT32 = 17
|
||
SINT64 = 18
|
||
|
||
@staticmethod
|
||
def type_to_name(field_type):
|
||
if field_type == FieldType.DOUBLE:
|
||
return "double"
|
||
elif field_type == FieldType.FLOAT:
|
||
return "float"
|
||
elif field_type == FieldType.INT64:
|
||
return "int64"
|
||
elif field_type == FieldType.UINT64:
|
||
return "uint64"
|
||
elif field_type == FieldType.INT32:
|
||
return "int32"
|
||
elif field_type == FieldType.FIXED64:
|
||
return "fixed64"
|
||
elif field_type == FieldType.FIXED32:
|
||
return "fixed32"
|
||
elif field_type == FieldType.BOOL:
|
||
return "bool"
|
||
elif field_type == FieldType.STRING:
|
||
return "string"
|
||
elif field_type == FieldType.GROUP:
|
||
return "group"
|
||
elif field_type == FieldType.MESSAGE:
|
||
return "message"
|
||
elif field_type == FieldType.BYTES:
|
||
return "bytes"
|
||
elif field_type == FieldType.UINT32:
|
||
return "uint32"
|
||
elif field_type == FieldType.ENUM:
|
||
return "enum"
|
||
elif field_type == FieldType.SFIXED32:
|
||
return "sfixed32"
|
||
elif field_type == FieldType.SFIXED64:
|
||
return "sfixed64"
|
||
elif field_type == FieldType.SINT32:
|
||
return "sint32"
|
||
elif field_type == FieldType.SINT64:
|
||
return "sint64"
|
||
else:
|
||
raise ValueError(f"Invalid field type: {field_type}")
|
||
|
||
|
||
@dataclass(init=False)
|
||
class Field:
|
||
name: str
|
||
full_name: str
|
||
number: int
|
||
index: int
|
||
label: int
|
||
has_default_value: bool
|
||
default_value: str
|
||
tp: int # type
|
||
cpp_type: int
|
||
message_type: str
|
||
|
||
def __init__(self):
|
||
self.name = ""
|
||
self.full_name = ""
|
||
self.number = 0
|
||
self.index = 0
|
||
self.label = 0
|
||
self.has_default_value = False
|
||
self.default_value = ""
|
||
self.tp = 0 # type
|
||
self.cpp_type = 0
|
||
self.message_type = ""
|
||
|
||
|
||
@dataclass(init=False)
|
||
class Message:
|
||
name: str
|
||
full_name: str
|
||
fields: Dict[str, Field]
|
||
is_extendable: bool
|
||
|
||
def __init__(self):
|
||
self.name = ""
|
||
self.full_name = ""
|
||
self.fields = dict()
|
||
self.is_extendable = False
|
||
|
||
|
||
IMPORT_MAP = {
|
||
"p11": ["common.proto"],
|
||
"p12": ["common.proto"],
|
||
"p13": ["common.proto"],
|
||
"p14": ["common.proto"],
|
||
"p15": ["common.proto"],
|
||
"p16": ["common.proto"],
|
||
"p18": ["common.proto"],
|
||
"p20": ["common.proto"],
|
||
"p21": ["common.proto"],
|
||
"p22": ["common.proto"],
|
||
"p24": ["common.proto"],
|
||
"p25": ["common.proto"],
|
||
"p26": ["common.proto"],
|
||
"p28": ["common.proto"],
|
||
"p29": ["common.proto"],
|
||
"p30": ["common.proto"],
|
||
"p33": ["common.proto"],
|
||
"p34": ["common.proto"],
|
||
"p40": ["common.proto"],
|
||
"p50": ["common.proto"],
|
||
"p60": ["common.proto", "guild.proto"],
|
||
"p61": ["common.proto"],
|
||
"p62": ["common.proto", "guild.proto"],
|
||
"p63": ["common.proto"],
|
||
"p64": ["common.proto"],
|
||
}
|
||
|
||
NO_PACKAGE = ["common.proto", "guild.proto"]
|
||
|
||
def gen_proto_file(lua_file_path, proto_file_path):
|
||
print(f"Parse {lua_file_path}, generate {proto_file_path}")
|
||
with open(str(lua_file_path), mode="r", encoding="utf-8") as f:
|
||
lines = f.readlines()
|
||
|
||
s, e = 0, 0
|
||
for i in range(len(lines)):
|
||
l = lines[i]
|
||
if l.find("name") != -1 and s == 0:
|
||
s = i
|
||
if l.find("Message") != -1 and e == 0:
|
||
e = i
|
||
break
|
||
|
||
lines = lines[s:e]
|
||
|
||
messages: Dict[str, Message] = dict()
|
||
|
||
for line in lines:
|
||
z = line[0]
|
||
if z == " " or not z.isprintable() or z == "\t" or z == "}":
|
||
continue
|
||
if line.find("Message") != -1:
|
||
continue
|
||
if line.startswith("slot"):
|
||
props, val = [x.strip() for x in line.split("=")]
|
||
splited = props.split(".")
|
||
method_upper_name = splited[1].replace("_FIELD_LIST", "")
|
||
field_upper_name = splited[2].replace("_FIELD", "").replace(f"{method_upper_name}_", "")
|
||
message = messages.get(method_upper_name, Message())
|
||
field = message.fields.get(field_upper_name, Field())
|
||
prop = splited[3]
|
||
if prop == "name":
|
||
field.name = str(val).replace('"', "")
|
||
elif prop == "full_name":
|
||
field.full_name = str(val).replace('"', "")
|
||
elif prop == "number":
|
||
field.number = int(val)
|
||
elif prop == "index":
|
||
field.index = int(val)
|
||
elif prop == "label":
|
||
field.label = int(val)
|
||
elif prop == "has_default_value":
|
||
field.has_default_value = False if str(val).replace('"', '') == "false" else True
|
||
elif prop == "default_value":
|
||
field.default_value = str(val).replace('"', "")
|
||
elif prop == "type":
|
||
field.tp = int(val)
|
||
elif prop == "cpp_type":
|
||
field.cpp_type = int(val)
|
||
elif prop == "message_type":
|
||
ty = str(val).replace('"', "")
|
||
if ty.find("slot") != -1 or ty.find(".") != -1:
|
||
field.message_type = ty.split(".")[-1]
|
||
else:
|
||
field.message_type = ty
|
||
message.fields[field_upper_name] = field
|
||
messages[method_upper_name] = message
|
||
pass
|
||
else:
|
||
props, val = [x.strip() for x in line.split("=")]
|
||
splited = props.split(".")
|
||
name, prop = splited
|
||
message = messages.get(name, Message())
|
||
if prop == "name":
|
||
message.name = str(val).replace('"', "")
|
||
elif prop == "full_name":
|
||
message.full_name = str(val).replace('"', "")
|
||
elif prop == "is_extendable":
|
||
message.is_extendable = bool(val)
|
||
messages[name] = message
|
||
|
||
with open(str(proto_file_path), mode="w", encoding="utf-8") as f:
|
||
f.write("syntax = \"proto2\";\n\n")
|
||
if proto_file_path.name not in NO_PACKAGE:
|
||
f.write(f"package {proto_file_path.stem};\n\n")
|
||
|
||
file_path = Path(lua_file_path)
|
||
proto_name = file_path.stem.replace("_pb", "")
|
||
import_files = IMPORT_MAP.get(proto_name, [])
|
||
for import_file in import_files:
|
||
f.write(f"import \"{import_file}\";\n")
|
||
f.write("\n")
|
||
|
||
for k, v in messages.items():
|
||
f.write(f"message {k} {{\n")
|
||
for i, field in v.fields.items():
|
||
label_type = LabelType.type_to_name(field.label)
|
||
field_type = FieldType.type_to_name(field.tp)
|
||
if field_type == "message":
|
||
field_type = field.message_type
|
||
default_value = ""
|
||
if field.has_default_value:
|
||
if field.tp == FieldType.STRING:
|
||
val = field.default_value
|
||
default_value = f"[default = \"{val if val != 'nil' else str()}\"]"
|
||
elif field.tp == FieldType.BOOL:
|
||
default_value = f"[default = {field.default_value}]"
|
||
elif field.label == LabelType.REPEATED:
|
||
pass
|
||
f.write(f" {label_type} {field_type} {field.name} = {field.number} {default_value};\n")
|
||
else:
|
||
f.write(f" {label_type} {field_type} {field.name} = {field.number};\n")
|
||
|
||
f.write(f"}} // {v.full_name}\n\n")
|
||
|
||
|
||
def main():
|
||
for file in (AZUR_LANE_LUA_SCRIPT_DIR / "CN" / "net" / "protocol").glob("*.lua"):
|
||
if file.stem in ["p70_pb"]: # TODO: fix this, skip this file for now, p70_pb.lua has some error
|
||
continue
|
||
gen_proto_file(file, Path("./proto") / (file.stem.replace("_pb", "") + ".proto"))
|
||
|
||
|
||
AZUR_LANE_LUA_SCRIPT_DIR = Path('E:/blhxlua')
|
||
|
||
if __name__ == "__main__":
|
||
main()
|