diff --git a/README.md b/README.md index 2aa2f9b..cbb90ba 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,12 @@ [![github actions](https://github.com/codematrixer/hmdriver2/actions/workflows/release.yml/badge.svg)](https://github.com/codematrixer/hmdriver2/actions) [![pypi version](https://img.shields.io/pypi/v/hmdriver2.svg)](https://pypi.python.org/pypi/hmdriver2) ![python](https://img.shields.io/pypi/pyversions/hmdriver2.svg) -[![downloads](https://pepy.tech/badge/hmdriver2)](https://pepy.tech/project/hmdriver2) +[![downloads](https://pepy.tech/badge/hmdriver2)](https://pepy.tech/project/hmdriver2) -> 写这个项目前github上已有个叫`hmdriver`的项目,但它是侵入式(需要提前在手机端安装一个testRunner app);另外鸿蒙官方提供的hypium自动化框架,使用较为复杂,依赖繁杂。于是决定重写一套。 +> **📢 重要通知:** +> +> **由于原项目的维护模式已变更,本仓库(当前项目)将接替成为核心维护分支。会尽量保持更新(包括功能改进、Bug修复等)。** **hmdriver2** 是一款支持`HarmonyOS NEXT`系统的UI自动化框架,**无侵入式**,提供应用管理,UI操作,元素定位等功能,轻量高效,上手简单,快速实现鸿蒙应用自动化测试需求。 diff --git a/example.py b/example.py index f11eed1..6a1281b 100644 --- a/example.py +++ b/example.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- import time + from hmdriver2.driver import Driver from hmdriver2.proto import DeviceInfo, KeyCode, ComponentData, DisplayRotation - # New driver d = Driver("FMR0223C13000649") @@ -135,3 +135,5 @@ d.xpath('//*[@text="showDialog"]').click_if_exists() d.xpath('//root[1]/Row[1]/Column[1]/Row[1]/Button[3]').click() d.xpath('//*[@text="showDialog"]').input_text("xxx") +d.xpath('//*[@text="showDialog"]').text() +d.xpath('//*[@text="showDialog"]').clickable() diff --git a/hmdriver2/_client.py b/hmdriver2/_client.py index 6dfe01a..8954a4a 100644 --- a/hmdriver2/_client.py +++ b/hmdriver2/_client.py @@ -1,100 +1,241 @@ # -*- coding: utf-8 -*- -import socket + +import hashlib import json -import time import os -import hashlib -import typing -from typing import Optional +import socket +import struct +import time from datetime import datetime from functools import cached_property +from typing import Optional, Union, Dict, List, Any from . import logger +from .exception import InvokeHypiumError, InvokeCaptures from .hdc import HdcWrapper from .proto import HypiumResponse, DriverData -from .exception import InvokeHypiumError, InvokeCaptures +# 连接相关常量 +UITEST_SERVICE_PORT = 8012 # 设备端服务端口 +SOCKET_TIMEOUT = 20 # Socket 超时时间(秒) +LOCAL_HOST = "127.0.0.1" # 本地主机地址 -UITEST_SERVICE_PORT = 8012 -SOCKET_TIMEOUT = 20 +# 消息协议常量 +MSG_HEADER = b'_uitestkit_rpc_message_head_' # 消息头标识 +MSG_TAILER = b'_uitestkit_rpc_message_tail_' # 消息尾标识 +SESSION_ID_LENGTH = 4 # 会话ID长度(字节) + +# API 模块常量 +API_MODULE = "com.ohos.devicetest.hypiumApiHelper" # API 模块名 +API_METHOD_HYPIUM = "callHypiumApi" # Hypium API 调用方法 +API_METHOD_CAPTURES = "Captures" # Captures API 调用方法 +DEFAULT_THIS = "Driver#0" # 默认目标对象 class HmClient: - """harmony uitest client""" + """ + Harmony OS 设备通信客户端 + + 负责与设备建立连接、发送命令和接收响应,是与设备交互的基础类。 + 通过 HDC(Harmony Debug Console)建立端口转发,使用 Socket 进行通信。 + """ + def __init__(self, serial: str): + """ + 初始化客户端 + + Args: + serial: 设备序列号 + """ self.hdc = HdcWrapper(serial) - self.sock = None + self.sock: Optional[socket.socket] = None + self._header_length = len(MSG_HEADER) + self._tailer_length = len(MSG_TAILER) @cached_property - def local_port(self): + def local_port(self) -> int: + """ + 获取本地转发端口 + + Returns: + int: 本地端口号 + """ fports = self.hdc.list_fport() - logger.debug(fports) if fports else None - + if fports: + logger.debug(fports) return self.hdc.forward_port(UITEST_SERVICE_PORT) - def _rm_local_port(self): - logger.debug("rm fport local port") + def _rm_local_port(self) -> None: + """移除本地端口转发""" + logger.debug("移除本地端口转发") self.hdc.rm_forward(self.local_port, UITEST_SERVICE_PORT) - def _connect_sock(self): - """Create socket and connect to the uiTEST server.""" + def _connect_sock(self) -> None: + """创建 Socket 并连接到 UITest 服务器""" self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.settimeout(SOCKET_TIMEOUT) - self.sock.connect((("127.0.0.1", self.local_port))) - - def _send_msg(self, msg: typing.Dict): - """Send an message to the server. - Example: - { - "module": "com.ohos.devicetest.hypiumApiHelper", - "method": "callHypiumApi", - "params": { - "api": "Driver.create", - "this": null, - "args": [], - "message_type": "hypium" - }, - "request_id": "20240815161352267072", - "client": "127.0.0.1" - } - """ - msg = json.dumps(msg, ensure_ascii=False, separators=(',', ':')) - logger.debug(f"sendMsg: {msg}") - self.sock.sendall(msg.encode('utf-8') + b'\n') - - def _recv_msg(self, buff_size: int = 4096, decode=False, print=True) -> typing.Union[bytearray, str]: - full_msg = bytearray() + self.sock.connect((LOCAL_HOST, self.local_port)) + + def _send_msg(self, msg: Dict[str, Any]) -> None: + """ + 发送消息到服务器 + + Args: + msg: 要发送的消息字典 + + 消息格式示例: + { + "module": "com.ohos.devicetest.hypiumApiHelper", + "method": "callHypiumApi", + "params": { + "api": "Driver.create", + "this": null, + "args": [], + "message_type": "hypium" + }, + "request_id": "20240815161352267072" + } + """ + # 序列化消息 + msg_str = json.dumps(msg, ensure_ascii=False, separators=(',', ':')) + logger.debug(f"发送消息: {msg_str}") + + # 生成会话ID并构建消息头 + msg_bytes = msg_str.encode('utf-8') + session_id = self._generate_session_id(msg_str) + header = ( + MSG_HEADER + + struct.pack('>I', session_id) + + struct.pack('>I', len(msg_bytes)) + ) + + # 发送完整消息(头部 + 消息体 + 尾部) + if self.sock is None: + raise ConnectionError("Socket 未连接") + self.sock.sendall(header + msg_bytes + MSG_TAILER) + + def _generate_session_id(self, message: str) -> int: + """ + 生成会话ID + + 将时间戳、消息内容和随机数据组合生成唯一标识符 + + Args: + message: 消息内容 + + Returns: + int: 生成的会话ID + """ + # 组合时间戳、消息内容和随机数据 + combined = ( + str(int(time.time() * 1000)) + # 毫秒时间戳 + message + + os.urandom(4).hex() # 16字节随机熵 + ) + # 生成哈希并取前8位转为整数 + return int(hashlib.sha256(combined.encode()).hexdigest()[:8], 16) + + def _recv_msg(self, decode: bool = False, print: bool = True) -> Union[bytearray, str]: + """ + 接收并解析消息 + + Args: + decode: 是否解码为字符串 + print: 是否打印接收到的消息 + + Returns: + 解析后的消息内容(字节数组或字符串) + + Raises: + ConnectionError: 连接中断时抛出 + """ try: - # FIXME - relay = self.sock.recv(buff_size) - if decode: - relay = relay.decode() + # 接收消息头 + header_len = self._header_length + SESSION_ID_LENGTH + 4 + header = self._recv_exact(header_len) # 头部 + session_id + length + if not header or header[:self._header_length] != MSG_HEADER: + logger.warning("接收到无效的消息头") + return bytearray() if not decode else "" + + # 解析消息长度(不验证session_id) + msg_length = struct.unpack('>I', header[self._header_length + SESSION_ID_LENGTH:])[0] + + # 接收消息体 + msg_bytes = self._recv_exact(msg_length) + if not msg_bytes: + logger.warning("接收消息体失败") + return bytearray() if not decode else "" + + # 接收消息尾 + tailer = self._recv_exact(self._tailer_length) + if not tailer or tailer != MSG_TAILER: + logger.warning("接收到无效的消息尾") + return bytearray() if not decode else "" + + # 处理消息内容 + if not decode: + logger.debug(f"接收到字节消息 (大小: {len(msg_bytes)})") + return bytearray(msg_bytes) + + # 解码为字符串 + msg_str = msg_bytes.decode('utf-8') if print: - logger.debug(f"recvMsg: {relay}") - full_msg = relay + logger.debug(f"接收到消息: {msg_str}") + return msg_str - except (socket.timeout, UnicodeDecodeError) as e: - logger.warning(e) - if decode: - full_msg = "" + except (socket.timeout, ValueError, json.JSONDecodeError) as e: + logger.warning(f"接收消息时出错: {e}") + return bytearray() if not decode else "" - return full_msg - - def invoke(self, api: str, this: str = "Driver#0", args: typing.List = []) -> HypiumResponse: + def _recv_exact(self, length: int) -> bytes: """ - Hypium invokes given API method with the specified arguments and handles exceptions. - + 精确接收指定长度的数据 + + 使用内存视图优化接收性能,确保接收完整数据 + Args: - api (str): The name of the API method to invoke. - args (List, optional): A list of arguments to pass to the API method. Default is an empty list. - + length: 要接收的数据长度 + Returns: - HypiumResponse: The response from the API call. - + bytes: 接收到的数据 + Raises: - InvokeHypiumError: If the API call returns an exception in the response. + ConnectionError: 连接关闭时抛出 + """ + if self.sock is None: + raise ConnectionError("Socket 未连接") + + buf = bytearray(length) + view = memoryview(buf) + pos = 0 + + while pos < length: + chunk_size = self.sock.recv_into(view[pos:], length - pos) + if not chunk_size: + raise ConnectionError("接收数据时连接已关闭") + pos += chunk_size + + return buf + + def invoke(self, api: str, this: Optional[str] = DEFAULT_THIS, args: Optional[List[Any]] = None) -> HypiumResponse: """ + 调用 Hypium API + + Args: + api: API 名称 + this: 目标对象标识符,默认为 "Driver#0" + args: API 参数列表,默认为空列表 + + Returns: + HypiumResponse: API 调用响应 + + Raises: + InvokeHypiumError: API 调用返回异常时抛出 + """ + if args is None: + args = [] + # 构建请求参数 request_id = datetime.now().strftime("%Y%m%d%H%M%S%f") params = { "api": api, @@ -103,51 +244,96 @@ def invoke(self, api: str, this: str = "Driver#0", args: typing.List = []) -> Hy "message_type": "hypium" } + # 构建完整消息 msg = { - "module": "com.ohos.devicetest.hypiumApiHelper", - "method": "callHypiumApi", + "module": API_MODULE, + "method": API_METHOD_HYPIUM, "params": params, "request_id": request_id } + # 发送请求并处理响应 self._send_msg(msg) raw_data = self._recv_msg(decode=True) - data = HypiumResponse(**(json.loads(raw_data))) + if not raw_data: + raise InvokeHypiumError("接收响应失败") + + try: + data = HypiumResponse(**(json.loads(raw_data))) + except json.JSONDecodeError as e: + raise InvokeHypiumError(f"解析响应失败: {e}") + + # 处理异常 if data.exception: raise InvokeHypiumError(data.exception) return data - def invoke_captures(self, api: str, args: typing.List = []) -> HypiumResponse: + def invoke_captures(self, api: str, args: Optional[List[Any]] = None) -> HypiumResponse: + """ + 调用 Captures API + + Args: + api: API 名称 + args: API 参数列表,默认为空列表 + + Returns: + HypiumResponse: API 调用响应 + + Raises: + InvokeCaptures: API 调用返回异常时抛出 + """ + if args is None: + args = [] + + # 构建请求参数 request_id = datetime.now().strftime("%Y%m%d%H%M%S%f") params = { "api": api, "args": args } + # 构建完整消息 msg = { - "module": "com.ohos.devicetest.hypiumApiHelper", - "method": "Captures", + "module": API_MODULE, + "method": API_METHOD_CAPTURES, "params": params, "request_id": request_id } + # 发送请求并处理响应 self._send_msg(msg) raw_data = self._recv_msg(decode=True) - data = HypiumResponse(**(json.loads(raw_data))) + if not raw_data: + raise InvokeCaptures("接收响应失败") + + try: + data = HypiumResponse(**(json.loads(raw_data))) + except json.JSONDecodeError as e: + raise InvokeCaptures(f"解析响应失败: {e}") + + # 处理异常 if data.exception: raise InvokeCaptures(data.exception) return data - def start(self): - logger.info("Start HmClient connection") + def start(self) -> None: + """ + 启动客户端连接 + + 初始化 UITest 服务,建立 Socket 连接,创建驱动实例 + """ + logger.info("启动 HmClient 连接") _UITestService(self.hdc).init() - self._connect_sock() - self._create_hdriver() - def release(self): - logger.info(f"Release {self.__class__.__name__} connection") + def release(self) -> None: + """ + 释放客户端资源 + + 关闭 Socket 连接,移除端口转发 + """ + logger.info(f"释放 {self.__class__.__name__} 连接") try: if self.sock: self.sock.close() @@ -156,54 +342,97 @@ def release(self): self._rm_local_port() except Exception as e: - logger.error(f"An error occurred: {e}") + logger.error(f"释放资源时出错: {e}") def _create_hdriver(self) -> DriverData: - logger.debug("Create uitest driver") + """ + 创建 UITest 驱动实例 + + Returns: + DriverData: 驱动数据对象 + """ + logger.debug("创建 UITest 驱动") resp: HypiumResponse = self.invoke("Driver.create") # {"result":"Driver#0"} hdriver: DriverData = DriverData(resp.result) return hdriver class _UITestService: + """ + UITest 服务管理类 + + 负责初始化设备上的 UITest 服务,包括安装必要的库文件和启动服务进程 + """ + def __init__(self, hdc: HdcWrapper): - """Initialize the UITestService class.""" + """ + 初始化 UITest 服务管理类 + + Args: + hdc: HDC 包装器实例 + """ self.hdc = hdc + self._remote_agent_path = "/data/local/tmp/agent.so" - def init(self): + def init(self) -> None: """ - Initialize the UITest service: - 1. Ensure agent.so is set up on the device. - 2. Start the UITest daemon. - - Note: 'hdc shell aa test' will also start a uitest daemon. - $ hdc shell ps -ef |grep uitest - shell 44306 1 25 11:03:37 ? 00:00:16 uitest start-daemon singleness - shell 44416 1 2 11:03:42 ? 00:00:01 uitest start-daemon com.hmtest.uitest@4x9@1" + 初始化 UITest 服务 + + 1. 确保设备上安装了 agent.so + 2. 启动 UITest 守护进程 + + Note: + 'hdc shell aa test' 也会启动 UITest 守护进程 + $ hdc shell ps -ef |grep uitest + shell 44306 1 25 11:03:37 ? 00:00:16 uitest start-daemon singleness + shell 44416 1 2 11:03:42 ? 00:00:01 uitest start-daemon com.hmtest.uitest@4x9@1" """ - - logger.debug("Initializing UITest service") + logger.debug("初始化 UITest 服务") local_path = self._get_local_agent_path() - remote_path = "/data/local/tmp/agent.so" - self._kill_uitest_service() # Stop the service if running - self._setup_device_agent(local_path, remote_path) + # 按顺序执行初始化步骤 + self._kill_uitest_service() # 停止可能运行的服务 + self._setup_device_agent(local_path, self._remote_agent_path) self._start_uitest_daemon() - time.sleep(0.5) + time.sleep(0.5) # 等待服务启动 def _get_local_agent_path(self) -> str: - """Return the local path of the agent file.""" - target_agent = "uitest_agent_v1.1.0.so" + """ + 获取本地 agent.so 文件路径 + + 根据设备 CPU 架构选择对应的库文件 + + Returns: + str: 本地 agent.so 文件路径 + """ + cpu_abi = self.hdc.cpu_abi() + target_agent = os.path.join("so", cpu_abi, "agent.so") return os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets", target_agent) def _get_remote_md5sum(self, file_path: str) -> Optional[str]: - """Get the MD5 checksum of a remote file.""" + """ + 获取远程文件的 MD5 校验和 + + Args: + file_path: 远程文件路径 + + Returns: + Optional[str]: MD5 校验和,如果文件不存在则返回 None + """ command = f"md5sum {file_path}" output = self.hdc.shell(command).output.strip() return output.split()[0] if output else None def _get_local_md5sum(self, file_path: str) -> str: - """Get the MD5 checksum of a local file.""" + """ + 获取本地文件的 MD5 校验和 + + Args: + file_path: 本地文件路径 + + Returns: + str: MD5 校验和 + """ hash_md5 = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): @@ -211,27 +440,51 @@ def _get_local_md5sum(self, file_path: str) -> str: return hash_md5.hexdigest() def _is_remote_file_exists(self, file_path: str) -> bool: - """Check if a file exists on the device.""" + """ + 检查远程文件是否存在 + + Args: + file_path: 远程文件路径 + + Returns: + bool: 文件存在返回 True,否则返回 False + """ command = f"[ -f {file_path} ] && echo 'exists' || echo 'not exists'" result = self.hdc.shell(command).output.strip() return "exists" in result - def _setup_device_agent(self, local_path: str, remote_path: str): - """Ensure the remote agent file is correctly set up.""" + def _setup_device_agent(self, local_path: str, remote_path: str) -> None: + """ + 设置设备上的 agent.so 文件 + + 如果远程文件不存在或与本地文件不一致,则上传本地文件 + + Args: + local_path: 本地文件路径 + remote_path: 远程文件路径 + """ + # 检查远程文件是否存在且与本地文件一致 if self._is_remote_file_exists(remote_path): local_md5 = self._get_local_md5sum(local_path) remote_md5 = self._get_remote_md5sum(remote_path) if local_md5 == remote_md5: - logger.debug("Remote agent file is up-to-date") + logger.debug("远程 agent 文件已是最新") self.hdc.shell(f"chmod +x {remote_path}") return self.hdc.shell(f"rm {remote_path}") + # 上传并设置权限 self.hdc.send_file(local_path, remote_path) self.hdc.shell(f"chmod +x {remote_path}") - logger.debug("Updated remote agent file") + logger.debug("已更新远程 agent 文件") - def _get_uitest_pid(self) -> typing.List[str]: + def _get_uitest_pid(self) -> List[str]: + """ + 获取 UITest 守护进程的 PID 列表 + + Returns: + List[str]: PID 列表 + """ proc_pids = [] result = self.hdc.shell("ps -ef").output.strip() lines = result.splitlines() @@ -242,12 +495,13 @@ def _get_uitest_pid(self) -> typing.List[str]: proc_pids.append(line.split()[1]) return proc_pids - def _kill_uitest_service(self): + def _kill_uitest_service(self) -> None: + """终止所有 UITest 守护进程""" for pid in self._get_uitest_pid(): self.hdc.shell(f"kill -9 {pid}") - logger.debug(f"Killed uitest process with PID {pid}") + logger.debug(f"已终止 UITest 进程,PID: {pid}") - def _start_uitest_daemon(self): - """Start the UITest daemon.""" + def _start_uitest_daemon(self) -> None: + """启动 UITest 守护进程""" self.hdc.shell("uitest start-daemon singleness") - logger.debug("Started UITest daemon") \ No newline at end of file + logger.debug("已启动 UITest 守护进程") diff --git a/hmdriver2/_gesture.py b/hmdriver2/_gesture.py index d81d6b5..4af6602 100644 --- a/hmdriver2/_gesture.py +++ b/hmdriver2/_gesture.py @@ -1,26 +1,40 @@ # -*- coding: utf-8 -*- import math -from typing import List, Union +from typing import List, Union, Optional, Tuple, Callable, Any + from . import logger -from .utils import delay from .driver import Driver -from .proto import HypiumResponse, Point from .exception import InjectGestureError +from .proto import HypiumResponse, Point +from .utils import delay + +# 手势采样时间常量(毫秒) +SAMPLE_TIME_MIN = 10 # 最小采样时间 +SAMPLE_TIME_NORMAL = 50 # 正常采样时间 +SAMPLE_TIME_MAX = 100 # 最大采样时间 + +# 手势步骤类型常量 +STEP_TYPE_START = "start" # 开始手势 +STEP_TYPE_MOVE = "move" # 移动手势 +STEP_TYPE_PAUSE = "pause" # 暂停手势 class _Gesture: - SAMPLE_TIME_MIN = 10 - SAMPLE_TIME_NORMAL = 50 - SAMPLE_TIME_MAX = 100 + """ + 手势操作类 + + 提供了创建和执行复杂手势操作的功能,包括点击、滑动、暂停等。 + 通过链式调用可以组合多个手势步骤。 + """ - def __init__(self, d: Driver, sampling_ms=50): + def __init__(self, d: Driver, sampling_ms: int = SAMPLE_TIME_NORMAL): """ - Initialize a gesture object. - + 初始化手势对象 + Args: - d (Driver): The driver object to interact with. - sampling_ms (int): Sampling time for gesture operation points in milliseconds. Default is 50. + d: Driver 实例,用于与设备交互 + sampling_ms: 手势操作点的采样时间(毫秒),默认为 50 """ self.d = d self.steps: List[GestureStep] = [] @@ -28,74 +42,86 @@ def __init__(self, d: Driver, sampling_ms=50): def _validate_sampling_time(self, sampling_time: int) -> int: """ - Validate the input sampling time. - + 验证采样时间是否在有效范围内 + Args: - sampling_time (int): The given sampling time. - + sampling_time: 给定的采样时间 + Returns: - int: Valid sampling time within allowed range. + int: 有效范围内的采样时间 """ - if _Gesture.SAMPLE_TIME_MIN <= sampling_time <= _Gesture.SAMPLE_TIME_MAX: + if SAMPLE_TIME_MIN <= sampling_time <= SAMPLE_TIME_MAX: return sampling_time - return _Gesture.SAMPLE_TIME_NORMAL + return SAMPLE_TIME_NORMAL - def _release(self): + def _release(self) -> None: + """清空手势步骤列表""" self.steps = [] def start(self, x: Union[int, float], y: Union[int, float], interval: float = 0.5) -> '_Gesture': """ - Start gesture operation. - + 开始手势操作 + Args: - x: oordinate as a percentage or absolute value. - y: coordinate as a percentage or absolute value. - interval (float, optional): Duration to hold at start position in seconds. Default is 0.5. - + x: X 坐标,可以是百分比(0-1)或绝对值 + y: Y 坐标,可以是百分比(0-1)或绝对值 + interval: 在起始位置停留的时间(秒),默认为 0.5 + Returns: - Gesture: Self instance to allow method chaining. + _Gesture: 当前实例,支持链式调用 + + Raises: + InjectGestureError: 手势已经开始时抛出 """ self._ensure_can_start() - self._add_step(x, y, "start", interval) + self._add_step(x, y, STEP_TYPE_START, interval) return self def move(self, x: Union[int, float], y: Union[int, float], interval: float = 0.5) -> '_Gesture': """ - Move to specified position. - + 移动到指定位置 + Args: - x: coordinate as a percentage or absolute value. - y: coordinate as a percentage or absolute value. - interval (float, optional): Duration of move in seconds. Default is 0.5. - + x: X 坐标,可以是百分比(0-1)或绝对值 + y: Y 坐标,可以是百分比(0-1)或绝对值 + interval: 移动的持续时间(秒),默认为 0.5 + Returns: - Gesture: Self instance to allow method chaining. + _Gesture: 当前实例,支持链式调用 + + Raises: + InjectGestureError: 手势未开始时抛出 """ self._ensure_started() - self._add_step(x, y, "move", interval) + self._add_step(x, y, STEP_TYPE_MOVE, interval) return self def pause(self, interval: float = 1) -> '_Gesture': """ - Pause at current position for specified duration. - + 在当前位置暂停指定时间 + Args: - interval (float, optional): Duration to pause in seconds. Default is 1. - + interval: 暂停时间(秒),默认为 1 + Returns: - Gesture: Self instance to allow method chaining. + _Gesture: 当前实例,支持链式调用 + + Raises: + InjectGestureError: 手势未开始时抛出 """ self._ensure_started() pos = self.steps[-1].pos - self.steps.append(GestureStep(pos, "pause", interval)) + self.steps.append(GestureStep(pos, STEP_TYPE_PAUSE, interval)) return self @delay - def action(self): + def action(self) -> None: """ - Execute the gesture action. + 执行手势操作 + + 该方法会将所有已定义的手势步骤转换为触摸事件并发送到设备 """ - logger.info(f">>>Gesture steps: {self.steps}") + logger.info(f">>>执行手势步骤: {self.steps}") total_points = self._calculate_total_points() pointer_matrix = self._create_pointer_matrix(total_points) @@ -105,76 +131,82 @@ def action(self): self._release() - def _create_pointer_matrix(self, total_points: int): + def _create_pointer_matrix(self, total_points: int) -> Any: """ - Create a pointer matrix for the gesture. - + 创建手势操作的指针矩阵 + Args: - total_points (int): Total number of points. - + total_points: 总点数 + Returns: - PointerMatrix: Pointer matrix object. + Any: 指针矩阵对象 """ - fingers = 1 + fingers = 1 # 当前仅支持单指操作 api = "PointerMatrix.create" data: HypiumResponse = self.d._client.invoke(api, this=None, args=[fingers, total_points]) return data.result - def _inject_pointer_actions(self, pointer_matrix): + def _inject_pointer_actions(self, pointer_matrix: Any) -> None: """ - Inject pointer actions into the driver. - + 将指针操作注入到设备 + Args: - pointer_matrix (PointerMatrix): Pointer matrix to inject. + pointer_matrix: 要注入的指针矩阵 """ api = "Driver.injectMultiPointerAction" self.d._client.invoke(api, args=[pointer_matrix, 2000]) - def _add_step(self, x: int, y: int, step_type: str, interval: float): + def _add_step(self, x: Union[int, float], y: Union[int, float], step_type: str, interval: float) -> None: """ - Add a step to the gesture. - + 添加手势步骤 + Args: - x (int): x-coordinate of the point. - y (int): y-coordinate of the point. - step_type (str): Type of step ("start", "move", or "pause"). - interval (float): Interval duration in seconds. + x: X 坐标 + y: Y 坐标 + step_type: 步骤类型("start"、"move" 或 "pause") + interval: 时间间隔(秒) """ point: Point = self.d._to_abs_pos(x, y) step = GestureStep(point.to_tuple(), step_type, interval) self.steps.append(step) - def _ensure_can_start(self): + def _ensure_can_start(self) -> None: """ - Ensure that the gesture can start. + 确保手势可以开始 + + Raises: + InjectGestureError: 手势已经开始时抛出 """ if self.steps: - raise InjectGestureError("Can't start gesture twice") + raise InjectGestureError("不能重复开始手势") - def _ensure_started(self): + def _ensure_started(self) -> None: """ - Ensure that the gesture has started. + 确保手势已经开始 + + Raises: + InjectGestureError: 手势未开始时抛出 """ if not self.steps: - raise InjectGestureError("Please call gesture.start first") + raise InjectGestureError("请先调用 gesture.start") - def _generate_points(self, pointer_matrix, total_points): + def _generate_points(self, pointer_matrix: Any, total_points: int) -> None: """ - Generate points for the pointer matrix. - + 为指针矩阵生成点 + Args: - pointer_matrix (PointerMatrix): Pointer matrix to populate. - total_points (int): Total points to generate. + pointer_matrix: 要填充的指针矩阵 + total_points: 要生成的总点数 """ - - def set_point(point_index: int, point: Point, interval: int = None): + # 定义设置点的内部函数 + def set_point(point_index: int, point: Point, interval: Optional[int] = None) -> None: """ - Set a point in the pointer matrix. - + 在指针矩阵中设置点 + Args: - point_index (int): Index of the point. - point (Point): The point object. - interval (int, optional): Interval duration. + point_index: 点的索引 + point: 点对象 + interval: 时间间隔(可选) """ if interval is not None: point.x += 65536 * interval @@ -183,30 +215,33 @@ def set_point(point_index: int, point: Point, interval: int = None): point_index = 0 + # 处理所有手势步骤 for index, step in enumerate(self.steps): - if step.type == "start": + if step.type == STEP_TYPE_START: point_index = self._generate_start_point(step, point_index, set_point) - elif step.type == "move": + elif step.type == STEP_TYPE_MOVE: point_index = self._generate_move_points(index, step, point_index, set_point) - elif step.type == "pause": + elif step.type == STEP_TYPE_PAUSE: point_index = self._generate_pause_points(step, point_index, set_point) + # 填充剩余点 step = self.steps[-1] while point_index < total_points: set_point(point_index, Point(*step.pos)) point_index += 1 - def _generate_start_point(self, step, point_index, set_point): + def _generate_start_point(self, step: 'GestureStep', point_index: int, + set_point: Callable) -> int: """ - Generate start points. - + 生成起始点 + Args: - step (GestureStep): Gesture step. - point_index (int): Current point index. - set_point (function): Function to set the point in pointer matrix. - + step: 手势步骤 + point_index: 当前点索引 + set_point: 设置点的函数 + Returns: - int: Updated point index. + int: 更新后的点索引 """ set_point(point_index, Point(*step.pos), step.interval) point_index += 1 @@ -214,18 +249,19 @@ def _generate_start_point(self, step, point_index, set_point): set_point(point_index, Point(*pos)) return point_index + 1 - def _generate_move_points(self, index, step, point_index, set_point): + def _generate_move_points(self, index: int, step: 'GestureStep', + point_index: int, set_point: Callable) -> int: """ - Generate move points. - + 生成移动点 + Args: - index (int): Step index. - step (GestureStep): Gesture step. - point_index (int): Current point index. - set_point (function): Function to set the point in pointer matrix. - + index: 步骤索引 + step: 手势步骤 + point_index: 当前点索引 + set_point: 设置点的函数 + Returns: - int: Updated point index. + int: 更新后的点索引 """ last_step = self.steps[index - 1] offset_x = step.pos[0] - last_step.pos[0] @@ -234,6 +270,10 @@ def _generate_move_points(self, index, step, point_index, set_point): interval_ms = step.interval cur_steps = self._calculate_move_step_points(distance, interval_ms) + # 避免除零错误 + if cur_steps <= 0: + cur_steps = 1 + step_x = int(offset_x / cur_steps) step_y = int(offset_y / cur_steps) @@ -246,55 +286,57 @@ def _generate_move_points(self, index, step, point_index, set_point): point_index += 1 return point_index - def _generate_pause_points(self, step, point_index, set_point): + def _generate_pause_points(self, step: 'GestureStep', point_index: int, + set_point: Callable) -> int: """ - Generate pause points. - + 生成暂停点 + Args: - step (GestureStep): Gesture step. - point_index (int): Current point index. - set_point (function): Function to set the point in pointer matrix. - + step: 手势步骤 + point_index: 当前点索引 + set_point: 设置点的函数 + Returns: - int: Updated point index. + int: 更新后的点索引 """ - points = int(step.interval / self.sampling_ms) + # 计算需要的点数 + points = max(1, int(step.interval / self.sampling_ms)) for _ in range(points): - set_point(point_index, Point(*step.pos), int(step.interval / self.sampling_ms)) + set_point(point_index, Point(*step.pos), int(step.interval / points)) point_index += 1 - pos = step.pos[0] + 3, step.pos[1] + pos = step.pos[0] + 3, step.pos[1] # 微小移动以触发事件 set_point(point_index, Point(*pos)) return point_index + 1 def _calculate_total_points(self) -> int: """ - Calculate the total number of points needed for the gesture. - + 计算手势所需的总点数 + Returns: - int: Total points. + int: 总点数 """ total_points = 0 for index, step in enumerate(self.steps): - if step.type == "start": + if step.type == STEP_TYPE_START: total_points += 2 - elif step.type == "move": - total_points += self._calculate_move_step_points( - *self._calculate_move_distance(step, index)) - elif step.type == "pause": - points = int(step.interval / self.sampling_ms) + elif step.type == STEP_TYPE_MOVE: + distance, interval_ms = self._calculate_move_distance(step, index) + total_points += self._calculate_move_step_points(distance, interval_ms) + elif step.type == STEP_TYPE_PAUSE: + points = max(1, int(step.interval / self.sampling_ms)) total_points += points + 1 return total_points - def _calculate_move_distance(self, step, index): + def _calculate_move_distance(self, step: 'GestureStep', index: int) -> Tuple[int, float]: """ - Calculate move distance and interval. - + 计算移动距离和时间间隔 + Args: - step (GestureStep): Gesture step. - index (int): Step index. - + step: 手势步骤 + index: 步骤索引 + Returns: - tuple: Tuple (distance, interval_ms). + Tuple[int, float]: (距离, 时间间隔(毫秒)) """ last_step = self.steps[index - 1] offset_x = step.pos[0] - last_step.pos[0] @@ -305,39 +347,45 @@ def _calculate_move_distance(self, step, index): def _calculate_move_step_points(self, distance: int, interval_ms: float) -> int: """ - Calculate the number of move step points based on distance and time. - + 根据距离和时间计算移动步骤点数 + Args: - distance (int): Distance to move. - interval_ms (float): Move duration in milliseconds. - + distance: 移动距离 + interval_ms: 移动持续时间(毫秒) + Returns: - int: Number of move step points. + int: 移动步骤点数 """ if interval_ms < self.sampling_ms or distance < 1: return 1 nums = interval_ms / self.sampling_ms - return distance if nums > distance else int(nums) + return min(distance, int(nums)) class GestureStep: - """Class to store each step of a gesture, not to be used directly, use via Gesture class""" + """ + 手势步骤类 + + 存储手势的每个步骤,不直接使用,通过 Gesture 类使用 + """ - def __init__(self, pos: tuple, step_type: str, interval: float): + def __init__(self, pos: Tuple[int, int], step_type: str, interval: float): """ - Initialize a gesture step. - + 初始化手势步骤 + Args: - pos (tuple): Tuple containing x and y coordinates. - step_type (str): Type of step ("start", "move", "pause"). - interval (float): Interval duration in seconds. + pos: 包含 x 和 y 坐标的元组 + step_type: 步骤类型("start"、"move"、"pause") + interval: 时间间隔(秒) """ self.pos = pos[0], pos[1] - self.interval = int(interval * 1000) + self.interval = int(interval * 1000) # 转换为毫秒 self.type = step_type - def __repr__(self): + def __repr__(self) -> str: + """返回手势步骤的字符串表示""" return f"GestureStep(pos=({self.pos[0]}, {self.pos[1]}), type='{self.type}', interval={self.interval})" - def __str__(self): + def __str__(self) -> str: + """返回手势步骤的字符串表示""" return self.__repr__() \ No newline at end of file diff --git a/hmdriver2/_screenrecord.py b/hmdriver2/_screenrecord.py index 6207d69..b8f2170 100644 --- a/hmdriver2/_screenrecord.py +++ b/hmdriver2/_screenrecord.py @@ -1,36 +1,70 @@ # -*- coding: utf-8 -*- -import typing -import threading -import numpy as np import queue +import threading from datetime import datetime +from typing import List, Optional, Any import cv2 +import numpy as np from . import logger from ._client import HmClient from .driver import Driver from .exception import ScreenRecordError +# 常量定义 +JPEG_START_FLAG = b'\xff\xd8' # JPEG 图像开始标记 +JPEG_END_FLAG = b'\xff\xd9' # JPEG 图像结束标记 +VIDEO_FPS = 10 # 视频帧率 +VIDEO_CODEC = 'mp4v' # 视频编码格式 +QUEUE_TIMEOUT = 0.1 # 队列超时时间(秒) + class RecordClient(HmClient): + """ + 屏幕录制客户端 + + 继承自 HmClient,提供设备屏幕录制功能 + """ + def __init__(self, serial: str, d: Driver): + """ + 初始化屏幕录制客户端 + + Args: + serial: 设备序列号 + d: Driver 实例 + """ super().__init__(serial) self.d = d - self.video_path = None - self.jpeg_queue = queue.Queue() - self.threads: typing.List[threading.Thread] = [] - self.stop_event = threading.Event() + self.video_path: Optional[str] = None + self.jpeg_queue: queue.Queue = queue.Queue() + self.threads: List[threading.Thread] = [] + self.stop_event: threading.Event = threading.Event() def __enter__(self): + """上下文管理器入口""" return self def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器退出时停止录制""" self.stop() - def _send_msg(self, api: str, args: list): + def _send_msg(self, api: str, args: Optional[List[Any]] = None): + """ + 发送消息到设备 + + 重写父类方法,使用 Captures API + + Args: + api: API 名称 + args: API 参数列表,默认为空列表 + """ + if args is None: + args = [] + _msg = { "module": "com.ohos.devicetest.hypiumApiHelper", "method": "Captures", @@ -43,16 +77,32 @@ def _send_msg(self, api: str, args: list): super()._send_msg(_msg) def start(self, video_path: str): - logger.info("Start RecordClient connection") - + """ + 开始屏幕录制 + + Args: + video_path: 视频保存路径 + + Returns: + RecordClient: 当前实例,支持链式调用 + + Raises: + ScreenRecordError: 启动屏幕录制失败时抛出 + """ + logger.info("开始屏幕录制") + + # 连接设备 self._connect_sock() self.video_path = video_path + # 发送开始录制命令 self._send_msg("startCaptureScreen", []) - reply: str = self._recv_msg(1024, decode=True, print=False) + # 检查响应 + reply: str = self._recv_msg(decode=True, print=False) if "true" in reply: + # 创建并启动工作线程 record_th = threading.Thread(target=self._record_worker) writer_th = threading.Thread(target=self._video_writer) record_th.daemon = True @@ -61,74 +111,100 @@ def start(self, video_path: str): writer_th.start() self.threads.extend([record_th, writer_th]) else: - raise ScreenRecordError("Failed to start device screen capture.") + raise ScreenRecordError("启动设备屏幕录制失败") return self def _record_worker(self): - """Capture screen frames and save current frames.""" - - # JPEG start and end markers. - start_flag = b'\xff\xd8' - end_flag = b'\xff\xd9' + """ + 屏幕帧捕获工作线程 + + 捕获屏幕帧并保存当前帧 + """ buffer = bytearray() while not self.stop_event.is_set(): try: - buffer += self._recv_msg(4096 * 1024, decode=False, print=False) + buffer += self._recv_msg(decode=False, print=False) except Exception as e: - print(f"Error receiving data: {e}") + logger.error(f"接收数据时出错: {e}") break - start_idx = buffer.find(start_flag) - end_idx = buffer.find(end_flag) + # 查找 JPEG 图像的开始和结束标记 + start_idx = buffer.find(JPEG_START_FLAG) + end_idx = buffer.find(JPEG_END_FLAG) + + # 处理所有完整的 JPEG 图像 while start_idx != -1 and end_idx != -1 and end_idx > start_idx: - # Extract one JPEG image + # 提取一个 JPEG 图像 jpeg_image: bytearray = buffer[start_idx:end_idx + 2] self.jpeg_queue.put(jpeg_image) + # 从缓冲区中移除已处理的数据 buffer = buffer[end_idx + 2:] - # Search for the next JPEG image in the buffer - start_idx = buffer.find(start_flag) - end_idx = buffer.find(end_flag) + # 在缓冲区中查找下一个 JPEG 图像 + start_idx = buffer.find(JPEG_START_FLAG) + end_idx = buffer.find(JPEG_END_FLAG) def _video_writer(self): - """Write frames to video file.""" + """ + 视频写入工作线程 + + 将帧写入视频文件 + """ cv2_instance = None img = None while not self.stop_event.is_set(): try: - jpeg_image = self.jpeg_queue.get(timeout=0.1) + # 从队列获取 JPEG 图像 + jpeg_image = self.jpeg_queue.get(timeout=QUEUE_TIMEOUT) img = cv2.imdecode(np.frombuffer(jpeg_image, np.uint8), cv2.IMREAD_COLOR) except queue.Empty: pass + + # 跳过无效图像 if img is None or img.size == 0: continue + + # 首次获取有效图像时初始化视频写入器 if cv2_instance is None: height, width = img.shape[:2] - fourcc = cv2.VideoWriter_fourcc(*'mp4v') - cv2_instance = cv2.VideoWriter(self.video_path, fourcc, 10, (width, height)) + fourcc = cv2.VideoWriter_fourcc(*VIDEO_CODEC) + cv2_instance = cv2.VideoWriter(self.video_path, fourcc, VIDEO_FPS, (width, height)) + # 写入帧 cv2_instance.write(img) + # 释放资源 if cv2_instance: cv2_instance.release() def stop(self) -> str: + """ + 停止屏幕录制 + + Returns: + str: 视频保存路径 + """ try: + # 设置停止事件,通知工作线程退出 self.stop_event.set() + + # 等待所有工作线程结束 for t in self.threads: t.join() + # 发送停止录制命令 self._send_msg("stopCaptureScreen", []) - self._recv_msg(1024, decode=True, print=False) + self._recv_msg(decode=True, print=False) + # 释放资源 self.release() - # Invalidate the cached property + # 使缓存的属性失效 self.d._invalidate_cache('screenrecord') except Exception as e: - logger.error(f"An error occurred: {e}") + logger.error(f"停止屏幕录制时出错: {e}") return self.video_path diff --git a/hmdriver2/_swipe.py b/hmdriver2/_swipe.py index aab92a0..9769d66 100644 --- a/hmdriver2/_swipe.py +++ b/hmdriver2/_swipe.py @@ -1,35 +1,69 @@ # -*- coding: utf-8 -*- -from typing import Union, Tuple +from typing import Union, Tuple, Optional, Literal from .driver import Driver -from .proto import SwipeDirection +from .proto import SwipeDirection, Point -class SwipeExt(object): +# 滑动方向的字符串类型 +SwipeDirectionStr = Literal["left", "right", "up", "down"] + +# 默认滑动速度(像素/秒) +DEFAULT_SWIPE_SPEED = 2000 +# 速度限制 +MIN_SWIPE_SPEED = 200 +MAX_SWIPE_SPEED = 40000 + + +class SwipeExt: + """ + 扩展滑动功能类 + + 提供了更灵活的滑动操作,支持方向、比例和区域定制 + """ + def __init__(self, d: Driver): + """ + 初始化滑动扩展功能 + + Args: + d: Driver 实例 + """ self._d = d - def __call__(self, - direction: Union[SwipeDirection, str], - scale: float = 0.8, - box: Union[Tuple, None] = None, - speed=2000): + def __call__( + self, + direction: Union[SwipeDirection, SwipeDirectionStr], + scale: float = 0.8, + box: Optional[Tuple[int, int, int, int]] = None, + speed: int = DEFAULT_SWIPE_SPEED + ) -> None: """ + 执行滑动操作 + Args: - direction (str): one of "left", "right", "up", "bottom" or SwipeDirection.LEFT - scale (float): percent of swipe, range (0, 1.0] - box (Tuple): None or (x1, x1, y1, x2, y2) - speed (int, optional): The swipe speed in pixels per second. Default is 2000. Range: 200-40000. If not within the range, set to default value of 2000. + direction: 滑动方向,可以是 "left", "right", "up", "down" 或 SwipeDirection 枚举 + scale: 滑动比例,范围 (0, 1.0],表示滑动距离占可滑动区域的比例 + box: 滑动区域,格式为 (x1, y1, x2, y2),默认为全屏 + speed: 滑动速度(像素/秒),默认为 2000,有效范围 200-40000 + Raises: - ValueError + ValueError: 参数无效时抛出 """ - def _swipe(_from, _to): + def _swipe(_from: Tuple[int, int], _to: Tuple[int, int]) -> None: + """执行从一点到另一点的滑动""" self._d.swipe(_from[0], _from[1], _to[0], _to[1], speed=speed) + # 验证 scale 参数 if scale <= 0 or scale > 1.0 or not isinstance(scale, (float, int)): raise ValueError("scale must be in range (0, 1.0]") + # 验证 speed 参数 + if speed < MIN_SWIPE_SPEED or speed > MAX_SWIPE_SPEED: + speed = DEFAULT_SWIPE_SPEED + + # 确定滑动区域 if box: x1, y1, x2, y2 = self._validate_and_convert_box(box) else: @@ -38,47 +72,61 @@ def _swipe(_from, _to): width, height = x2 - x1, y2 - y1 + # 计算偏移量,确保滑动在边缘留有间距 h_offset = int(width * (1 - scale) / 2) v_offset = int(height * (1 - scale) / 2) - if direction == SwipeDirection.LEFT: + # 根据方向确定滑动的起点和终点 + if direction in [SwipeDirection.LEFT, "left"]: start = (x2 - h_offset, y1 + height // 2) end = (x1 + h_offset, y1 + height // 2) - elif direction == SwipeDirection.RIGHT: + elif direction in [SwipeDirection.RIGHT, "right"]: start = (x1 + h_offset, y1 + height // 2) end = (x2 - h_offset, y1 + height // 2) - elif direction == SwipeDirection.UP: + elif direction in [SwipeDirection.UP, "up"]: start = (x1 + width // 2, y2 - v_offset) end = (x1 + width // 2, y1 + v_offset) - elif direction == SwipeDirection.DOWN: + elif direction in [SwipeDirection.DOWN, "down"]: start = (x1 + width // 2, y1 + v_offset) end = (x1 + width // 2, y2 - v_offset) else: - raise ValueError("Unknown SwipeDirection:", direction) + raise ValueError(f"Unknown SwipeDirection: {direction}") + # 执行滑动 _swipe(start, end) def _validate_and_convert_box(self, box: Tuple) -> Tuple[int, int, int, int]: """ - Validate and convert the box coordinates if necessay. - + 验证并转换区域坐标 + Args: - box (Tuple): The box coordinates as a tuple (x1, y1, x2, y2). - + box: 区域坐标元组 (x1, y1, x2, y2) + Returns: - Tuple[int, int, int, int]: The validated and converted box coordinates. + Tuple[int, int, int, int]: 验证并转换后的区域坐标 + + Raises: + ValueError: 坐标无效时抛出 """ + # 验证元组长度 if not isinstance(box, tuple) or len(box) != 4: raise ValueError("Box must be a tuple of length 4.") + x1, y1, x2, y2 = box + + # 验证坐标值 + if not all(isinstance(coord, (int, float)) for coord in box): + raise ValueError("All coordinates must be numeric.") + + # 验证坐标范围 if not (x1 >= 0 and y1 >= 0 and x2 > 0 and y2 > 0): raise ValueError("Box coordinates must be greater than 0.") + + # 验证坐标关系 if not (x1 < x2 and y1 < y2): raise ValueError("Box coordinates must satisfy x1 < x2 and y1 < y2.") - from .driver import Point + # 转换坐标到绝对位置 p1: Point = self._d._to_abs_pos(x1, y1) p2: Point = self._d._to_abs_pos(x2, y2) - x1, y1, x2, y2 = p1.x, p1.y, p2.x, p2.y - - return x1, y1, x2, y2 + return p1.x, p1.y, p2.x, p2.y diff --git a/hmdriver2/_uiobject.py b/hmdriver2/_uiobject.py index 31f58d0..6fbbb92 100644 --- a/hmdriver2/_uiobject.py +++ b/hmdriver2/_uiobject.py @@ -2,44 +2,73 @@ import enum import time -from typing import List, Union +from typing import List, Optional, Any from . import logger -from .utils import delay from ._client import HmClient from .exception import ElementNotFoundError from .proto import ComponentData, ByData, HypiumResponse, Point, Bounds, ElementInfo +from .utils import delay class ByType(enum.Enum): - id = "id" - key = "key" - text = "text" - type = "type" - description = "description" - clickable = "clickable" - longClickable = "longClickable" - scrollable = "scrollable" - enabled = "enabled" - focused = "focused" - selected = "selected" - checked = "checked" - checkable = "checkable" - isBefore = "isBefore" - isAfter = "isAfter" + """ + UI 元素查找类型枚举 + + 定义了可用于查找 UI 元素的属性类型 + """ + id = "id" # 元素 ID + key = "key" # 元素键值 + text = "text" # 元素文本 + type = "type" # 元素类型 + description = "description" # 元素描述 + clickable = "clickable" # 是否可点击 + longClickable = "longClickable" # 是否可长按 + scrollable = "scrollable" # 是否可滚动 + enabled = "enabled" # 是否启用 + focused = "focused" # 是否获得焦点 + selected = "selected" # 是否被选中 + checked = "checked" # 是否被勾选 + checkable = "checkable" # 是否可勾选 + isBefore = "isBefore" # 是否在指定元素之前 + isAfter = "isAfter" # 是否在指定元素之后 @classmethod - def verify(cls, value): + def verify(cls, value: str) -> bool: + """ + 验证属性类型是否有效 + + Args: + value: 要验证的属性类型 + + Returns: + bool: 属性类型有效返回 True,否则返回 False + """ return any(value == item.value for item in cls) class UiObject: + """ + UI 对象类,用于查找和操作 UI 元素 + + 提供了元素查找、属性获取和操作执行的功能 + """ + + # 默认超时时间(秒) DEFAULT_TIMEOUT = 2 def __init__(self, client: HmClient, **kwargs) -> None: + """ + 初始化 UI 对象 + + Args: + client: HmClient 实例 + **kwargs: 查找元素的条件 + """ self._client = client self._raw_kwargs = kwargs + # 提取特殊参数 self._index = kwargs.pop("index", 0) self._isBefore = kwargs.pop("isBefore", False) self._isAfter = kwargs.pop("isAfter", False) @@ -47,53 +76,106 @@ def __init__(self, client: HmClient, **kwargs) -> None: self._kwargs = kwargs self.__verify() - self._component: Union[ComponentData, None] = None # cache + self._component: Optional[ComponentData] = None # 缓存找到的组件 def __str__(self) -> str: - return f"UiObject [{self._raw_kwargs}" - - def __verify(self): + """返回 UiObject 的字符串表示""" + return f"UiObject {self._raw_kwargs}" + + def __verify(self) -> None: + """ + 验证查找条件是否有效 + + Raises: + ReferenceError: 查找条件无效时抛出 + """ for k, v in self._kwargs.items(): if not ByType.verify(k): - raise ReferenceError(f"{k} is not allowed.") + raise ReferenceError(f"{k} 不是有效的查找条件") @property def count(self) -> int: - eleements = self.__find_components() - return len(eleements) if eleements else 0 - - def __len__(self): + """ + 获取匹配条件的元素数量 + + Returns: + int: 元素数量 + """ + elements = self.__find_components() + return len(elements) if elements else 0 + + def __len__(self) -> int: + """支持使用 len() 函数获取元素数量""" return self.count - def exists(self, retries: int = 2, wait_time=1) -> bool: + def exists(self, retries: int = 2, wait_time: float = 1) -> bool: + """ + 检查元素是否存在 + + Args: + retries: 重试次数,默认为 2 + wait_time: 重试间隔时间(秒),默认为 1 + + Returns: + bool: 元素存在返回 True,否则返回 False + """ obj = self.find_component(retries, wait_time) - return True if obj else False - - def __set_component(self, component: ComponentData): + return obj is not None + + def __set_component(self, component: ComponentData) -> None: + """ + 设置找到的组件 + + Args: + component: 组件数据 + """ self._component = component - def find_component(self, retries: int = 1, wait_time=1) -> ComponentData: + def find_component(self, retries: int = 1, wait_time: float = 1) -> Optional[ComponentData]: + """ + 查找匹配条件的组件 + + Args: + retries: 重试次数,默认为 1 + wait_time: 重试间隔时间(秒),默认为 1 + + Returns: + Optional[ComponentData]: 找到的组件,未找到返回 None + """ for attempt in range(retries): components = self.__find_components() if components and self._index < len(components): self.__set_component(components[self._index]) return self._component - if attempt < retries: + if attempt < retries - 1: time.sleep(wait_time) - logger.info(f"Retry found element {self}") + logger.info(f"重试查找元素 {self}") return None - # useless - def __find_component(self) -> Union[ComponentData, None]: + def __find_component(self) -> Optional[ComponentData]: + """ + 查找单个匹配条件的组件 + + 该方法直接调用 Driver.findComponent API 查找单个组件 + + Returns: + Optional[ComponentData]: 找到的组件,未找到返回 None + """ by: ByData = self.__get_by() resp: HypiumResponse = self._client.invoke("Driver.findComponent", args=[by.value]) if not resp.result: return None return ComponentData(resp.result) - def __find_components(self) -> Union[List[ComponentData], None]: + def __find_components(self) -> Optional[List[ComponentData]]: + """ + 查找所有匹配条件的组件 + + Returns: + Optional[List[ComponentData]]: 找到的组件列表,未找到返回 None + """ by: ByData = self.__get_by() resp: HypiumResponse = self._client.invoke("Driver.findComponents", args=[by.value]) if not resp.result: @@ -105,92 +187,149 @@ def __find_components(self) -> Union[List[ComponentData], None]: return components def __get_by(self) -> ByData: + """ + 构建查找条件 + + Returns: + ByData: 查找条件对象 + """ + this = "On#seed" + + # 处理所有查找条件 for k, v in self._kwargs.items(): api = f"On.{k}" - this = "On#seed" - resp: HypiumResponse = self._client.invoke(api, this, args=[v]) + resp: HypiumResponse = self._client.invoke(api, this=this, args=[v]) this = resp.result + # 处理位置关系 if self._isBefore: - resp: HypiumResponse = self._client.invoke("On.isBefore", this="On#seed", args=[resp.result]) + resp: HypiumResponse = self._client.invoke("On.isBefore", this=this, args=[resp.result]) if self._isAfter: - resp: HypiumResponse = self._client.invoke("On.isAfter", this="On#seed", args=[resp.result]) + resp: HypiumResponse = self._client.invoke("On.isAfter", this=this, args=[resp.result]) return ByData(resp.result) - def __operate(self, api, args=[], retries: int = 2): + def __operate(self, api: str, args: Optional[List[Any]] = None, retries: int = 2) -> Any: + """ + 对元素执行操作 + + Args: + api: 要调用的 API + args: API 参数,默认为空列表 + retries: 重试次数,默认为 2 + + Returns: + Any: API 调用结果 + + Raises: + ElementNotFoundError: 元素未找到时抛出 + """ + if args is None: + args = [] + if not self._component: if not self.find_component(retries): - raise ElementNotFoundError(f"Element({self}) not found after {retries} retries") + raise ElementNotFoundError(f"未找到元素({self}),重试 {retries} 次后失败") resp: HypiumResponse = self._client.invoke(api, this=self._component.value, args=args) return resp.result @property def id(self) -> str: + """元素 ID""" return self.__operate("Component.getId") @property def key(self) -> str: + """元素键值""" return self.__operate("Component.getId") @property def type(self) -> str: + """元素类型""" return self.__operate("Component.getType") @property def text(self) -> str: + """元素文本""" return self.__operate("Component.getText") @property def description(self) -> str: + """元素描述""" return self.__operate("Component.getDescription") @property def isSelected(self) -> bool: + """元素是否被选中""" return self.__operate("Component.isSelected") @property def isChecked(self) -> bool: + """元素是否被勾选""" return self.__operate("Component.isChecked") @property def isEnabled(self) -> bool: + """元素是否启用""" return self.__operate("Component.isEnabled") @property def isFocused(self) -> bool: + """元素是否获得焦点""" return self.__operate("Component.isFocused") @property def isCheckable(self) -> bool: + """元素是否可勾选""" return self.__operate("Component.isCheckable") @property def isClickable(self) -> bool: + """元素是否可点击""" return self.__operate("Component.isClickable") @property def isLongClickable(self) -> bool: + """元素是否可长按""" return self.__operate("Component.isLongClickable") @property def isScrollable(self) -> bool: + """元素是否可滚动""" return self.__operate("Component.isScrollable") @property def bounds(self) -> Bounds: + """ + 元素边界 + + Returns: + Bounds: 元素边界对象 + """ _raw = self.__operate("Component.getBounds") return Bounds(**_raw) @property def boundsCenter(self) -> Point: + """ + 元素中心点坐标 + + Returns: + Point: 元素中心点坐标对象 + """ _raw = self.__operate("Component.getBoundsCenter") return Point(**_raw) @property def info(self) -> ElementInfo: + """ + 获取元素的完整信息 + + Returns: + ElementInfo: 包含元素所有属性的信息对象 + """ return ElementInfo( id=self.id, key=self.key, @@ -209,40 +348,108 @@ def info(self) -> ElementInfo: boundsCenter=self.boundsCenter) @delay - def click(self): + def click(self) -> Any: + """ + 点击元素 + + Returns: + Any: 操作结果 + """ return self.__operate("Component.click") @delay - def click_if_exists(self): + def click_if_exists(self) -> Optional[Any]: + """ + 如果元素存在则点击 + + 与 click() 不同,该方法在元素不存在时不会抛出异常 + + Returns: + Optional[Any]: 操作结果,元素不存在时返回 None + """ try: return self.__operate("Component.click") except ElementNotFoundError: - pass + return None @delay - def double_click(self): + def double_click(self) -> Any: + """ + 双击元素 + + Returns: + Any: 操作结果 + """ return self.__operate("Component.doubleClick") @delay - def long_click(self): + def long_click(self) -> Any: + """ + 长按元素 + + Returns: + Any: 操作结果 + """ return self.__operate("Component.longClick") @delay - def drag_to(self, component: ComponentData): + def drag_to(self, component: ComponentData) -> Any: + """ + 将元素拖动到指定组件位置 + + Args: + component: 目标组件 + + Returns: + Any: 操作结果 + """ return self.__operate("Component.dragTo", [component.value]) @delay - def input_text(self, text: str): + def input_text(self, text: str) -> Any: + """ + 在元素中输入文本 + + Args: + text: 要输入的文本 + + Returns: + Any: 操作结果 + """ return self.__operate("Component.inputText", [text]) @delay - def clear_text(self): + def clear_text(self) -> Any: + """ + 清除元素中的文本 + + Returns: + Any: 操作结果 + """ return self.__operate("Component.clearText") @delay - def pinch_in(self, scale: float = 0.5): + def pinch_in(self, scale: float = 0.5) -> Any: + """ + 在元素上执行捏合手势(缩小) + + Args: + scale: 缩放比例,默认为 0.5 + + Returns: + Any: 操作结果 + """ return self.__operate("Component.pinchIn", [scale]) @delay - def pinch_out(self, scale: float = 2): + def pinch_out(self, scale: float = 2) -> Any: + """ + 在元素上执行张开手势(放大) + + Args: + scale: 缩放比例,默认为 2 + + Returns: + Any: 操作结果 + """ return self.__operate("Component.pinchOut", [scale]) diff --git a/hmdriver2/_xpath.py b/hmdriver2/_xpath.py index 49cc576..d2dcd77 100644 --- a/hmdriver2/_xpath.py +++ b/hmdriver2/_xpath.py @@ -1,105 +1,417 @@ # -*- coding: utf-8 -*- +import json import re -from typing import Dict -from lxml import etree from functools import cached_property +from typing import Dict, Optional, List, Any + +from lxml import etree from . import logger -from .proto import Bounds from .driver import Driver -from .utils import delay, parse_bounds from .exception import XmlElementNotFoundError +from .proto import Point +from .utils import parse_bounds, delay + +# XML相关常量 +XML_ROOT_TAG = "orgRoot" +XML_ATTRIBUTE_TYPE = "type" +XML_ATTRIBUTE_BOUNDS = "origBounds" +XML_ATTRIBUTE_ID = "id" +XML_ATTRIBUTE_TEXT = "text" +XML_ATTRIBUTE_DESCRIPTION = "description" + +# 布尔属性的默认值 +DEFAULT_BOOL_VALUE = "false" +TRUE_VALUE = "true" + +# 布尔属性列表 +BOOL_ATTRIBUTES = [ + "enabled", "focused", "selected", "checked", "checkable", + "clickable", "longClickable", "scrollable" +] class _XPath: - def __init__(self, d: Driver): + """ + XPath查询类,用于在UI层次结构中查找元素。 + + 该类提供了将JSON格式的层次结构转换为XML,并执行XPath查询的功能。 + + Attributes: + _d (Driver): Driver实例,用于与设备交互 + """ + + def __init__(self, d: Driver) -> None: + """ + 初始化XPath查询对象。 + + Args: + d: Driver实例,用于执行设备操作 + """ self._d = d def __call__(self, xpath: str) -> '_XMLElement': + """ + 执行XPath查询并返回匹配的元素。 + + Args: + xpath: XPath查询字符串 - hierarchy: Dict = self._d.dump_hierarchy() - if not hierarchy: - raise RuntimeError("hierarchy is empty") + Returns: + _XMLElement: 匹配XPath查询的XML元素 - xml = _XPath._json2xml(hierarchy) - result = xml.xpath(xpath) + Raises: + RuntimeError: 当层次结构为空或解析失败时抛出 + XmlElementNotFoundError: 当找不到匹配元素时抛出 + """ + hierarchy_str: str = self._d.dump_hierarchy() + if not hierarchy_str: + raise RuntimeError("层次结构为空") - if len(result) > 0: - node = result[0] - raw_bounds: str = node.attrib.get("bounds") # [832,1282][1125,1412] - bounds: Bounds = parse_bounds(raw_bounds) - logger.debug(f"{xpath} Bounds: {bounds}") - return _XMLElement(bounds, self._d) + try: + hierarchy: Dict[str, Any] = json.loads(hierarchy_str) + except json.JSONDecodeError as e: + raise RuntimeError(f"解析层次结构JSON失败: {e}") - return _XMLElement(None, self._d) + xml = self._json2xml(hierarchy) + + try: + result: List[etree.Element] = xml.xpath(xpath) + except etree.XPathError as e: + raise RuntimeError(f"XPath查询语法错误: {e}") + + # 返回第一个匹配的节点,如果未找到则返回None + node = result[0] if result else None + return _XMLElement(node, self._d) @staticmethod def _sanitize_text(text: str) -> str: - """Remove XML-incompatible control characters.""" + """ + 移除XML不兼容的控制字符。 + + Args: + text: 需要清理的文本 + + Returns: + str: 清理后的文本 + """ + if not isinstance(text, str): + text = str(text) return re.sub(r'[\x00-\x1F\x7F]', '', text) @staticmethod - def _json2xml(hierarchy: Dict) -> etree.Element: - """Convert JSON-like hierarchy to XML.""" - attributes = hierarchy.get("attributes", {}) + def _json2xml(hierarchy: Dict[str, Any]) -> etree.Element: + """ + 将JSON格式的层次结构转换为XML。 - # 过滤所有属性的值,确保无非法字符 - cleaned_attributes = {k: _XPath._sanitize_text(str(v)) for k, v in attributes.items()} + Args: + hierarchy: JSON格式的层次结构数据 - tag = cleaned_attributes.get("type", "orgRoot") or "orgRoot" + Returns: + etree.Element: 转换后的XML元素 + """ + attributes = hierarchy.get("attributes", {}) + cleaned_attributes = { + k: _XPath._sanitize_text(str(v)) + for k, v in attributes.items() + } + + tag = cleaned_attributes.get(XML_ATTRIBUTE_TYPE, XML_ROOT_TAG) or XML_ROOT_TAG xml = etree.Element(tag, attrib=cleaned_attributes) - + children = hierarchy.get("children", []) for item in children: xml.append(_XPath._json2xml(item)) - + return xml class _XMLElement: - def __init__(self, bounds: Bounds, d: Driver): - self.bounds = bounds + """ + XML元素类,用于处理UI控件的属性和操作。 + + 提供了控件属性的访问和基本的交互操作方法。 + + Attributes: + _node (Optional[etree.Element]): XML节点元素 + _d (Driver): Driver实例 + """ + + def __init__(self, node: Optional[etree.Element], d: Driver) -> None: + """ + 初始化XML元素。 + + Args: + node: XML节点元素 + d: Driver实例 + """ + self._node = node self._d = d - def _verify(self): - if not self.bounds: - raise XmlElementNotFoundError("xpath not found") + def _verify(self) -> None: + """ + 验证控件是否存在。 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + """ + if self._node is None or not self._node.attrib: + raise XmlElementNotFoundError("未找到匹配的元素") @cached_property - def center(self): + def center(self) -> Point: + """ + 获取控件中心点坐标。 + + Returns: + Point: 控件中心点坐标 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + """ self._verify() - return self.bounds.get_center() + bounds = parse_bounds(self._node.attrib.get(XML_ATTRIBUTE_BOUNDS, "")) + return bounds.get_center() if bounds else Point(0, 0) def exists(self) -> bool: - return self.bounds is not None + """ + 检查控件是否存在。 + + Returns: + bool: 控件存在返回True,否则返回False + """ + return self._node is not None and bool(self._node.attrib) + + @property + def id(self) -> str: + """ + 控件的唯一标识符 + + Returns: + str: 控件ID + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + """ + self._verify() + return self._node.attrib.get(XML_ATTRIBUTE_ID, "") + + @property + def type(self) -> str: + """ + 控件的类型 + + Returns: + str: 控件类型 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + """ + self._verify() + return self._node.attrib.get(XML_ATTRIBUTE_TYPE, "") + + @property + def text(self) -> str: + """ + 控件的文本内容 + + Returns: + str: 控件文本 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + """ + self._verify() + return self._node.attrib.get(XML_ATTRIBUTE_TEXT, "") + + @property + def description(self) -> str: + """ + 控件的描述信息 + + Returns: + str: 控件描述 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + """ + self._verify() + return self._node.attrib.get(XML_ATTRIBUTE_DESCRIPTION, "") + + def _get_bool_attr(self, attr: str) -> bool: + """ + 获取布尔类型的属性值。 + + Args: + attr: 属性名 + + Returns: + bool: 属性值 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + """ + self._verify() + return self._node.attrib.get(attr, DEFAULT_BOOL_VALUE) == TRUE_VALUE + + @property + def enabled(self) -> bool: + """ + 控件是否启用 + + Returns: + bool: 启用状态 + """ + return self._get_bool_attr("enabled") + + @property + def focused(self) -> bool: + """ + 控件是否获得焦点 + + Returns: + bool: 焦点状态 + """ + return self._get_bool_attr("focused") + + @property + def selected(self) -> bool: + """ + 控件是否被选中 + + Returns: + bool: 选中状态 + """ + return self._get_bool_attr("selected") + + @property + def checked(self) -> bool: + """ + 控件是否被勾选 + + Returns: + bool: 勾选状态 + """ + return self._get_bool_attr("checked") + + @property + def checkable(self) -> bool: + """ + 控件是否可勾选 + + Returns: + bool: 可勾选状态 + """ + return self._get_bool_attr("checkable") + + @property + def clickable(self) -> bool: + """ + 控件是否可点击 + + Returns: + bool: 可点击状态 + """ + return self._get_bool_attr("clickable") + + @property + def long_clickable(self) -> bool: + """ + 控件是否可长按 + + Returns: + bool: 可长按状态 + """ + return self._get_bool_attr("longClickable") + + @property + def scrollable(self) -> bool: + """ + 控件是否可滚动 + + Returns: + bool: 可滚动状态 + """ + return self._get_bool_attr("scrollable") @delay - def click(self): + def click(self) -> None: + """ + 点击控件中心位置。 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + + Note: + 该操作会自动添加延迟以确保操作的稳定性 + """ + self._verify() x, y = self.center.x, self.center.y + logger.debug(f"点击坐标: ({x}, {y})") self._d.click(x, y) @delay - def click_if_exists(self): - + def click_if_exists(self) -> bool: + """ + 如果控件存在则点击。 + + Returns: + bool: 点击成功返回True,控件不存在返回False + + Note: + 该方法不会在控件不存在时抛出异常 + """ if not self.exists(): - logger.debug("click_exist: xpath not found") - return + logger.debug("控件不存在,跳过点击操作") + return False x, y = self.center.x, self.center.y + logger.debug(f"点击坐标: ({x}, {y})") self._d.click(x, y) + return True @delay - def double_click(self): + def double_click(self) -> None: + """ + 双击控件中心位置。 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + """ + self._verify() x, y = self.center.x, self.center.y + logger.debug(f"双击坐标: ({x}, {y})") self._d.double_click(x, y) @delay - def long_click(self): + def long_click(self) -> None: + """ + 长按控件中心位置。 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + """ + self._verify() x, y = self.center.x, self.center.y + logger.debug(f"长按坐标: ({x}, {y})") self._d.long_click(x, y) @delay - def input_text(self, text): + def input_text(self, text: str) -> None: + """ + 在控件中输入文本。 + + Args: + text: 要输入的文本内容 + + Raises: + XmlElementNotFoundError: 当控件不存在时抛出 + + Note: + 会先点击控件获取焦点,然后进行文本输入 + """ + self._verify() + logger.debug(f"输入文本: {text}") self.click() - self._d.input_text(text) \ No newline at end of file + self._d.input_text(text) diff --git a/hmdriver2/assets/so/arm64-v8a/agent.so b/hmdriver2/assets/so/arm64-v8a/agent.so new file mode 100644 index 0000000..6d4539a Binary files /dev/null and b/hmdriver2/assets/so/arm64-v8a/agent.so differ diff --git a/hmdriver2/assets/so/x86_64/agent.so b/hmdriver2/assets/so/x86_64/agent.so new file mode 100644 index 0000000..904f0f5 Binary files /dev/null and b/hmdriver2/assets/so/x86_64/agent.so differ diff --git a/hmdriver2/assets/uitest_agent_v1.0.7.so b/hmdriver2/assets/uitest_agent_v1.0.7.so deleted file mode 100644 index 8298f2c..0000000 Binary files a/hmdriver2/assets/uitest_agent_v1.0.7.so and /dev/null differ diff --git a/hmdriver2/assets/uitest_agent_v1.1.0.so b/hmdriver2/assets/uitest_agent_v1.1.0.so deleted file mode 100644 index e71a700..0000000 Binary files a/hmdriver2/assets/uitest_agent_v1.1.0.so and /dev/null differ diff --git a/hmdriver2/driver.py b/hmdriver2/driver.py index 179d10e..ec8476f 100644 --- a/hmdriver2/driver.py +++ b/hmdriver2/driver.py @@ -2,148 +2,249 @@ import json import uuid -import re -from typing import Type, Any, Tuple, Dict, Union, List, Optional from functools import cached_property # python3.8+ +from typing import Type, Tuple, Dict, Union, List, Optional, Any from . import logger -from .utils import delay from ._client import HmClient from ._uiobject import UiObject -from .hdc import list_devices from .exception import DeviceNotFoundError +from .hdc import list_devices from .proto import HypiumResponse, KeyCode, Point, DisplayRotation, DeviceInfo, CommandResult +from .utils import delay class Driver: + """ + Harmony OS 设备驱动类 + + 提供设备控制、应用管理、UI 操作等功能的主要接口。 + 采用单例模式,每个设备序列号对应一个实例。 + """ + + # 单例存储字典 _instance: Dict[str, "Driver"] = {} def __new__(cls: Type["Driver"], serial: Optional[str] = None) -> "Driver": """ - Ensure that only one instance of Driver exists per serial. - If serial is None, use the first serial from list_devices(). + 确保每个设备序列号只创建一个 Driver 实例 + + 如果 serial 为 None,使用 list_devices() 的第一个设备 + + Args: + serial: 设备序列号,为 None 时使用第一个可用设备 + + Returns: + Driver: 对应序列号的 Driver 实例 """ serial = cls._prepare_serial(serial) if serial not in cls._instance: instance = super().__new__(cls) cls._instance[serial] = instance - # Temporarily store the serial in the instance for initialization + # 临时存储序列号用于初始化 instance._serial_for_init = serial return cls._instance[serial] def __init__(self, serial: Optional[str] = None): """ - Initialize the Driver instance. Only initialize if `_initialized` is not set. + 初始化 Driver 实例 + + 只有在实例未初始化时才执行初始化 + + Args: + serial: 设备序列号,为 None 时使用第一个可用设备 + + Raises: + ValueError: 初始化时缺少序列号 """ if hasattr(self, "_initialized") and self._initialized: return - # Use the serial prepared in `__new__` + # 使用在 __new__ 中准备的序列号 serial = getattr(self, "_serial_for_init", serial) if serial is None: - raise ValueError("Serial number is required for initialization.") + raise ValueError("初始化时需要设备序列号") self.serial = serial self._client = HmClient(self.serial) self.hdc = self._client.hdc self._init_hmclient() - self._initialized = True # Mark the instance as initialized - del self._serial_for_init # Clean up temporary attribute + self._initialized = True # 标记实例已初始化 + del self._serial_for_init # 清理临时属性 @classmethod - def _prepare_serial(cls, serial: str = None) -> str: + def _prepare_serial(cls, serial: Optional[str] = None) -> str: """ - Prepare the serial. Use the first available device if serial is None. + 准备设备序列号 + + 如果未提供序列号,使用第一个可用设备 + + Args: + serial: 设备序列号,为 None 时使用第一个可用设备 + + Returns: + str: 准备好的设备序列号 + + Raises: + DeviceNotFoundError: 未找到设备或指定的设备不存在 """ devices = list_devices() if not devices: - raise DeviceNotFoundError("No devices found. Please connect a device.") + raise DeviceNotFoundError("未找到设备,请连接设备") if serial is None: - logger.info(f"No serial provided, using the first device: {devices[0]}") + logger.info(f"未提供序列号,使用第一个设备: {devices[0]}") return devices[0] if serial not in devices: - raise DeviceNotFoundError(f"Device [{serial}] not found") + raise DeviceNotFoundError(f"未找到设备 [{serial}]") return serial def __call__(self, **kwargs) -> UiObject: - + """ + 将 Driver 实例作为函数调用,返回 UiObject 实例 + + Args: + **kwargs: 传递给 UiObject 构造函数的参数 + + Returns: + UiObject: 创建的 UiObject 实例 + """ return UiObject(self._client, **kwargs) def __del__(self): + """ + 析构函数,清理资源 + """ Driver._instance.clear() if hasattr(self, '_client') and self._client: self._client.release() def _init_hmclient(self): + """初始化 HmClient 连接""" self._client.start() - def _invoke(self, api: str, args: List = []) -> HypiumResponse: + def _invoke(self, api: str, args: Optional[List[Any]] = None) -> HypiumResponse: + """ + 调用 Hypium API + + Args: + api: API 名称 + args: API 参数列表,默认为空列表 + + Returns: + HypiumResponse: API 调用响应 + """ + if args is None: + args = [] return self._client.invoke(api, this="Driver#0", args=args) @delay def start_app(self, package_name: str, page_name: Optional[str] = None): """ - Start an application on the device. - If the `package_name` is empty, it will retrieve main ability using `get_app_main_ability`. - + 启动应用 + + 如果未提供 page_name,将通过 get_app_main_ability 获取主 Ability + Args: - package_name (str): The package name of the application. - page_name (Optional[str]): Ability Name within the application to start. + package_name: 应用包名 + page_name: Ability 名称,默认为 None """ if not page_name: page_name = self.get_app_main_ability(package_name).get('name', 'MainAbility') self.hdc.start_app(package_name, page_name) def force_start_app(self, package_name: str, page_name: Optional[str] = None): + """ + 强制启动应用 + + 先返回主屏幕,停止应用,然后启动应用 + + Args: + package_name: 应用包名 + page_name: Ability 名称,默认为 None + """ self.go_home() self.stop_app(package_name) self.start_app(package_name, page_name) def stop_app(self, package_name: str): + """ + 停止应用 + + Args: + package_name: 应用包名 + """ self.hdc.stop_app(package_name) def clear_app(self, package_name: str): """ - Clear the application's cache and data. + 清除应用缓存和数据 + + Args: + package_name: 应用包名 """ - self.hdc.shell(f"bm clean -n {package_name} -c") # clear cache - self.hdc.shell(f"bm clean -n {package_name} -d") # clear data + self.hdc.shell(f"bm clean -n {package_name} -c") # 清除缓存 + self.hdc.shell(f"bm clean -n {package_name} -d") # 清除数据 def install_app(self, apk_path: str): + """ + 安装应用 + + Args: + apk_path: 应用安装包路径 + """ self.hdc.install(apk_path) def uninstall_app(self, package_name: str): + """ + 卸载应用 + + Args: + package_name: 应用包名 + """ self.hdc.uninstall(package_name) def list_apps(self) -> List: + """ + 列出设备上的应用 + + Returns: + List: 应用列表 + """ return self.hdc.list_apps() def has_app(self, package_name: str) -> bool: + """ + 检查设备上是否安装了指定应用 + + Args: + package_name: 应用包名 + + Returns: + bool: 应用存在返回 True,否则返回 False + """ return self.hdc.has_app(package_name) def current_app(self) -> Tuple[str, str]: """ - Get the current foreground application information. - + 获取当前前台应用信息 + Returns: - Tuple[str, str]: A tuple contain the package_name andpage_name of the foreground application. - If no foreground application is found, returns (None, None). + Tuple[str, str]: 包含应用包名和页面名称的元组 + 如果未找到前台应用,返回 (None, None) """ - return self.hdc.current_app() def get_app_info(self, package_name: str) -> Dict: """ - Get detailed information about a specific application. - + 获取应用详细信息 + Args: - package_name (str): The package name of the application to retrieve information for. - + package_name: 应用包名 + Returns: - Dict: A dictionary containing the application information. If an error occurs during parsing, - an empty dictionary is returned. + Dict: 包含应用信息的字典,解析错误时返回空字典 """ app_info = {} data: CommandResult = self.hdc.shell(f"bm dump -n {package_name}") @@ -155,88 +256,112 @@ def get_app_info(self, package_name: str) -> Dict: app_info = json.loads(json_output) except Exception as e: - logger.error(f"An error occurred:{e}") + logger.error(f"解析应用信息时出错: {e}") return app_info def get_app_abilities(self, package_name: str) -> List[Dict]: """ - Get the abilities of an application. - + 获取应用的 Abilities + Args: - package_name (str): The package name of the application. - + package_name: 应用包名 + Returns: - List[Dict]: A list of dictionaries containing the abilities of the application. + List[Dict]: 包含应用 Abilities 信息的字典列表 """ result = [] app_info = self.get_app_info(package_name) - hap_module_infos = app_info.get("hapModuleInfos") + hap_module_infos = app_info.get("hapModuleInfos", []) main_entry = app_info.get("mainEntry") for hap_module_info in hap_module_infos: - # 尝试读取moduleInfo + # 尝试读取 moduleInfo try: - ability_infos = hap_module_info.get("abilityInfos") - module_main = hap_module_info["mainAbility"] + ability_infos = hap_module_info.get("abilityInfos", []) + module_main = hap_module_info.get("mainAbility", "") except Exception as e: - logger.warning(f"Fail to parse moduleInfo item, {repr(e)}") + logger.warning(f"解析 moduleInfo 失败: {repr(e)}") continue - # 尝试读取abilityInfo + # 尝试读取 abilityInfo for ability_info in ability_infos: try: is_launcher_ability = False - skills = ability_info['skills'] - if len(skills) > 0 or "action.system.home" in skills[0]["actions"]: + skills = ability_info.get('skills', []) + if skills and "action.system.home" in skills[0].get("actions", []): is_launcher_ability = True icon_ability_info = { - "name": ability_info["name"], - "moduleName": ability_info["moduleName"], + "name": ability_info.get("name", ""), + "moduleName": ability_info.get("moduleName", ""), "moduleMainAbility": module_main, "mainModule": main_entry, "isLauncherAbility": is_launcher_ability } result.append(icon_ability_info) except Exception as e: - logger.warning(f"Fail to parse ability_info item, {repr(e)}") + logger.warning(f"解析 ability_info 失败: {repr(e)}") continue - logger.debug(f"all abilities: {result}") + logger.debug(f"所有 abilities: {result}") return result def get_app_main_ability(self, package_name: str) -> Dict: """ - Get the main ability of an application. - + 获取应用的主 Ability + Args: - package_name (str): The package name of the application to retrieve information for. - + package_name: 应用包名 + Returns: - Dict: A dictionary containing the main ability of the application. - + Dict: 包含应用主 Ability 信息的字典,未找到时返回空字典 """ - if not (abilities := self.get_app_abilities(package_name)): + abilities = self.get_app_abilities(package_name) + if not abilities: return {} for item in abilities: score = 0 - if (name := item["name"]) and name == item["moduleMainAbility"]: + name = item.get("name", "") + if name and name == item.get("moduleMainAbility", ""): score += 1 - if (module_name := item["moduleName"]) and module_name == item["mainModule"]: + module_name = item.get("moduleName", "") + if module_name and module_name == item.get("mainModule", ""): score += 1 item["score"] = score - abilities.sort(key=lambda x: (not x["isLauncherAbility"], -x["score"])) - logger.debug(f"main ability: {abilities[0]}") + abilities.sort(key=lambda x: (not x.get("isLauncherAbility", False), -x.get("score", 0))) + logger.debug(f"主 ability: {abilities[0]}") return abilities[0] @cached_property def toast_watcher(self): - + """ + 获取 Toast 监视器 + + Returns: + _Watcher: Toast 监视器实例 + """ obj = self class _Watcher: + """Toast 监视器内部类""" + def start(self) -> bool: + """ + 开始监视 Toast + + Returns: + bool: 成功返回 True + """ api = "Driver.uiEventObserverOnce" resp: HypiumResponse = obj._invoke(api, args=["toastShow"]) return resp.result - def get_toast(self, timeout: int = 3) -> str: + def get_toast(self, timeout: int = 3) -> Optional[str]: + """ + 获取 Toast 内容 + + Args: + timeout: 超时时间(秒),默认为 3 + + Returns: + Optional[str]: Toast 内容,未捕获到返回 None + """ api = "Driver.getRecentUiEvent" resp: HypiumResponse = obj._invoke(api, args=[timeout]) if resp.result: @@ -247,31 +372,72 @@ def get_toast(self, timeout: int = 3) -> str: @delay def go_back(self): + """按返回键""" self.hdc.send_key(KeyCode.BACK) @delay def go_home(self): + """按主页键""" self.hdc.send_key(KeyCode.HOME) + @delay + def go_recent(self): + """打开最近任务""" + self.press_keys(KeyCode.META_LEFT, KeyCode.TAB) + @delay def press_key(self, key_code: Union[KeyCode, int]): + """ + 按下单个按键 + + Args: + key_code: 按键代码,可以是 KeyCode 枚举或整数 + """ self.hdc.send_key(key_code) + @delay + def press_keys(self, key_code1: Union[KeyCode, int], key_code2: Union[KeyCode, int]): + """ + 按下组合键 + + Args: + key_code1: 第一个按键代码 + key_code2: 第二个按键代码 + """ + code1 = key_code1.value if isinstance(key_code1, KeyCode) else key_code1 + code2 = key_code2.value if isinstance(key_code2, KeyCode) else key_code2 + + api = "Driver.triggerCombineKeys" + self._invoke(api, args=[code1, code2]) + def screen_on(self): + """唤醒屏幕""" self.hdc.wakeup() def screen_off(self): + """关闭屏幕""" self.hdc.wakeup() self.press_key(KeyCode.POWER) @delay def unlock(self): + """ + 解锁屏幕 + + 先唤醒屏幕,然后从屏幕底部向上滑动 + """ self.screen_on() w, h = self.display_size self.swipe(0.5 * w, 0.8 * h, 0.5 * w, 0.2 * h, speed=6000) @cached_property def display_size(self) -> Tuple[int, int]: + """ + 获取屏幕尺寸 + + Returns: + Tuple[int, int]: 屏幕宽度和高度 + """ api = "Driver.getDisplaySize" resp: HypiumResponse = self._invoke(api) w, h = resp.result.get("x"), resp.result.get("y") @@ -279,16 +445,22 @@ def display_size(self) -> Tuple[int, int]: @cached_property def display_rotation(self) -> DisplayRotation: + """ + 获取屏幕旋转状态 + + Returns: + DisplayRotation: 屏幕旋转状态枚举值 + """ api = "Driver.getDisplayRotation" value = self._invoke(api).result return DisplayRotation.from_value(value) def set_display_rotation(self, rotation: DisplayRotation): """ - Sets the display rotation to the specified orientation. - + 设置屏幕旋转状态 + Args: - rotation (DisplayRotation): display rotation. + rotation: 屏幕旋转状态枚举值 """ api = "Driver.setDisplayRotation" self._invoke(api, args=[rotation.value]) @@ -296,10 +468,10 @@ def set_display_rotation(self, rotation: DisplayRotation): @cached_property def device_info(self) -> DeviceInfo: """ - Get detailed information about the device. - + 获取设备详细信息 + Returns: - DeviceInfo: An object containing various properties of the device. + DeviceInfo: 包含设备各种属性的对象 """ hdc = self.hdc return DeviceInfo( @@ -315,78 +487,107 @@ def device_info(self) -> DeviceInfo: @delay def open_url(self, url: str, system_browser: bool = True): + """ + 打开 URL + + Args: + url: 要打开的 URL + system_browser: 是否使用系统浏览器,默认为 True + """ if system_browser: - # Use the system browser + # 使用系统浏览器 self.hdc.shell(f"aa start -A ohos.want.action.viewData -e entity.system.browsable -U {url}") else: - # Default method + # 默认方法 self.hdc.shell(f"aa start -U {url}") def pull_file(self, rpath: str, lpath: str): """ - Pull a file from the device to the local machine. - + 从设备拉取文件到本地 + Args: - rpath (str): The remote path of the file on the device. - lpath (str): The local path where the file should be saved. + rpath: 设备上的文件路径 + lpath: 本地保存路径 """ self.hdc.recv_file(rpath, lpath) def push_file(self, lpath: str, rpath: str): """ - Push a file from the local machine to the device. - + 推送本地文件到设备 + Args: - lpath (str): The local path of the file. - rpath (str): The remote path where the file should be saved on the device. + lpath: 本地文件路径 + rpath: 设备上的保存路径 """ self.hdc.send_file(lpath, rpath) def screenshot(self, path: str) -> str: """ - Take a screenshot of the device display. - + 截取设备屏幕 + Args: - path (str): The local path to save the screenshot. - + path: 本地保存路径 + Returns: - str: The path where the screenshot is saved. + str: 截图保存路径 """ _uuid = uuid.uuid4().hex _tmp_path = f"/data/local/tmp/_tmp_{_uuid}.jpeg" self.shell(f"snapshot_display -f {_tmp_path}") self.pull_file(_tmp_path, path) - self.shell(f"rm -rf {_tmp_path}") # remove local path + self.shell(f"rm -rf {_tmp_path}") # 删除临时文件 return path def shell(self, cmd) -> CommandResult: + """ + 执行 Shell 命令 + + Args: + cmd: 要执行的命令 + + Returns: + CommandResult: 命令执行结果 + """ return self.hdc.shell(cmd) def _to_abs_pos(self, x: Union[int, float], y: Union[int, float]) -> Point: """ - Convert percentages to absolute screen coordinates. - + 将百分比坐标转换为绝对屏幕坐标 + Args: - x (Union[int, float]): X coordinate as a percentage or absolute value. - y (Union[int, float]): Y coordinate as a percentage or absolute value. - + x: X 坐标,可以是百分比(0-1)或绝对值 + y: Y 坐标,可以是百分比(0-1)或绝对值 + Returns: - Point: A Point object with absolute screen coordinates. - """ - assert x >= 0 - assert y >= 0 - - w, h = self.display_size - - if x < 1: - x = int(w * x) - if y < 1: - y = int(h * y) + Point: 包含绝对屏幕坐标的 Point 对象 + + Raises: + AssertionError: 坐标为负数时抛出 + """ + assert x >= 0, "X 坐标不能为负数" + assert y >= 0, "Y 坐标不能为负数" + + # 只有在需要时才获取显示尺寸 + if x < 1 or y < 1: + w, h = self.display_size + + if x < 1: + x = w * x + if y < 1: + y = h * y + + # 只进行一次整数转换 return Point(int(x), int(y)) @delay def click(self, x: Union[int, float], y: Union[int, float]): - + """ + 点击屏幕 + + Args: + x: X 坐标,可以是百分比(0-1)或绝对值 + y: Y 坐标,可以是百分比(0-1)或绝对值 + """ # self.hdc.tap(point.x, point.y) point = self._to_abs_pos(x, y) api = "Driver.click" @@ -394,35 +595,49 @@ def click(self, x: Union[int, float], y: Union[int, float]): @delay def double_click(self, x: Union[int, float], y: Union[int, float]): + """ + 双击屏幕 + + Args: + x: X 坐标,可以是百分比(0-1)或绝对值 + y: Y 坐标,可以是百分比(0-1)或绝对值 + """ point = self._to_abs_pos(x, y) api = "Driver.doubleClick" self._invoke(api, args=[point.x, point.y]) @delay def long_click(self, x: Union[int, float], y: Union[int, float]): + """ + 长按屏幕 + + Args: + x: X 坐标,可以是百分比(0-1)或绝对值 + y: Y 坐标,可以是百分比(0-1)或绝对值 + """ point = self._to_abs_pos(x, y) api = "Driver.longClick" self._invoke(api, args=[point.x, point.y]) @delay - def swipe(self, x1, y1, x2, y2, speed=2000): + def swipe(self, x1: Union[int, float], y1: Union[int, float], + x2: Union[int, float], y2: Union[int, float], speed: int = 2000): """ - Perform a swipe action on the device screen. - + 在屏幕上滑动 + Args: - x1 (float): The start X coordinate as a percentage or absolute value. - y1 (float): The start Y coordinate as a percentage or absolute value. - x2 (float): The end X coordinate as a percentage or absolute value. - y2 (float): The end Y coordinate as a percentage or absolute value. - speed (int, optional): The swipe speed in pixels per second. Default is 2000. Range: 200-40000, - If not within the range, set to default value of 2000. + x1: 起始 X 坐标,可以是百分比(0-1)或绝对值 + y1: 起始 Y 坐标,可以是百分比(0-1)或绝对值 + x2: 结束 X 坐标,可以是百分比(0-1)或绝对值 + y2: 结束 Y 坐标,可以是百分比(0-1)或绝对值 + speed: 滑动速度(像素/秒),默认为 2000,范围:200-40000 + 如果超出范围,将设为默认值 2000 """ - point1 = self._to_abs_pos(x1, y1) point2 = self._to_abs_pos(x2, y2) if speed < 200 or speed > 40000: - logger.warning("`speed` is not in the range[200-40000], Set to default value of 2000.") + logger.warning("`speed` 不在范围 [200-40000] 内,设置为默认值 2000") speed = 2000 api = "Driver.swipe" @@ -431,8 +646,14 @@ def swipe(self, x1, y1, x2, y2, speed=2000): @cached_property def swipe_ext(self): """ + 获取扩展滑动功能 + + 用法示例: d.swipe_ext("up") d.swipe_ext("up", box=(0.2, 0.2, 0.8, 0.8)) + + Returns: + SwipeExt: 扩展滑动功能实例 """ from ._swipe import SwipeExt return SwipeExt(self) @@ -440,41 +661,58 @@ def swipe_ext(self): @delay def input_text(self, text: str): """ - Inputs text into the currently focused input field. - - Note: The input field must have focus before calling this method. - + 在当前焦点输入框中输入文本 + + 注意:调用此方法前,输入框必须已获得焦点 + Args: - text (str): input value + text: 要输入的文本 + + Returns: + HypiumResponse: API 调用响应 """ return self._invoke("Driver.inputText", args=[{"x": 1, "y": 1}, text]) - def dump_hierarchy(self) -> Dict: + def dump_hierarchy(self) -> str: """ - Dump the UI hierarchy of the device screen. - + 导出界面层次结构 + Returns: - Dict: The dumped UI hierarchy as a dictionary. + str: 界面层次结构的 JSON 字符串 """ - # return self._client.invoke_captures("captureLayout").result - return self.hdc.dump_hierarchy() + result = self._client.invoke_captures("captureLayout").result + if isinstance(result, str): + return result + return json.dumps(result, ensure_ascii=False) @cached_property def gesture(self): + """ + 获取手势操作功能 + + Returns: + _Gesture: 手势操作功能实例 + """ from ._gesture import _Gesture return _Gesture(self) @cached_property def screenrecord(self): + """ + 获取屏幕录制功能 + + Returns: + RecordClient: 屏幕录制功能实例 + """ from ._screenrecord import RecordClient return RecordClient(self.serial, self) - def _invalidate_cache(self, attribute_name): + def _invalidate_cache(self, attribute_name: str): """ - Invalidate the cached property. - + 使缓存的属性失效 + Args: - attribute_name (str): The name of the attribute to invalidate. + attribute_name: 要使失效的属性名 """ if attribute_name in self.__dict__: del self.__dict__[attribute_name] @@ -482,7 +720,13 @@ def _invalidate_cache(self, attribute_name): @cached_property def xpath(self): """ + 获取 XPath 查询功能 + + 用法示例: d.xpath("//*[@text='Hello']").click() + + Returns: + _XPath: XPath 查询功能实例 """ from ._xpath import _XPath return _XPath(self) diff --git a/hmdriver2/hdc.py b/hmdriver2/hdc.py index 000133f..8d95511 100644 --- a/hmdriver2/hdc.py +++ b/hmdriver2/hdc.py @@ -1,20 +1,38 @@ # -*- coding: utf-8 -*- -import tempfile + import json -import uuid -import shlex -import re import os +import re +import shlex import subprocess -from typing import Union, List, Dict, Tuple +import tempfile +import uuid +from typing import Union, List, Dict, Tuple, Optional, Any from . import logger -from .utils import FreePort -from .proto import CommandResult, KeyCode from .exception import HdcError, DeviceNotFoundError +from .proto import CommandResult, KeyCode +from .utils import FreePort + +# HDC 命令相关常量 +HDC_CMD = "hdc" +HDC_SERVER_HOST_ENV = "HDC_SERVER_HOST" +HDC_SERVER_PORT_ENV = "HDC_SERVER_PORT" + +# 键码相关常量 +MAX_KEY_CODE = 3200 def _execute_command(cmdargs: Union[str, List[str]]) -> CommandResult: + """ + 执行命令并返回结果 + + Args: + cmdargs: 要执行的命令,可以是字符串或命令参数列表 + + Returns: + CommandResult: 命令执行结果对象 + """ if isinstance(cmdargs, (list, tuple)): cmdline: str = ' '.join(list(map(shlex.quote, cmdargs))) elif isinstance(cmdargs, str): @@ -29,6 +47,7 @@ def _execute_command(cmdargs: Union[str, List[str]]) -> CommandResult: error = error.decode('utf-8') exit_code = process.returncode + # 检查输出中是否包含错误信息 if 'error:' in output.lower() or '[fail]' in output.lower(): return CommandResult("", output, -1) @@ -39,131 +58,298 @@ def _execute_command(cmdargs: Union[str, List[str]]) -> CommandResult: def _build_hdc_prefix() -> str: """ - Construct the hdc command prefix based on environment variables. + 根据环境变量构建 HDC 命令前缀 + + 如果设置了 HDC_SERVER_HOST 和 HDC_SERVER_PORT 环境变量, + 则使用这些值构建带有服务器连接信息的命令前缀。 + + Returns: + str: HDC 命令前缀 """ - host = os.getenv("HDC_SERVER_HOST") - port = os.getenv("HDC_SERVER_PORT") + host = os.getenv(HDC_SERVER_HOST_ENV) + port = os.getenv(HDC_SERVER_PORT_ENV) if host and port: - logger.debug(f"HDC_SERVER_HOST: {host}, HDC_SERVER_PORT: {port}") - return f"hdc -s {host}:{port}" - return "hdc" + logger.debug(f"{HDC_SERVER_HOST_ENV}: {host}, {HDC_SERVER_PORT_ENV}: {port}") + return f"{HDC_CMD} -s {host}:{port}" + return HDC_CMD def list_devices() -> List[str]: + """ + 列出所有已连接的设备 + + Returns: + List[str]: 设备序列号列表 + + Raises: + HdcError: HDC 命令执行失败时抛出 + """ devices = [] hdc_prefix = _build_hdc_prefix() result = _execute_command(f"{hdc_prefix} list targets") + if result.exit_code == 0 and result.output: lines = result.output.strip().split('\n') for line in lines: - if line.__contains__('Empty'): + if 'Empty' in line: continue devices.append(line.strip()) if result.exit_code != 0: - raise HdcError("HDC error", result.error) + raise HdcError("HDC 错误", result.error) return devices class HdcWrapper: + """ + HDC 命令包装类 + + 提供对 HDC 命令的封装,简化与 Harmony OS 设备的交互。 + """ + def __init__(self, serial: str) -> None: + """ + 初始化 HDC 包装器 + + Args: + serial: 设备序列号 + + Raises: + DeviceNotFoundError: 设备未找到时抛出 + """ self.serial = serial self.hdc_prefix = _build_hdc_prefix() if not self.is_online(): - raise DeviceNotFoundError(f"Device [{self.serial}] not found") + raise DeviceNotFoundError(f"未找到设备 [{self.serial}]") - def is_online(self): + def is_online(self) -> bool: + """ + 检查设备是否在线 + + Returns: + bool: 设备在线返回 True,否则返回 False + """ _serials = list_devices() - return True if self.serial in _serials else False + return self.serial in _serials def forward_port(self, rport: int) -> int: + """ + 设置端口转发 + + 将设备上的端口转发到本地端口 + + Args: + rport: 设备端口 + + Returns: + int: 本地端口 + + Raises: + HdcError: 端口转发失败时抛出 + """ lport: int = FreePort().get() result = _execute_command(f"{self.hdc_prefix} -t {self.serial} fport tcp:{lport} tcp:{rport}") if result.exit_code != 0: - raise HdcError("HDC forward port error", result.error) + raise HdcError("HDC 端口转发错误", result.error) return lport def rm_forward(self, lport: int, rport: int) -> int: + """ + 移除端口转发 + + Args: + lport: 本地端口 + rport: 设备端口 + + Returns: + int: 本地端口 + + Raises: + HdcError: 移除端口转发失败时抛出 + """ result = _execute_command(f"{self.hdc_prefix} -t {self.serial} fport rm tcp:{lport} tcp:{rport}") if result.exit_code != 0: - raise HdcError("HDC rm forward error", result.error) + raise HdcError("HDC 移除端口转发错误", result.error) return lport - def list_fport(self) -> List: + def list_fport(self) -> List[str]: """ - eg.['tcp:10001 tcp:8012', 'tcp:10255 tcp:8012'] + 列出所有端口转发 + + Returns: + List[str]: 端口转发列表,例如 ['tcp:10001 tcp:8012', 'tcp:10255 tcp:8012'] + + Raises: + HdcError: 列出端口转发失败时抛出 """ result = _execute_command(f"{self.hdc_prefix} -t {self.serial} fport ls") if result.exit_code != 0: - raise HdcError("HDC forward list error", result.error) + raise HdcError("HDC 列出端口转发错误", result.error) pattern = re.compile(r"tcp:\d+ tcp:\d+") return pattern.findall(result.output) - def send_file(self, lpath: str, rpath: str): + def send_file(self, lpath: str, rpath: str) -> CommandResult: + """ + 发送文件到设备 + + Args: + lpath: 本地文件路径 + rpath: 设备上的目标路径 + + Returns: + CommandResult: 命令执行结果 + + Raises: + HdcError: 发送文件失败时抛出 + """ result = _execute_command(f"{self.hdc_prefix} -t {self.serial} file send {lpath} {rpath}") if result.exit_code != 0: - raise HdcError("HDC send file error", result.error) + raise HdcError("HDC 发送文件错误", result.error) return result - def recv_file(self, rpath: str, lpath: str): + def recv_file(self, rpath: str, lpath: str) -> CommandResult: + """ + 从设备接收文件 + + Args: + rpath: 设备上的文件路径 + lpath: 本地保存路径 + + Returns: + CommandResult: 命令执行结果 + + Raises: + HdcError: 接收文件失败时抛出 + """ result = _execute_command(f"{self.hdc_prefix} -t {self.serial} file recv {rpath} {lpath}") if result.exit_code != 0: - raise HdcError("HDC receive file error", result.error) + raise HdcError("HDC 接收文件错误", result.error) return result - def shell(self, cmd: str, error_raise=True) -> CommandResult: - # ensure the command is wrapped in double quotes + def shell(self, cmd: str, error_raise: bool = True) -> CommandResult: + """ + 在设备上执行 Shell 命令 + + Args: + cmd: 要执行的 Shell 命令 + error_raise: 命令失败时是否抛出异常,默认为 True + + Returns: + CommandResult: 命令执行结果 + + Raises: + HdcError: 命令执行失败且 error_raise 为 True 时抛出 + """ + # 确保命令用双引号包裹 if cmd[0] != '\"': cmd = "\"" + cmd if cmd[-1] != '\"': cmd += '\"' result = _execute_command(f"{self.hdc_prefix} -t {self.serial} shell {cmd}") if result.exit_code != 0 and error_raise: - raise HdcError("HDC shell error", f"{cmd}\n{result.output}\n{result.error}") + raise HdcError("HDC Shell 命令错误", f"{cmd}\n{result.output}\n{result.error}") return result - def uninstall(self, bundlename: str): + def uninstall(self, bundlename: str) -> CommandResult: + """ + 卸载应用 + + Args: + bundlename: 应用包名 + + Returns: + CommandResult: 命令执行结果 + + Raises: + HdcError: 卸载应用失败时抛出 + """ result = _execute_command(f"{self.hdc_prefix} -t {self.serial} uninstall {bundlename}") if result.exit_code != 0: - raise HdcError("HDC uninstall error", result.output) + raise HdcError("HDC 卸载应用错误", result.output) return result - def install(self, apkpath: str): - # Ensure the path is properly quoted for Windows + def install(self, apkpath: str) -> CommandResult: + """ + 安装应用 + + Args: + apkpath: 应用安装包路径 + + Returns: + CommandResult: 命令执行结果 + + Raises: + HdcError: 安装应用失败时抛出 + """ + # 确保路径正确引用,特别是在 Windows 系统上 quoted_path = f'"{apkpath}"' result = _execute_command(f"{self.hdc_prefix} -t {self.serial} install {quoted_path}") if result.exit_code != 0: - raise HdcError("HDC install error", result.error) + raise HdcError("HDC 安装应用错误", result.error) return result def list_apps(self) -> List[str]: + """ + 列出设备上的所有应用 + + Returns: + List[str]: 应用列表 + """ result = self.shell("bm dump -a") raw = result.output.split('\n') - return [item.strip() for item in raw] + return [item.strip() for item in raw if item.strip()] def has_app(self, package_name: str) -> bool: + """ + 检查设备上是否安装了指定应用 + + Args: + package_name: 应用包名 + + Returns: + bool: 应用存在返回 True,否则返回 False + """ data = self.shell("bm dump -a").output - return True if package_name in data else False + return package_name in data - def start_app(self, package_name: str, ability_name: str): + def start_app(self, package_name: str, ability_name: str) -> CommandResult: + """ + 启动应用 + + Args: + package_name: 应用包名 + ability_name: Ability 名称 + + Returns: + CommandResult: 命令执行结果 + """ return self.shell(f"aa start -a {ability_name} -b {package_name}") - def stop_app(self, package_name: str): + def stop_app(self, package_name: str) -> CommandResult: + """ + 停止应用 + + Args: + package_name: 应用包名 + + Returns: + CommandResult: 命令执行结果 + """ return self.shell(f"aa force-stop {package_name}") - def current_app(self) -> Tuple[str, str]: + def current_app(self) -> Tuple[Optional[str], Optional[str]]: """ - Get the current foreground application information. - + 获取当前前台应用信息 + Returns: - Tuple[str, str]: A tuple contain the package_name andpage_name of the foreground application. - If no foreground application is found, returns (None, None). + Tuple[Optional[str], Optional[str]]: 包含应用包名和页面名称的元组 + 如果未找到前台应用,返回 (None, None) """ - - def __extract_info(output: str): + def __extract_info(output: str) -> List[Tuple[str, str]]: + """提取应用信息""" results = [] mission_blocks = re.findall(r'Mission ID #[\s\S]*?isKeepAlive: false\s*}', output) @@ -186,12 +372,17 @@ def __extract_info(output: str): results = __extract_info(output) return results[0] if results else (None, None) - def wakeup(self): + def wakeup(self) -> None: + """唤醒设备""" self.shell("power-shell wakeup") - def screen_state(self) -> str: + def screen_state(self) -> Optional[str]: """ - ["INACTIVE", "SLEEP, AWAKE"] + 获取屏幕状态 + + Returns: + Optional[str]: 屏幕状态,可能的值包括 "INACTIVE"、"SLEEP"、"AWAKE" 等 + 如果无法获取状态,返回 None """ data = self.shell("hidumper -s PowerManagerService -a -s").output pattern = r"Current State:\s*(\w+)" @@ -199,39 +390,96 @@ def screen_state(self) -> str: return match.group(1) if match else None - def wlan_ip(self) -> Union[str, None]: + def wlan_ip(self) -> Optional[str]: + """ + 获取设备的 WLAN IP 地址 + + Returns: + Optional[str]: WLAN IP 地址,如果未找到则返回 None + """ data = self.shell("ifconfig").output matches = re.findall(r'inet addr:(?!127)(\d+\.\d+\.\d+\.\d+)', data) return matches[0] if matches else None - def __split_text(self, text: str) -> str: + def __split_text(self, text: Optional[str]) -> Optional[str]: + """ + 从文本中提取第一行并去除前后空白 + + Args: + text: 输入文本 + + Returns: + Optional[str]: 处理后的文本,如果输入为 None 则返回 None + """ return text.split("\n")[0].strip() if text else None - def sdk_version(self) -> str: + def sdk_version(self) -> Optional[str]: + """ + 获取设备 SDK 版本 + + Returns: + Optional[str]: SDK 版本 + """ data = self.shell("param get const.ohos.apiversion").output return self.__split_text(data) - def sys_version(self) -> str: + def sys_version(self) -> Optional[str]: + """ + 获取设备系统版本 + + Returns: + Optional[str]: 系统版本 + """ data = self.shell("param get const.product.software.version").output return self.__split_text(data) - def model(self) -> str: + def model(self) -> Optional[str]: + """ + 获取设备型号 + + Returns: + Optional[str]: 设备型号 + """ data = self.shell("param get const.product.model").output return self.__split_text(data) - def brand(self) -> str: + def brand(self) -> Optional[str]: + """ + 获取设备品牌 + + Returns: + Optional[str]: 设备品牌 + """ data = self.shell("param get const.product.brand").output return self.__split_text(data) - def product_name(self) -> str: + def product_name(self) -> Optional[str]: + """ + 获取设备产品名称 + + Returns: + Optional[str]: 产品名称 + """ data = self.shell("param get const.product.name").output return self.__split_text(data) - def cpu_abi(self) -> str: + def cpu_abi(self) -> Optional[str]: + """ + 获取设备 CPU ABI + + Returns: + Optional[str]: CPU ABI + """ data = self.shell("param get const.product.cpu.abilist").output return self.__split_text(data) def display_size(self) -> Tuple[int, int]: + """ + 获取设备屏幕尺寸 + + Returns: + Tuple[int, int]: 屏幕宽度和高度,如果无法获取则返回 (0, 0) + """ data = self.shell("hidumper -s RenderService -a screen").output match = re.search(r'activeMode:\s*(\d+)x(\d+),\s*refreshrate=\d+', data) @@ -242,33 +490,81 @@ def display_size(self) -> Tuple[int, int]: return (0, 0) def send_key(self, key_code: Union[KeyCode, int]) -> None: + """ + 发送按键事件 + + Args: + key_code: 按键代码,可以是 KeyCode 枚举或整数 + + Raises: + HdcError: 按键代码无效时抛出 + """ if isinstance(key_code, KeyCode): key_code = key_code.value - MAX = 3200 - if key_code > MAX: - raise HdcError("Invalid HDC keycode") + if key_code > MAX_KEY_CODE: + raise HdcError("无效的 HDC 按键代码") self.shell(f"uitest uiInput keyEvent {key_code}") def tap(self, x: int, y: int) -> None: + """ + 点击屏幕 + + Args: + x: X 坐标 + y: Y 坐标 + """ self.shell(f"uitest uiInput click {x} {y}") - def swipe(self, x1, y1, x2, y2, speed=1000): + def swipe(self, x1: int, y1: int, x2: int, y2: int, speed: int = 1000) -> None: + """ + 在屏幕上滑动 + + Args: + x1: 起始 X 坐标 + y1: 起始 Y 坐标 + x2: 结束 X 坐标 + y2: 结束 Y 坐标 + speed: 滑动速度,默认为 1000 + """ self.shell(f"uitest uiInput swipe {x1} {y1} {x2} {y2} {speed}") - def input_text(self, x: int, y: int, text: str): + def input_text(self, x: int, y: int, text: str) -> None: + """ + 在指定位置输入文本 + + Args: + x: X 坐标 + y: Y 坐标 + text: 要输入的文本 + """ self.shell(f"uitest uiInput inputText {x} {y} {text}") def screenshot(self, path: str) -> str: + """ + 截取屏幕 + + Args: + path: 本地保存路径 + + Returns: + str: 截图保存路径 + """ _uuid = uuid.uuid4().hex _tmp_path = f"/data/local/tmp/_tmp_{_uuid}.jpeg" self.shell(f"snapshot_display -f {_tmp_path}") self.recv_file(_tmp_path, path) - self.shell(f"rm -rf {_tmp_path}") # remove local path + self.shell(f"rm -rf {_tmp_path}") # 删除临时文件 return path - def dump_hierarchy(self) -> Dict: + def dump_hierarchy(self) -> Dict[str, Any]: + """ + 导出界面层次结构 + + Returns: + Dict[str, Any]: 界面层次结构数据,如果解析失败则返回空字典 + """ _tmp_path = f"/data/local/tmp/{self.serial}_tmp.json" self.shell(f"uitest dumpLayout -p {_tmp_path}") @@ -280,7 +576,7 @@ def dump_hierarchy(self) -> Dict: with open(path, 'r', encoding='utf8') as file: data = json.load(file) except Exception as e: - logger.error(f"Error loading JSON file: {e}") + logger.error(f"加载 JSON 文件时出错: {e}") data = {} return data diff --git a/hmdriver2/proto.py b/hmdriver2/proto.py index e987cc9..9f108f5 100644 --- a/hmdriver2/proto.py +++ b/hmdriver2/proto.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import json -from enum import Enum -from typing import Union, List from dataclasses import dataclass, asdict +from enum import Enum +from typing import Union, List, Dict @dataclass @@ -64,8 +64,8 @@ class HypiumResponse: {"result":null,"exception":"Can not connect to AAMS, RET_ERR_CONNECTION_EXIST"} {"exception":{"code":401,"message":"(PreProcessing: APiCallInfoChecker)Illegal argument count"}} """ - result: Union[List, bool, str, None] = None - exception: Union[List, bool, str, None] = None + result: Union[List, Dict, bool, str, None] = None + exception: Union[List, Dict, bool, str, None] = None @dataclass @@ -160,51 +160,51 @@ class KeyCode(Enum): MUTE = 23 # 话筒静音键 BRIGHTNESS_UP = 40 # 亮度调节按键调亮 BRIGHTNESS_DOWN = 41 # 亮度调节按键调暗 - NUM_0 = 2000 # 按键’0’ - NUM_1 = 2001 # 按键’1’ - NUM_2 = 2002 # 按键’2’ - NUM_3 = 2003 # 按键’3’ - NUM_4 = 2004 # 按键’4’ - NUM_5 = 2005 # 按键’5’ - NUM_6 = 2006 # 按键’6’ - NUM_7 = 2007 # 按键’7’ - NUM_8 = 2008 # 按键’8’ - NUM_9 = 2009 # 按键’9’ - STAR = 2010 # 按键’*’ - POUND = 2011 # 按键’#’ + NUM_0 = 2000 # 按键'0' + NUM_1 = 2001 # 按键'1' + NUM_2 = 2002 # 按键'2' + NUM_3 = 2003 # 按键'3' + NUM_4 = 2004 # 按键'4' + NUM_5 = 2005 # 按键'5' + NUM_6 = 2006 # 按键'6' + NUM_7 = 2007 # 按键'7' + NUM_8 = 2008 # 按键'8' + NUM_9 = 2009 # 按键'9' + STAR = 2010 # 按键'*' + POUND = 2011 # 按键'#' DPAD_UP = 2012 # 导航键向上 DPAD_DOWN = 2013 # 导航键向下 DPAD_LEFT = 2014 # 导航键向左 DPAD_RIGHT = 2015 # 导航键向右 DPAD_CENTER = 2016 # 导航键确定键 - A = 2017 # 按键’A’ - B = 2018 # 按键’B’ - C = 2019 # 按键’C’ - D = 2020 # 按键’D’ - E = 2021 # 按键’E’ - F = 2022 # 按键’F’ - G = 2023 # 按键’G’ - H = 2024 # 按键’H’ - I = 2025 # 按键’I’ - J = 2026 # 按键’J’ - K = 2027 # 按键’K’ - L = 2028 # 按键’L’ - M = 2029 # 按键’M’ - N = 2030 # 按键’N’ - O = 2031 # 按键’O’ - P = 2032 # 按键’P’ - Q = 2033 # 按键’Q’ - R = 2034 # 按键’R’ - S = 2035 # 按键’S’ - T = 2036 # 按键’T’ - U = 2037 # 按键’U’ - V = 2038 # 按键’V’ - W = 2039 # 按键’W’ - X = 2040 # 按键’X’ - Y = 2041 # 按键’Y’ - Z = 2042 # 按键’Z’ - COMMA = 2043 # 按键’,’ - PERIOD = 2044 # 按键’.’ + A = 2017 # 按键'A' + B = 2018 # 按键'B' + C = 2019 # 按键'C' + D = 2020 # 按键'D' + E = 2021 # 按键'E' + F = 2022 # 按键'F' + G = 2023 # 按键'G' + H = 2024 # 按键'H' + I = 2025 # 按键'I' + J = 2026 # 按键'J' + K = 2027 # 按键'K' + L = 2028 # 按键'L' + M = 2029 # 按键'M' + N = 2030 # 按键'N' + O = 2031 # 按键'O' + P = 2032 # 按键'P' + Q = 2033 # 按键'Q' + R = 2034 # 按键'R' + S = 2035 # 按键'S' + T = 2036 # 按键'T' + U = 2037 # 按键'U' + V = 2038 # 按键'V' + W = 2039 # 按键'W' + X = 2040 # 按键'X' + Y = 2041 # 按键'Y' + Z = 2042 # 按键'Z' + COMMA = 2043 # 按键',' + PERIOD = 2044 # 按键'.' ALT_LEFT = 2045 # 左Alt键 ALT_RIGHT = 2046 # 右Alt键 SHIFT_LEFT = 2047 # 左Shift键 @@ -216,17 +216,17 @@ class KeyCode(Enum): ENVELOPE = 2053 # 电子邮件功能键,此键用于启动电子邮件应用程序。 ENTER = 2054 # 回车键 DEL = 2055 # 退格键 - GRAVE = 2056 # 按键’`’ - MINUS = 2057 # 按键’-’ - EQUALS = 2058 # 按键’=’ - LEFT_BRACKET = 2059 # 按键’[’ - RIGHT_BRACKET = 2060 # 按键’]’ - BACKSLASH = 2061 # 按键’\’ - SEMICOLON = 2062 # 按键’;’ - APOSTROPHE = 2063 # 按键’‘’(单引号) - SLASH = 2064 # 按键’/’ - AT = 2065 # 按键’@’ - PLUS = 2066 # 按键’+’ + GRAVE = 2056 # 按键'`' + MINUS = 2057 # 按键'-' + EQUALS = 2058 # 按键'=' + LEFT_BRACKET = 2059 # 按键'[' + RIGHT_BRACKET = 2060 # 按键']' + BACKSLASH = 2061 # 按键'\' + SEMICOLON = 2062 # 按键';' + APOSTROPHE = 2063 # 按键''' + SLASH = 2064 # 按键'/' + AT = 2065 # 按键'@' + PLUS = 2066 # 按键'+' MENU = 2067 # 菜单键 PAGE_UP = 2068 # 向上翻页键 PAGE_DOWN = 2069 # 向下翻页键 @@ -250,39 +250,39 @@ class KeyCode(Enum): MEDIA_CLOSE = 2087 # 多媒体键关闭 MEDIA_EJECT = 2088 # 多媒体键弹出 MEDIA_RECORD = 2089 # 多媒体键录音 - F1 = 2090 # 按键’F1’ - F2 = 2091 # 按键’F2’ - F3 = 2092 # 按键’F3’ - F4 = 2093 # 按键’F4’ - F5 = 2094 # 按键’F5’ - F6 = 2095 # 按键’F6’ - F7 = 2096 # 按键’F7’ - F8 = 2097 # 按键’F8’ - F9 = 2098 # 按键’F9’ - F10 = 2099 # 按键’F10’ - F11 = 2100 # 按键’F11’ - F12 = 2101 # 按键’F12’ + F1 = 2090 # 按键'F1' + F2 = 2091 # 按键'F2' + F3 = 2092 # 按键'F3' + F4 = 2093 # 按键'F4' + F5 = 2094 # 按键'F5' + F6 = 2095 # 按键'F6' + F7 = 2096 # 按键'F7' + F8 = 2097 # 按键'F8' + F9 = 2098 # 按键'F9' + F10 = 2099 # 按键'F10' + F11 = 2100 # 按键'F11' + F12 = 2101 # 按键'F12' NUM_LOCK = 2102 # 小键盘锁 - NUMPAD_0 = 2103 # 小键盘按键’0’ - NUMPAD_1 = 2104 # 小键盘按键’1’ - NUMPAD_2 = 2105 # 小键盘按键’2’ - NUMPAD_3 = 2106 # 小键盘按键’3’ - NUMPAD_4 = 2107 # 小键盘按键’4’ - NUMPAD_5 = 2108 # 小键盘按键’5’ - NUMPAD_6 = 2109 # 小键盘按键’6’ - NUMPAD_7 = 2110 # 小键盘按键’7’ - NUMPAD_8 = 2111 # 小键盘按键’8’ - NUMPAD_9 = 2112 # 小键盘按键’9’ - NUMPAD_DIVIDE = 2113 # 小键盘按键’/’ - NUMPAD_MULTIPLY = 2114 # 小键盘按键’*’ - NUMPAD_SUBTRACT = 2115 # 小键盘按键’-’ - NUMPAD_ADD = 2116 # 小键盘按键’+’ - NUMPAD_DOT = 2117 # 小键盘按键’.’ - NUMPAD_COMMA = 2118 # 小键盘按键’,’ + NUMPAD_0 = 2103 # 小键盘按键'0' + NUMPAD_1 = 2104 # 小键盘按键'1' + NUMPAD_2 = 2105 # 小键盘按键'2' + NUMPAD_3 = 2106 # 小键盘按键'3' + NUMPAD_4 = 2107 # 小键盘按键'4' + NUMPAD_5 = 2108 # 小键盘按键'5' + NUMPAD_6 = 2109 # 小键盘按键'6' + NUMPAD_7 = 2110 # 小键盘按键'7' + NUMPAD_8 = 2111 # 小键盘按键'8' + NUMPAD_9 = 2112 # 小键盘按键'9' + NUMPAD_DIVIDE = 2113 # 小键盘按键'/' + NUMPAD_MULTIPLY = 2114 # 小键盘按键'*' + NUMPAD_SUBTRACT = 2115 # 小键盘按键'-' + NUMPAD_ADD = 2116 # 小键盘按键'+' + NUMPAD_DOT = 2117 # 小键盘按键'.' + NUMPAD_COMMA = 2118 # 小键盘按键',' NUMPAD_ENTER = 2119 # 小键盘按键回车 - NUMPAD_EQUALS = 2120 # 小键盘按键’=’ - NUMPAD_LEFT_PAREN = 2121 # 小键盘按键’(’ - NUMPAD_RIGHT_PAREN = 2122 # 小键盘按键’)’ + NUMPAD_EQUALS = 2120 # 小键盘按键'=' + NUMPAD_LEFT_PAREN = 2121 # 小键盘按键'(' + NUMPAD_RIGHT_PAREN = 2122 # 小键盘按键')' VIRTUAL_MULTITASK = 2210 # 虚拟多任务键 SLEEP = 2600 # 睡眠键 ZENKAKU_HANKAKU = 2601 # 日文全宽/半宽键 @@ -431,18 +431,18 @@ class KeyCode(Enum): EJECTCLOSECD = 2813 # 弹出CD键 ISO = 2814 # ISO键 MOVE = 2815 # 移动键 - F13 = 2816 # 按键’F13’ - F14 = 2817 # 按键’F14’ - F15 = 2818 # 按键’F15’ - F16 = 2819 # 按键’F16’ - F17 = 2820 # 按键’F17’ - F18 = 2821 # 按键’F18’ - F19 = 2822 # 按键’F19’ - F20 = 2823 # 按键’F20’ - F21 = 2824 # 按键’F21’ - F22 = 2825 # 按键’F22’ - F23 = 2826 # 按键’F23’ - F24 = 2827 # 按键’F24’ + F13 = 2816 # 按键'F13' + F14 = 2817 # 按键'F14' + F15 = 2818 # 按键'F15' + F16 = 2819 # 按键'F16' + F17 = 2820 # 按键'F17' + F18 = 2821 # 按键'F18' + F19 = 2822 # 按键'F19' + F20 = 2823 # 按键'F20' + F21 = 2824 # 按键'F21' + F22 = 2825 # 按键'F22' + F23 = 2826 # 按键'F23' + F24 = 2827 # 按键'F24' PROG3 = 2828 # 程序键3 PROG4 = 2829 # 程序键4 DASHBOARD = 2830 # 仪表板 diff --git a/hmdriver2/utils.py b/hmdriver2/utils.py index cf4f75a..05fec21 100644 --- a/hmdriver2/utils.py +++ b/hmdriver2/utils.py @@ -1,60 +1,120 @@ # -*- coding: utf-8 -*- - -import time -import socket import re +import socket +import time from functools import wraps -from typing import Union +from typing import Optional, Callable, Any, TypeVar from .proto import Bounds +# 默认 UI 操作后的延迟时间(秒) +DEFAULT_DELAY_TIME = 0.6 + +# 端口范围 +PORT_RANGE_START = 10000 +PORT_RANGE_END = 20000 -def delay(func): +# 类型变量定义,用于泛型函数 +F = TypeVar('F', bound=Callable[..., Any]) + + +def delay(func: F) -> F: """ - After each UI operation, it is necessary to wait for a while to ensure the stability of the UI, - so as not to affect the next UI operation. + UI 操作后的延迟装饰器 + + 在每次 UI 操作后需要等待一段时间,确保 UI 稳定, + 避免影响下一次 UI 操作。 + + Args: + func: 要装饰的函数 + + Returns: + 装饰后的函数 """ - DELAY_TIME = 0.6 - @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: result = func(*args, **kwargs) - time.sleep(DELAY_TIME) + time.sleep(DEFAULT_DELAY_TIME) return result - return wrapper + + return wrapper # type: ignore class FreePort: - def __init__(self): - self._start = 10000 - self._end = 20000 + """ + 空闲端口管理类 + + 用于获取系统中未被占用的网络端口 + """ + + def __init__(self) -> None: + """初始化端口管理器""" + self._start = PORT_RANGE_START + self._end = PORT_RANGE_END self._now = self._start - 1 def get(self) -> int: - while True: + """ + 获取一个空闲端口 + + Returns: + int: 可用的端口号 + """ + attempts = 0 + max_attempts = self._end - self._start + + while attempts < max_attempts: + attempts += 1 self._now += 1 if self._now > self._end: self._now = self._start + if not self.is_port_in_use(self._now): return self._now + + raise RuntimeError(f"无法找到可用端口,已尝试 {max_attempts} 次") @staticmethod def is_port_in_use(port: int) -> bool: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 + """ + 检查端口是否被占用 + + Args: + port: 要检查的端口号 + + Returns: + bool: 端口被占用返回 True,否则返回 False + """ + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(('localhost', port)) == 0 + except (socket.error, OSError): + # 如果发生错误,保守地认为端口被占用 + return True -def parse_bounds(bounds: str) -> Union[Bounds, None]: +def parse_bounds(bounds: str) -> Optional[Bounds]: """ - Parse bounds string to Bounds. - bounds is str, like: "[832,1282][1125,1412]" + 解析边界字符串为 Bounds 对象 + + Args: + bounds: 边界字符串,格式如 "[832,1282][1125,1412]" + + Returns: + Optional[Bounds]: 解析成功返回 Bounds 对象,否则返回 None """ + if not bounds: + return None + result = re.match(r"\[(\d+),(\d+)\]\[(\d+),(\d+)\]", bounds) if result: g = result.groups() - return Bounds(int(g[0]), - int(g[1]), - int(g[2]), - int(g[3])) - return None \ No newline at end of file + try: + return Bounds(int(g[0]), + int(g[1]), + int(g[2]), + int(g[3])) + except (ValueError, IndexError): + return None + return None