diff --git a/tests/test_utils.py b/tests/test_utils.py index 4d15c6c8..1cdb19a3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,8 @@ # coding: utf-8 # +import time +import threading import pytest from uiautomator2 import utils @@ -23,4 +25,28 @@ def foo(a, b, c=2): assert ret == 242 with pytest.raises(TypeError): - utils.inject_call(foo, 2) \ No newline at end of file + utils.inject_call(foo, 2) + + +def test_method_threadsafe(): + class A: + n = 0 + + @utils.method_threadsafe + def call(self): + v = self.n + time.sleep(.5) + self.n = v + 1 + + a = A() + th1 = threading.Thread(name="th1", target=a.call) + th2 = threading.Thread(name="th2", target=a.call) + th1.start() + th2.start() + th1.join() + th2.join() + + assert 2 == a.n + + + diff --git a/uiautomator2/__init__.py b/uiautomator2/__init__.py index e4a552aa..5aad6fab 100644 --- a/uiautomator2/__init__.py +++ b/uiautomator2/__init__.py @@ -17,7 +17,6 @@ import base64 import contextlib -import functools import hashlib import io import json @@ -31,12 +30,13 @@ import time import warnings import xml.dom.minidom -from collections import namedtuple, defaultdict +from collections import defaultdict, namedtuple from datetime import datetime from typing import List, Optional, Tuple, Union # import progress.bar import adbutils +import logzero import packaging import requests import six @@ -49,6 +49,7 @@ from urllib3.util.retry import Retry from . import xpath +from ._proto import HTTP_TIMEOUT, SCROLL_STEPS, Direction from ._selector import Selector, UiObject from .exceptions import (BaseError, ConnectError, GatewayError, JSONRPCError, NullObjectExceptionError, NullPointerExceptionError, @@ -59,10 +60,9 @@ # from .session import Session # noqa: F401 from .settings import Settings from .swipe import SwipeExt -from .utils import list2cmdline -from .version import __atx_agent_version__, __apk_version__ -from .watcher import Watcher, WatchContext -from ._proto import SCROLL_STEPS, Direction, HTTP_TIMEOUT +from .utils import list2cmdline, method_threadsafe +from .version import __apk_version__, __atx_agent_version__ +from .watcher import WatchContext, Watcher if six.PY2: FileNotFoundError = OSError @@ -70,9 +70,10 @@ DEBUG = False WAIT_FOR_DEVICE_TIMEOUT = int(os.getenv("WAIT_FOR_DEVICE_TIMEOUT", 20)) -# logger = logging.getLogger("uiautomator2") -logger = setup_logger("uiautomator2", level=logging.DEBUG) +log_format = '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s [pid:%(process)d] %(message)s' +formatter = logzero.LogFormatter(fmt=log_format) +logger = setup_logger("uiautomator2", level=logging.DEBUG, formatter=formatter) _mswindows = (os.name == "nt") @@ -575,6 +576,7 @@ def _is_alive(self): except (requests.ReadTimeout, EnvironmentError): return False + # @method_threadsafe def reset_uiautomator(self, reason="unknown", depth=0): """ Reset uiautomator diff --git a/uiautomator2/utils.py b/uiautomator2/utils.py index b2f5a8b8..0f7c7533 100644 --- a/uiautomator2/utils.py +++ b/uiautomator2/utils.py @@ -4,6 +4,8 @@ import functools import inspect import shlex +import threading +import typing from typing import Union import six @@ -207,6 +209,17 @@ def _swipe(_from, _to): raise ValueError("Unknown direction:", direction) +def method_threadsafe(fn: typing.Callable): + @functools.wraps(fn) + def inner(self, *args, **kwargs): + if not hasattr(self, "_lock"): + self._lock = threading.Lock() + + with self._lock: + return fn(self, *args, **kwargs) + + return inner + if __name__ == "__main__": for n in (1, 10000, 10000000, 10000000000): print(n, natualsize(n))