diff --git a/cookbook/examplescript.py b/cookbook/examplescript.py
index 543369f7..674bbf7c 100644
--- a/cookbook/examplescript.py
+++ b/cookbook/examplescript.py
@@ -1,50 +1,58 @@
import asyncio
import base64
-from mobileadapt import mobileadapt
-from datetime import datetime
-from PIL import Image
import io
import os
+from datetime import datetime
+
from loguru import logger
-''' From the root directory use the following command to start the script:
+from PIL import Image
+
+from mobileadapt import mobileadapt
+
+""" From the root directory use the following command to start the script:
python example-scripts/examplescript.py
-'''
+"""
+
async def save_screenshot(screenshot_data, filename):
# Open the screenshot data as an image and save it
image = Image.open(io.BytesIO(screenshot_data))
image.save(filename)
+
async def main():
# Create an Android device instance
android_device = mobileadapt(platform="android")
-
+
# Initialize the device (starts the Appium session)
await android_device.start_device()
-
+
# Get the current state of the device
encoded_ui, screenshot, ui = await android_device.get_state()
logger.info(f"Current state: {encoded_ui}")
-
+
# Save the first screenshot
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# filename1 = os.path.join(os.path.dirname(__file__), f"screenshot_before_{timestamp}.png")
- #await save_screenshot(screenshot, filename1)
+ # await save_screenshot(screenshot, filename1)
# print(f"Screenshot saved as {filename1}")
-
+
# Perform a tap action at coordinates (100, 100)
await android_device.tap(100, 100)
-
+
# Get the state again after the tap action
new_encoded_ui, new_screenshot, new_ui = await android_device.get_state()
print("New state after tap:", new_encoded_ui)
-
+
# Save the second screenshot
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- filename2 = os.path.join(os.path.dirname(__file__), f"screenshot_after_{timestamp}.png")
+ filename2 = os.path.join(
+ os.path.dirname(__file__), f"screenshot_after_{timestamp}.png"
+ )
await save_screenshot(new_screenshot, filename2)
print(f"Screenshot saved as {filename2}")
+
if __name__ == "__main__":
# Run the main function asynchronously
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/cookbook/examplescript2.py b/cookbook/examplescript2.py
index d49b958a..ce05f8a4 100644
--- a/cookbook/examplescript2.py
+++ b/cookbook/examplescript2.py
@@ -1,21 +1,25 @@
import asyncio
+import io
import os
from datetime import datetime
+
from PIL import Image
-import io
+
from mobileadapt import mobileadapt
+
async def save_screenshot(screenshot_data, filename):
image = Image.open(io.BytesIO(screenshot_data))
image.save(filename)
+
async def perform_actions(device):
# Tap actions
await device.tap(200, 300)
print("Tapped at (200, 300)")
await device.tap(100, 400)
print("Tapped at (100, 400)")
-
+
# Swipe actions
await device.swipe("up")
print("Swiped up")
@@ -30,6 +34,7 @@ async def perform_actions(device):
await device.input(150, 500, "Hello, MobileAdapt!")
print("Input text at (150, 500)")
+
async def main():
android_device = mobileadapt(platform="android")
await android_device.start_device()
@@ -37,7 +42,9 @@ async def main():
# Perform initial state capture
encoded_ui, screenshot, ui = await android_device.get_state()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- filename = os.path.join(os.path.dirname(__file__), f"screenshot_initial_{timestamp}.png")
+ filename = os.path.join(
+ os.path.dirname(__file__), f"screenshot_initial_{timestamp}.png"
+ )
await save_screenshot(screenshot, filename)
print(f"Initial screenshot saved as {filename}")
print("Initial UI state:", encoded_ui)
@@ -46,10 +53,12 @@ async def main():
for i in range(3):
print(f"\nPerforming action set {i+1}")
await perform_actions(android_device)
-
+
encoded_ui, screenshot, ui = await android_device.get_state()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- filename = os.path.join(os.path.dirname(__file__), f"screenshot_action{i+1}_{timestamp}.png")
+ filename = os.path.join(
+ os.path.dirname(__file__), f"screenshot_action{i+1}_{timestamp}.png"
+ )
await save_screenshot(screenshot, filename)
print(f"Screenshot after action set {i+1} saved as {filename}")
print(f"UI state after action set {i+1}:", encoded_ui)
@@ -65,10 +74,13 @@ async def main():
# Capture final state
encoded_ui, screenshot, ui = await android_device.get_state()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- filename = os.path.join(os.path.dirname(__file__), f"screenshot_final_{timestamp}.png")
+ filename = os.path.join(
+ os.path.dirname(__file__), f"screenshot_final_{timestamp}.png"
+ )
await save_screenshot(screenshot, filename)
print(f"Final screenshot saved as {filename}")
print("Final UI state:", encoded_ui)
+
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/mobileadapt/__init__.py b/mobileadapt/__init__.py
index b41f96c8..2025cbed 100644
--- a/mobileadapt/__init__.py
+++ b/mobileadapt/__init__.py
@@ -1,6 +1,17 @@
+from .core import MobileAdapt
from .device.device_factory import DeviceFactory
-def mobileadapt(platform: str, app_url: str = None, state_representation='aria', download_directory='default', session_id=None):
- return DeviceFactory.create_device(platform, app_url, state_representation, download_directory, session_id)
-__all__ = ['mobileadapt']
\ No newline at end of file
+def mobileadapt(
+ platform: str,
+ app_url: str = None,
+ state_representation="aria",
+ download_directory="default",
+ session_id=None,
+):
+ return DeviceFactory.create_device(
+ platform, app_url, state_representation, download_directory, session_id
+ )
+
+
+__all__ = ["mobileadapt", "MobileAdapt"]
diff --git a/cookbook/agentic_example.ipynb b/mobileadapt/device/android/__init__.py
similarity index 100%
rename from cookbook/agentic_example.ipynb
rename to mobileadapt/device/android/__init__.py
diff --git a/mobileadapt/device/android/android_device.py b/mobileadapt/device/android/android_device.py
index 52917c9a..15c88aa6 100644
--- a/mobileadapt/device/android/android_device.py
+++ b/mobileadapt/device/android/android_device.py
@@ -1,37 +1,41 @@
import base64
+import os
from datetime import datetime
-from mobileadapt.device.device import Device
+from typing import Tuple
+
+import cv2
from appium import webdriver
from appium.options.android import UiAutomator2Options
-from mobileadapt.device.android.android_view_hierarchy import ViewHierarchy
-from mobileadapt.utils.constants import XML_SCREEN_WIDTH, XML_SCREEN_HEIGHT
-from typing import Tuple
from loguru import logger
-import os
-import cv2
+
# Android Emulator Config
from mobileadapt.device.android.android_ui import UI
+from mobileadapt.device.android.android_view_hierarchy import ViewHierarchy
+from mobileadapt.device.device import Device
+from mobileadapt.utils.constants import XML_SCREEN_HEIGHT, XML_SCREEN_WIDTH
+
class AndroidDevice(Device):
- def __init__(self, app_package, download_directory='default', session_id=None):
+ def __init__(self, app_package, download_directory="default", session_id=None):
super().__init__(app_package)
self.download_directory = download_directory
self.session_id = session_id
self.desired_caps = {
- 'deviceName': 'Android Device',
- 'automationName': 'UiAutomator2',
- 'autoGrantPermission': True,
- 'newCommandTimeout': 600,
- 'mjpegScreenshotUrl': 'http://localhost:4723/stream.mjpeg',
-
+ "deviceName": "Android Device",
+ "automationName": "UiAutomator2",
+ "autoGrantPermission": True,
+ "newCommandTimeout": 600,
+ "mjpegScreenshotUrl": "http://localhost:4723/stream.mjpeg",
}
self.options = UiAutomator2Options().load_capabilities(self.desired_caps)
async def get_state(self) -> Tuple[str, bytes, UI]:
raw_appium_state = self.driver.page_source
- file_path = os.path.join(os.path.dirname(__file__), 'android_view_hierarchy.xml')
- xml_file = open(file_path, 'w')
+ file_path = os.path.join(
+ os.path.dirname(__file__), "android_view_hierarchy.xml"
+ )
+ xml_file = open(file_path, "w")
xml_file.write(raw_appium_state)
xml_file.close()
@@ -60,19 +64,16 @@ async def tap(self, x, y):
async def input(self, x, y, text):
await self.tap(x, y)
- self.driver.execute_script('mobile: type', {'text': text})
+ self.driver.execute_script("mobile: type", {"text": text})
async def drag(self, startX, startY, endX, endY):
self.driver.swipe(startX, startY, endX, endY, duration=1000)
async def scroll(self, direction):
- direction_map = {
- 'up': 'UP',
- 'down': 'DOWN',
- 'left': 'LEFT',
- 'right': 'RIGHT'
- }
- self.driver.execute_script('mobile: scroll', {'direction': direction_map[direction]})
+ direction_map = {"up": "UP", "down": "DOWN", "left": "LEFT", "right": "RIGHT"}
+ self.driver.execute_script(
+ "mobile: scroll", {"direction": direction_map[direction]}
+ )
async def swipe(self, direction):
window_size = self.driver.get_window_size()
@@ -80,14 +81,17 @@ async def swipe(self, direction):
top = window_size["height"] * 0.2
width = window_size["width"] * 0.6
height = window_size["height"] * 0.6
- self.driver.execute_script("mobile: swipeGesture", {
- "left": left,
- "top": top,
- "width": width,
- "height": height,
- "direction": direction,
- "percent": 1.0
- })
+ self.driver.execute_script(
+ "mobile: swipeGesture",
+ {
+ "left": left,
+ "top": top,
+ "width": width,
+ "height": height,
+ "direction": direction,
+ "percent": 1.0,
+ },
+ )
async def start_recording(self):
"""
@@ -133,21 +137,19 @@ async def stop_recording(self, save_path=None):
return save_path
async def stop_device(self):
- '''
+ """
Stops a test
- '''
+ """
pass
- def generate_set_of_mark(self,
- ui,
- image: bytes,
- position='top-left') -> bytes:
- '''
+
+ def generate_set_of_mark(self, ui, image: bytes, position="top-left") -> bytes:
+ """
Code to generate a set of mark for a given image and UI state
ui: UI object
image: bytes of the image
step_i: step number
position: position of the annotation, defaults to 'top-lefts', can also be 'center'
- '''
+ """
# Convert image bytes to numpy array
nparr = np.frombuffer(image, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
@@ -161,7 +163,7 @@ def generate_set_of_mark(self,
ui.elements[element_id].bounding_box.x1,
ui.elements[element_id].bounding_box.y1,
ui.elements[element_id].bounding_box.x2,
- ui.elements[element_id].bounding_box.y2
+ ui.elements[element_id].bounding_box.y2,
]
# Calculate the area of the bounding box
@@ -171,19 +173,22 @@ def generate_set_of_mark(self,
if area > k:
# Draw a rectangle around the element
cv2.rectangle(
- img, (int(bounds[0]), int(bounds[1])),
- (int(bounds[2]), int(bounds[3])), (0, 0, 255), 5)
+ img,
+ (int(bounds[0]), int(bounds[1])),
+ (int(bounds[2]), int(bounds[3])),
+ (0, 0, 255),
+ 5,
+ )
text = str(element_id)
text_size = 2 # Fixed text size
font = cv2.FONT_HERSHEY_SIMPLEX
# Calculate the width and height of the text
- text_width, text_height = cv2.getTextSize(
- text, font, text_size, 2)[0]
+ text_width, text_height = cv2.getTextSize(text, font, text_size, 2)[0]
# Calculate the position of the text
- if position == 'top-left':
+ if position == "top-left":
text_x = int(bounds[0])
text_y = int(bounds[1]) + text_height
else: # Default to center
@@ -191,33 +196,51 @@ def generate_set_of_mark(self,
text_y = (int(bounds[1]) + int(bounds[3])) // 2 + text_height // 2
# Draw a black rectangle behind the text
- cv2.rectangle(img, (text_x, text_y - text_height),
- (text_x + text_width, text_y), (0, 0, 0), thickness=cv2.FILLED)
+ cv2.rectangle(
+ img,
+ (text_x, text_y - text_height),
+ (text_x + text_width, text_y),
+ (0, 0, 0),
+ thickness=cv2.FILLED,
+ )
# Draw the text in white
- cv2.putText(img, text, (text_x, text_y), font,
- text_size, (255, 255, 255), 4)
+ cv2.putText(
+ img, text, (text_x, text_y), font, text_size, (255, 255, 255), 4
+ )
# Convert the image to bytes
- _, img_encoded = cv2.imencode('.png', img)
+ _, img_encoded = cv2.imencode(".png", img)
img_bytes = img_encoded.tobytes()
return img_bytes
async def start_device(self):
- '''
+ """
TODO: implement
- '''
+ """
try:
- self.driver = webdriver.Remote('http://localhost:4723', options=self.options)
+ self.driver = webdriver.Remote(
+ "http://localhost:4723", options=self.options
+ )
except BaseException:
- self.desired_caps.pop('mjpegScreenshotUrl')
+ self.desired_caps.pop("mjpegScreenshotUrl")
self.options = UiAutomator2Options().load_capabilities(self.desired_caps)
- self.driver = webdriver.Remote('http://localhost:4723', options=self.options)
+ self.driver = webdriver.Remote(
+ "http://localhost:4723", options=self.options
+ )
# self.driver.start_recording_screen()
- self.driver.update_settings({'waitForIdleTimeout': 0, 'shouldWaitForQuiescence': False, 'maxTypingFrequency': 60})
+ self.driver.update_settings(
+ {
+ "waitForIdleTimeout": 0,
+ "shouldWaitForQuiescence": False,
+ "maxTypingFrequency": 60,
+ }
+ )
# self.driver.get_screenshot_as_base64()
+
+
# self.driver.execute_script('mobile: startScreenStreaming', {
# 'width': 1080,
# 'height': 1920,
@@ -228,6 +251,6 @@ async def start_device(self):
if __name__ == "__main__":
- ui = UI(os.path.join(os.path.dirname(__file__), 'android_view_hierarchy.xml'))
+ ui = UI(os.path.join(os.path.dirname(__file__), "android_view_hierarchy.xml"))
encoded_ui = ui.encoding()
logger.info(f"Encoded UI: {encoded_ui}")
diff --git a/mobileadapt/device/android/android_ui.py b/mobileadapt/device/android/android_ui.py
index fd8c73c2..31f0f2b4 100644
--- a/mobileadapt/device/android/android_ui.py
+++ b/mobileadapt/device/android/android_ui.py
@@ -1,69 +1,79 @@
CLASS_MAPPING = {
- 'TEXTVIEW': 'p',
- 'BUTTON': 'button',
- 'IMAGEBUTTON': 'button',
- 'IMAGEVIEW': 'img',
- 'EDITTEXT': 'input',
- 'CHECKBOX': 'input',
- 'CHECKEDTEXTVIEW': 'input',
- 'TOGGLEBUTTON': 'button',
- 'RADIOBUTTON': 'input',
- 'SPINNER': 'select',
- 'SWITCH': 'input',
- 'SLIDINGDRAWER': 'input',
- 'TABWIDGET': 'div',
- 'VIDEOVIEW': 'video',
- 'SEARCHVIEW': 'div'
+ "TEXTVIEW": "p",
+ "BUTTON": "button",
+ "IMAGEBUTTON": "button",
+ "IMAGEVIEW": "img",
+ "EDITTEXT": "input",
+ "CHECKBOX": "input",
+ "CHECKEDTEXTVIEW": "input",
+ "TOGGLEBUTTON": "button",
+ "RADIOBUTTON": "input",
+ "SPINNER": "select",
+ "SWITCH": "input",
+ "SLIDINGDRAWER": "input",
+ "TABWIDGET": "div",
+ "VIDEOVIEW": "video",
+ "SEARCHVIEW": "div",
}
from loguru import logger
+
from mobileadapt.device.android.android_view_hierarchy import ViewHierarchy
-from mobileadapt.utils.constants import XML_SCREEN_WIDTH, XML_SCREEN_HEIGHT
+from mobileadapt.utils.constants import XML_SCREEN_HEIGHT, XML_SCREEN_WIDTH
+
def sortchildrenby_viewhierarchy(view, attr="bounds"):
- if attr == 'bounds':
- bounds = [(ele.uiobject.bounding_box.x1, ele.uiobject.bounding_box.y1,
- ele.uiobject.bounding_box.x2, ele.uiobject.bounding_box.y2)
- for ele in view]
+ if attr == "bounds":
+ bounds = [
+ (
+ ele.uiobject.bounding_box.x1,
+ ele.uiobject.bounding_box.y1,
+ ele.uiobject.bounding_box.x2,
+ ele.uiobject.bounding_box.y2,
+ )
+ for ele in view
+ ]
sorted_bound_index = [
- bounds.index(i) for i in sorted(
- bounds, key=lambda x: (
- x[1], x[0]))]
+ bounds.index(i) for i in sorted(bounds, key=lambda x: (x[1], x[0]))
+ ]
sort_children = [view[i] for i in sorted_bound_index]
view[:] = sort_children
-
-class UI():
+class UI:
def __init__(self, xml_file):
self.xml_file = xml_file
self.elements = {}
def encoding(self):
- logger.info('reading hierarchy tree from {} ...'.format(
- self.xml_file.split('/')[-1]))
- with open(self.xml_file, 'r', encoding='utf-8') as f:
+ logger.info(
+ "reading hierarchy tree from {} ...".format(self.xml_file.split("/")[-1])
+ )
+ with open(self.xml_file, "r", encoding="utf-8") as f:
vh_data = f.read().encode()
vh = ViewHierarchy(
- screen_width=XML_SCREEN_WIDTH,
- screen_height=XML_SCREEN_HEIGHT)
+ screen_width=XML_SCREEN_WIDTH, screen_height=XML_SCREEN_HEIGHT
+ )
vh.load_xml(vh_data)
view_hierarchy_leaf_nodes = vh.get_leaf_nodes()
sortchildrenby_viewhierarchy(view_hierarchy_leaf_nodes)
- logger.debug('encoding the ui elements in hierarchy tree...')
- codes = ''
+ logger.debug("encoding the ui elements in hierarchy tree...")
+ codes = ""
# logger.info(view_hierarchy_leaf_nodes)
for _id, ele in enumerate(view_hierarchy_leaf_nodes):
obj_type = ele.uiobject.obj_type.name
text = ele.uiobject.text
- text = text.replace('\n', ' ')
- resource_id = ele.uiobject.resource_id if ele.uiobject.resource_id is not None else ''
+ text = text.replace("\n", " ")
+ resource_id = (
+ ele.uiobject.resource_id if ele.uiobject.resource_id is not None else ""
+ )
content_desc = ele.uiobject.content_desc
html_code = self.element_encoding(
- _id, obj_type, text, content_desc, resource_id)
+ _id, obj_type, text, content_desc, resource_id
+ )
codes += html_code
self.elements[_id] = ele.uiobject
codes = "\n" + codes + ""
@@ -71,29 +81,23 @@ def encoding(self):
# logger.info('Encoded UI\n' + codes)
return codes
- def element_encoding(
- self,
- _id,
- _obj_type,
- _text,
- _content_desc,
- _resource_id):
+ def element_encoding(self, _id, _obj_type, _text, _content_desc, _resource_id):
- _class = _resource_id.split('id/')[-1].strip()
+ _class = _resource_id.split("id/")[-1].strip()
_text = _text.strip()
assert _obj_type in CLASS_MAPPING.keys(), print(_obj_type)
tag = CLASS_MAPPING[_obj_type]
- if _obj_type in ['CHECKBOX', 'CHECKEDTEXTVIEW', 'SWITCH']:
+ if _obj_type in ["CHECKBOX", "CHECKEDTEXTVIEW", "SWITCH"]:
code = f' \n'
- code += f' \n'
- elif _obj_type == 'RADIOBUTTON':
+ code += f" \n"
+ elif _obj_type == "RADIOBUTTON":
code = f' \n'
- code += f' \n'
- elif _obj_type == 'SPINNER':
- code = f' \n'
+ code += f" \n"
+ elif _obj_type == "SPINNER":
+ code = f" \n"
code += f' \n'
- elif _obj_type == 'IMAGEVIEW':
+ elif _obj_type == "IMAGEVIEW":
if _class == "":
code = f' \n'
else:
@@ -106,4 +110,3 @@ def element_encoding(
_text = _content_desc if _text == "" else _text
code = f' <{tag} id={_id} class="{_class}">{_text}{tag}>\n'
return code
-
diff --git a/mobileadapt/device/android/android_view_hierarchy.py b/mobileadapt/device/android/android_view_hierarchy.py
new file mode 100644
index 00000000..742450d1
--- /dev/null
+++ b/mobileadapt/device/android/android_view_hierarchy.py
@@ -0,0 +1,754 @@
+import collections
+import json
+import re
+from enum import Enum
+
+import attr
+import numpy as np
+from lxml import etree
+from str2bool import str2bool as strtobool
+
+import mobileadapt.utils.constants as config
+
+
+class UIObjectType(Enum):
+ """Types of the different UI objects."""
+
+ UNKNOWN = 0
+ BUTTON = 1
+ CHECKBOX = 2
+ CHECKEDTEXTVIEW = 3
+ EDITTEXT = 4
+ IMAGEBUTTON = 5
+ IMAGEVIEW = 6
+ RADIOBUTTON = 7
+ SLIDINGDRAWER = 8
+ SPINNER = 9
+ SWITCH = 10
+ TABWIDGET = 11
+ TEXTVIEW = 12
+ TOGGLEBUTTON = 13
+ VIDEOVIEW = 14
+ SEARCHVIEW = 15
+
+
+class UIObjectGridLocation(Enum):
+ """The on-screen grid location (3x3 grid) of an UI object."""
+
+ TOP_LEFT = 0
+ TOP_CENTER = 1
+ TOP_RIGHT = 2
+ LEFT = 3
+ CENTER = 4
+ RIGHT = 5
+ BOTTOM_LEFT = 6
+ BOTTOM_CENTER = 7
+ BOTTOM_RIGHT = 8
+
+
+@attr.s
+class BoundingBox(object):
+ """The bounding box with horizontal/vertical coordinates of an UI object."""
+
+ x1 = attr.ib()
+ y1 = attr.ib()
+ x2 = attr.ib()
+ y2 = attr.ib()
+
+
+@attr.s
+class UIObject(object):
+ """Represents an UI object from the leaf node in the view hierarchy."""
+
+ obj_type = attr.ib()
+ obj_name = attr.ib()
+ word_sequence = attr.ib()
+ text = attr.ib()
+ resource_id = attr.ib()
+ android_class = attr.ib()
+ android_package = attr.ib()
+ content_desc = attr.ib()
+ clickable = attr.ib()
+ visible = attr.ib()
+ enabled = attr.ib()
+ focusable = attr.ib()
+ focused = attr.ib()
+ scrollable = attr.ib()
+ long_clickable = attr.ib()
+ selected = attr.ib()
+ bounding_box = attr.ib()
+ grid_location = attr.ib()
+ dom_location = attr.ib()
+ pointer = attr.ib()
+ neighbors = attr.ib()
+
+
+def _build_word_sequence(text, content_desc, resource_id):
+ """Returns a sequence of word tokens based on certain attributes.
+ Args:
+ text: `text` attribute of an element.
+ content_desc: `content_desc` attribute of an element.
+ resource_id: `resource_id` attribute of an element.
+ Returns:
+ A sequence of word tokens.
+ """
+ if text or content_desc:
+ return re.findall(r"[\w']+|[?.!/,;:]", text if text else content_desc)
+ else:
+ # logger.info(f"Resource ID: {resource_id}")
+ if resource_id is not None:
+ name = resource_id.split("/")[-1]
+ return filter(None, name.split("_"))
+ else:
+ return []
+
+
+def _build_object_type(android_class):
+ """Returns the object type based on `class` attribute.
+ Args:
+ android_class: `class` attribute of an element (Android class).
+ Returns:
+ The UIObjectType enum.
+ """
+ if android_class.startswith("android.widget"):
+ widget_type = android_class.split(".")[2]
+ for obj_type in UIObjectType:
+ if obj_type.name == widget_type.upper():
+ return obj_type
+ widget_type = android_class.split(".")[-1]
+ for obj_type in UIObjectType:
+ if obj_type.name in widget_type.upper():
+ return obj_type
+ return UIObjectType.BUTTON
+
+
+def _build_object_name(text, content_desc):
+ """Returns the object name based on `text` or `content_desc` attribute.
+ Args:
+ text: The `text` attribute.
+ content_desc: The `content_desc` attribute.
+ Returns:
+ The object name string.
+ """
+ return text if text else content_desc
+
+
+def _build_bounding_box(bounds):
+ """Returns the object bounding box based on `bounds` attribute.
+ Args:
+ bounds: The `bounds` attribute.
+ Returns:
+ The BoundingBox object.
+ """
+ match = re.compile(r"\[(\d+),(\d+)\]\[(\d+),(\d+)\]").match(bounds)
+ assert match, f"Invalid bounds format: {bounds}"
+ x1, y1, x2, y2 = map(int, match.groups())
+ return BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2)
+
+
+def _build_clickable(element, tree_child_as_clickable=True):
+ """Returns whether the element is clickable or one of its ancestors is.
+ Args:
+ element: The etree.Element object.
+ tree_child_as_clickable: treat all tree children as clickable
+ Returns:
+ A boolean to indicate whether the element is clickable or one of its
+ ancestors is.
+ """
+ clickable = element.get("clickable")
+ if clickable == "false":
+ for node in element.iterancestors():
+ if node.get("clickable") == "true":
+ clickable = "true"
+ break
+
+ # Below code is try to fix that: some target UI have 'clickable==False'
+ # but it's clickable by human actually
+
+ # Checkable elemnts should also be treated as clickable
+ # Some menu items may have clickable==False but checkable==True
+ if element.get("checkable") == "true":
+ clickable = "true"
+ if tree_child_as_clickable:
+ p = element.getparent()
+ while p is not None:
+ if p.get("class") == "android.widget.ListView":
+ clickable = "true"
+ break
+ p = p.getparent()
+
+ return strtobool(clickable)
+
+
+def _pixel_distance(a_x1, a_x2, b_x1, b_x2):
+ """Calculates the pixel distance between bounding box a and b.
+ Args:
+ a_x1: The x1 coordinate of box a.
+ a_x2: The x2 coordinate of box a.
+ b_x1: The x1 coordinate of box b.
+ b_x2: The x2 coordinate of box b.
+ Returns:
+ The pixel distance between box a and b on the x axis. The distance
+ on the y axis can be calculated in the same way. The distance can be
+ positive number (b is right/bottom to a) and negative number
+ (b is left or top to a).
+ """
+ # if a and b are close enough, then we set the their distance to be 1
+ # because there are typically padding spaces inside an object's bounding
+ # box
+ if b_x1 <= a_x2 and a_x2 - b_x1 <= config.ADJACENT_BOUNDING_BOX_THRESHOLD:
+ return 1
+ if a_x1 <= b_x2 and b_x2 - a_x1 <= config.ADJACENT_BOUNDING_BOX_THRESHOLD:
+ return -1
+ # overlap
+ if (
+ (a_x1 <= b_x1 <= a_x2)
+ or (a_x1 <= b_x2 <= a_x2)
+ or (b_x1 <= a_x1 <= b_x2)
+ or (b_x1 <= a_x2 <= b_x2)
+ ):
+ return 0
+ elif b_x1 > a_x2:
+ return b_x1 - a_x2
+ else:
+ return b_x2 - a_x1
+
+
+def _grid_coordinate(x, width):
+ """Calculates the 3x3 grid coordinate on the x axis.
+ The grid coordinate on the y axis is calculated in the same way.
+ Args:
+ x: The x coordinate: [0, width).
+ width: The screen width.
+ Returns:
+ The grid coordinate: [0, 2].
+ Note that the screen is divided into 3x3 grid, so the grid coordinate
+ uses the number from 0, 1, 2.
+ """
+ assert 0 <= x <= width
+ grid_x_0 = width / 3
+ grid_x_1 = 2 * grid_x_0
+ if 0 <= x < grid_x_0:
+ grid_coordinate_x = 0
+ elif grid_x_0 <= x < grid_x_1:
+ grid_coordinate_x = 1
+ else:
+ grid_coordinate_x = 2
+ return grid_coordinate_x
+
+
+def _grid_location(bbox, screen_width, screen_height):
+ """Calculates the grid number of the UI object's bounding box.
+ The screen can be divided into 3x3 grid:
+ (0, 0) (0, 1) (0, 2) 0 1 2
+ (1, 0) (1, 1) (1, 2) ---> 3 4 5
+ (2, 0) (2, 1) (2, 2) 6 7 8
+ Args:
+ bbox: The bounding box of the UI object.
+ screen_width: The width of the screen associated with the hierarchy.
+ screen_height: The height of the screen associated with the hierarchy.
+ Returns:
+ The grid location number.
+ """
+ bbox_center_x = (bbox.x1 + bbox.x2) / 2
+ bbox_center_y = (bbox.y1 + bbox.y2) / 2
+ bbox_grid_x = _grid_coordinate(bbox_center_x, screen_width)
+ bbox_grid_y = _grid_coordinate(bbox_center_y, screen_height)
+ return UIObjectGridLocation(bbox_grid_y * 3 + bbox_grid_x)
+
+
+def get_view_hierarchy_leaf_relation(objects, _screen_width, _screen_height):
+ """Calculates adjacency relation from list of view hierarchy leaf nodes.
+ Args:
+ objects: a list of objects.
+ _screen_width, _screen_height: Screen width and height.
+ Returns:
+ An un-padded feature dictionary as follow:
+ 'v_distance': 2d numpy array of ui object vertical adjacency relation.
+ 'h_distance': 2d numpy array of ui object horizontal adjacency relation.
+ 'dom_distance': 2d numpy array of ui object dom adjacency relation.
+ """
+ vh_node_num = len(objects)
+ vertical_adjacency = np.zeros((vh_node_num, vh_node_num))
+ horizontal_adjacency = np.zeros((vh_node_num, vh_node_num))
+ for row in range(len(objects)):
+ for column in range(len(objects)):
+ if row == column:
+ h_dist = v_dist = 0
+ else:
+ node1 = objects[row]
+ node2 = objects[column]
+ h_dist, v_dist = normalized_pixel_distance(
+ node1, node2, _screen_width, _screen_height
+ )
+ # print(node1.text, node2.text, v_dist)
+ vertical_adjacency[row][column] = v_dist
+ horizontal_adjacency[row][column] = h_dist
+ return {"v_distance": vertical_adjacency, "h_distance": horizontal_adjacency}
+
+
+def _get_single_direction_neighbors(object_idx, ui_v_dist, ui_h_dist):
+ """Gets four 'single direction neighbors' for one target ui_object.
+ If B is A's bottom/top 'single direction neighbor', it means B is the
+ vertical closest neighbor among all object whose horizontal distance to A is
+ smaller than margin threshold. Same with left/right direction neighbor.
+ Args:
+ object_idx: index number of target ui_object in ui_object_list
+ ui_v_dist: ui objects' vertical distances. shape=[num_ui_obj, num_ui_obj]
+ ui_h_dist: ui objects' horizontal distances. shape=[num_ui_obj, num_ui_obj]
+ Returns:
+ a dictionary, keys are NeighborContextDesc Instance, values are neighbor
+ object index.
+ """
+ neighbor_dict = {}
+ vertical_dist = ui_v_dist[object_idx]
+ horizontal_dist = ui_h_dist[object_idx]
+ bottom_neighbors = np.array(
+ [
+ idx
+ for idx in range(len(vertical_dist))
+ if vertical_dist[idx] > 0
+ and abs(horizontal_dist[idx]) < config.NORM_HORIZONTAL_NEIGHBOR_MARGIN
+ ]
+ )
+ top_neighbors = np.array(
+ [
+ idx
+ for idx in range(len(vertical_dist))
+ if vertical_dist[idx] < 0
+ and abs(horizontal_dist[idx]) < config.NORM_HORIZONTAL_NEIGHBOR_MARGIN
+ ]
+ )
+ right_neighbors = np.array(
+ [
+ idx
+ for idx in range(len(horizontal_dist))
+ if horizontal_dist[idx] > 0
+ and abs(vertical_dist[idx]) < config.NORM_VERTICAL_NEIGHBOR_MARGIN
+ ]
+ )
+ left_neighbors = np.array(
+ [
+ idx
+ for idx in range(len(horizontal_dist))
+ if horizontal_dist[idx] < 0
+ and abs(vertical_dist[idx]) < config.NORM_VERTICAL_NEIGHBOR_MARGIN
+ ]
+ )
+
+ if bottom_neighbors.size:
+ neighbor_dict["top"] = bottom_neighbors[
+ np.argmin(vertical_dist[bottom_neighbors])
+ ]
+ if top_neighbors.size:
+ neighbor_dict["bottom"] = top_neighbors[np.argmax(vertical_dist[top_neighbors])]
+ if right_neighbors.size:
+ neighbor_dict["left"] = right_neighbors[
+ np.argmin(horizontal_dist[right_neighbors])
+ ]
+ if left_neighbors.size:
+ neighbor_dict["right"] = left_neighbors[
+ np.argmax(horizontal_dist[left_neighbors])
+ ]
+
+ return neighbor_dict
+
+
+def normalized_pixel_distance(node1, node2, _screen_width, _screen_height):
+ """Calculates normalized pixel distance between this node and other node.
+ Args:
+ node1, node2: Another object.
+ _screen_width, _screen_height: Screen width and height.
+ Returns:
+ Normalized pixel distance on both horizontal and vertical direction.
+ """
+ h_distance = _pixel_distance(
+ _build_bounding_box(node1.get("bounds")).x1,
+ _build_bounding_box(node1.get("bounds")).x2,
+ _build_bounding_box(node2.get("bounds")).x1,
+ _build_bounding_box(node2.get("bounds")).x2,
+ )
+ v_distance = _pixel_distance(
+ _build_bounding_box(node1.get("bounds")).y1,
+ _build_bounding_box(node1.get("bounds")).y2,
+ _build_bounding_box(node2.get("bounds")).y1,
+ _build_bounding_box(node2.get("bounds")).y2,
+ )
+
+ return float(h_distance) / _screen_width, float(v_distance) / _screen_height
+
+
+def _build_neighbors(node, view_hierarchy_leaf_nodes, _screen_width, _screen_height):
+ """Builds the neighbours from view_hierarchy.
+ Args:
+ node: The current etree root node.
+ view_hierarchy_leaf_nodes: All of the etree nodes.
+ _screen_width, _screen_height: Screen width and height.
+ Returns:
+ Neighbour directions and object pointers.
+ """
+ if view_hierarchy_leaf_nodes is None:
+ return None
+ vh_relation = get_view_hierarchy_leaf_relation(
+ view_hierarchy_leaf_nodes, _screen_width, _screen_height
+ )
+ _neighbor = _get_single_direction_neighbors(
+ view_hierarchy_leaf_nodes.index(node),
+ vh_relation["v_distance"],
+ vh_relation["h_distance"],
+ )
+ for k, v in _neighbor.items():
+ _neighbor[k] = view_hierarchy_leaf_nodes[v].get("pointer")
+ return _neighbor
+
+
+def _build_etree_from_json(root, json_dict):
+ """Builds the element tree from json_dict.
+ Args:
+ root: The current etree root node.
+ json_dict: The current json_dict corresponding to the etree root node.
+ """
+ # set node attributes
+ if root is None or json_dict is None:
+ return
+ x1, y1, x2, y2 = json_dict.get("bounds", [0, 0, 0, 0])
+ root.set("bounds", "[%d,%d][%d,%d]" % (x1, y1, x2, y2))
+ root.set("class", json_dict.get("class", ""))
+ # XML element cannot contain NULL bytes.
+ root.set("text", json_dict.get("text", "").replace("\x00", ""))
+ root.set("resource-id", json_dict.get("resource-id", ""))
+ content_desc = json_dict.get("content-desc", [None])
+ root.set(
+ "content-desc",
+ "" if content_desc[0] is None else content_desc[0].replace("\x00", ""),
+ )
+ root.set("package", json_dict.get("package", ""))
+ root.set("visible", str(json_dict.get("visible-to-user", True)))
+ root.set("enabled", str(json_dict.get("enabled", False)))
+ root.set("focusable", str(json_dict.get("focusable", False)))
+ root.set("focused", str(json_dict.get("focused", False)))
+ root.set(
+ "scrollable",
+ str(
+ json_dict.get("scrollable-horizontal", False)
+ or json_dict.get("scrollable-vertical", False)
+ ),
+ )
+ root.set("clickable", str(json_dict.get("clickable", False)))
+ root.set("long-clickable", str(json_dict.get("long-clickable", False)))
+ root.set("selected", str(json_dict.get("selected", False)))
+ root.set("pointer", str(json_dict.get("pointer", "")))
+ if "children" not in json_dict: # leaf node
+ return
+ for child in json_dict["children"]:
+ # some json file has 'null' as one of the children.
+ if child:
+ child_node = etree.Element("node")
+ root.append(child_node)
+ _build_etree_from_json(child_node, child)
+
+
+class LeafNode(object):
+ """Represents a leaf node in the view hierarchy data from xml."""
+
+ def __init__(
+ self,
+ element,
+ all_elements=None,
+ dom_location=None,
+ screen_width=config.SCREEN_WIDTH,
+ screen_height=config.SCREEN_HEIGHT,
+ ):
+ """Constructor.
+ Args:
+ element: The etree.Element object.
+ all_elements: All the etree.Element objects in the view hierarchy.
+ dom_location: [depth, preorder-index, postorder-index] of element.
+ screen_width: The width of the screen associated with the element.
+ screen_height: The height of the screen associated with the element.
+ """
+ assert not element.findall(".//node")
+ self.element = element
+ self._screen_width = screen_width
+ self._screen_height = screen_height
+ # logger.info(f"element: {element}")
+ bbox = _build_bounding_box(element.get("bounds"))
+ self.uiobject = UIObject(
+ obj_type=_build_object_type(element.get("class")),
+ obj_name=_build_object_name(
+ element.get("text"), element.get("content-desc")
+ ),
+ word_sequence=_build_word_sequence(
+ element.get("text"),
+ element.get("content-desc"),
+ element.get("resource-id"),
+ ),
+ text=element.get("text"),
+ resource_id=element.get("resource-id"),
+ android_class=element.get("class"),
+ android_package=element.get("package"),
+ content_desc=element.get("content-desc"),
+ clickable=_build_clickable(element),
+ visible=strtobool(element.get("visible", default="true")),
+ enabled=strtobool(element.get("enabled")),
+ focusable=strtobool(element.get("focusable")),
+ focused=strtobool(element.get("focused")),
+ scrollable=strtobool(element.get("scrollable")),
+ long_clickable=strtobool(element.get("long-clickable")),
+ selected=strtobool(element.get("selected")),
+ bounding_box=bbox,
+ grid_location=_grid_location(bbox, self._screen_width, self._screen_height),
+ dom_location=dom_location,
+ pointer=element.get("pointer"),
+ neighbors=_build_neighbors(
+ element, all_elements, self._screen_width, self._screen_height
+ ),
+ )
+
+ def dom_distance(self, other_node):
+ """Calculates dom distance between this node and other node.
+ Args:
+ other_node: Another LeafNode object.
+ Returns:
+ The dom distance in between two leaf nodes: defined as the number of
+ nodes on the path from one leaf node to the other on the tree.
+ """
+ intersection = [
+ node
+ for node in self.element.iterancestors()
+ if node in other_node.element.iterancestors()
+ ]
+ assert intersection
+ ancestor_list = list(self.element.iterancestors())
+ other_ancestor_list = list(other_node.element.iterancestors())
+ return (
+ ancestor_list.index(intersection[0])
+ + other_ancestor_list.index(intersection[0])
+ + 1
+ )
+
+
+class DomLocationKey(Enum):
+ """Keys of dom location info."""
+
+ DEPTH = 0
+ PREORDER_INDEX = 1
+ POSTORDER_INDEX = 2
+
+
+class ViewHierarchy(object):
+ """Represents the view hierarchy data from UIAutomator dump."""
+
+ def __init__(
+ self, screen_width=config.SCREEN_WIDTH, screen_height=config.SCREEN_HEIGHT
+ ):
+ """Constructor.
+ Args:
+ screen_width: The pixel width of the screen for the view hierarchy.
+ screen_height: The pixel height of the screen for the view hierarchy.
+ """
+ self._root = None
+ self._root_element = None
+ self._all_visible_leaves = []
+ self._dom_location_dict = None
+ self._preorder_index = 0
+ self._postorder_index = 0
+ self._screen_width = screen_width
+ self._screen_height = screen_height
+
+ def load_xml(self, xml_content):
+ """Builds the etree from xml content.
+ Args:
+ xml_content: The string containing xml content.
+ """
+ self._root = etree.XML(xml_content)
+ self._root_element = self._root[0]
+
+ self._all_visible_leaves = self._get_visible_leaves()
+
+ # dom_location_dict:
+ # dict of {id(element): [depth, preorder-index, postorder-index]}
+ # Note: for leaves of any tree, the following equation is always true:
+ #
+ # depth == preorder-index - postorder-index (depth is # of ancestors)
+ #
+ self._dom_location_dict = self._calculate_dom_location()
+
+ def load_json(self, json_content):
+ """Builds the etree from json content.
+ Args:
+ json_content: The string containing json content.
+ """
+ json_dict = json.loads(json_content)
+ if json_dict is None:
+ raise ValueError("empty json file.")
+ self._root = etree.Element("hierarchy", rotation="0")
+ self._root_element = etree.Element("node")
+ self._root.append(self._root_element)
+ _build_etree_from_json(self._root_element, json_dict["activity"]["root"])
+
+ self._all_visible_leaves = self._get_visible_leaves()
+ self._dom_location_dict = self._calculate_dom_location()
+
+ def get_leaf_nodes(self):
+ """Returns a list of all the leaf Nodes."""
+ return [
+ LeafNode(
+ element,
+ self._all_visible_leaves,
+ self._dom_location_dict[id(element)],
+ self._screen_width,
+ self._screen_height,
+ )
+ for element in self._all_visible_leaves
+ ]
+
+ def get_ui_objects(self):
+ """Returns a list of all ui objects represented by leaf nodes."""
+ return [
+ LeafNode(
+ element,
+ self._all_visible_leaves,
+ self._dom_location_dict[id(element)],
+ self._screen_width,
+ self._screen_height,
+ ).uiobject
+ for element in self._all_visible_leaves
+ ]
+
+ def dedup(self, click_x_and_y):
+ """Dedup UI objects with same text or content_desc.
+ Args:
+ click_x_and_y: the event x and y (like: click pos in screen)
+ """
+ click_x, click_y = click_x_and_y
+
+ # Map of {'name': [list of UI objects with this name]}
+ name_element_map = collections.defaultdict(list)
+ for element in self._all_visible_leaves:
+ name = _build_object_name(element.get("text"), element.get("content_desc"))
+ name_element_map[name].append(element)
+
+ def delete_element(element):
+ element.getparent().remove(element)
+
+ for name, elements in name_element_map.items():
+ if not name:
+ continue
+ # Search if the event (x, y) happens in one of these objects
+ target_index = None
+ for index, element in enumerate(elements):
+ box = _build_bounding_box(element.get("bounds"))
+ if box.x1 <= click_x <= box.x2 and box.y1 <= click_y <= box.y2:
+ target_index = index
+
+ if target_index is None: # target UI obj is not in this elements
+ for ele in elements[1:]:
+ delete_element(ele)
+ else: # if target UI obj is one of them, delete the rest UI objs
+ for ele in elements[:target_index] + elements[target_index + 1 :]:
+ delete_element(ele)
+
+ print(
+ "Dedup: %d -> %d"
+ % (len(self._all_visible_leaves), len(self._get_visible_leaves()))
+ )
+
+ self._all_visible_leaves = self._get_visible_leaves()
+ self._dom_location_dict = self._calculate_dom_location()
+
+ def _get_visible_leaves(self):
+ """Gets all the visible leaves from view hierarchy.
+ Returns:
+ all_visible_leaves: The list of all the visible leaf elements.
+ """
+
+ all_elements = [element for element in self._root.iter("*")]
+
+ all_visible_leaves = [
+ element
+ for element in all_elements
+ if self._is_leaf(element)
+ and strtobool(element.attrib.get("displayed", default="true"))
+ and self._is_within_screen_bound(element)
+ ]
+ return all_visible_leaves
+
+ def _calculate_dom_location(self):
+ """Calculate [depth, preorder-index, postorder-index] of all leaf nodes.
+ This method is NOT thread safe if multiple threads call this method of same
+ ViewHierarchy object: This method keeps updating self._preorder_index
+ and self._postorder_index when call pre/post travel method recursively.
+ All leaf elements will be filted and cached in self._all_visible_leaves.
+ This is necessary because dom_location_dict use id(element) as keys, if
+ call _root.iter('*') every time, the id(element) will not be a fixed value
+ even for same element in XML.
+ Returns:
+ dom_location_dict, dict of
+ {id(element): [depth, preorder-index, postorder-index]}
+ """
+ dom_location_dict = collections.defaultdict(lambda: [None, None, None])
+ # Calculate the depth of all leaf nodes.
+ for element in self._all_visible_leaves:
+ ancestors = [node for node in element.iterancestors()]
+ dom_location_dict[id(element)][DomLocationKey.DEPTH.value] = len(ancestors)
+
+ # Calculate the pre/post index by calling pre/post iteration
+ # recursively.
+ self._preorder_index = 0
+ self._pre_order_iterate(self._root, dom_location_dict)
+ self._postorder_index = 0
+ self._post_order_iterate(self._root, dom_location_dict)
+ return dom_location_dict
+
+ def _pre_order_iterate(self, element, dom_location_dict):
+ """Preorder travel on hierarchy tree.
+ Args:
+ element: etree element which will be visited now.
+ dom_location_dict: dict of
+ {id(element): [depth, preorder-index, postorder-index]}
+ """
+ if self._is_leaf(element):
+ dom_location_dict[id(element)][
+ DomLocationKey.PREORDER_INDEX.value
+ ] = self._preorder_index
+ self._preorder_index += 1
+
+ for child in element:
+ if child.getparent() == element:
+ self._pre_order_iterate(child, dom_location_dict)
+
+ def _post_order_iterate(self, element, dom_location_dict):
+ """Postorder travel on hierarchy tree.
+ Args:
+ element: etree element which will be visited now.
+ dom_location_dict: dict of
+ {id(element): [depth, preorder-index, postorder-index]}
+ """
+ for child in element:
+ if child.getparent() == element:
+ self._post_order_iterate(child, dom_location_dict)
+
+ if self._is_leaf(element):
+ dom_location_dict[id(element)][
+ DomLocationKey.POSTORDER_INDEX.value
+ ] = self._postorder_index
+ self._postorder_index += 1
+
+ def _is_leaf(self, element):
+ """Whether an etree element is leaf in hierachy tree."""
+
+ return not element.findall(".//*")
+
+ def _is_within_screen_bound(self, element):
+ """Whether an etree element's bounding box is within screen boundary."""
+ bbox = _build_bounding_box(element.attrib.get("bounds"))
+ in_x = (0 <= bbox.x1 <= self._screen_width) and (
+ 0 <= bbox.x2 <= self._screen_width
+ )
+ in_y = (0 <= bbox.y1 <= self._screen_height) and (
+ 0 <= bbox.y2 <= self._screen_height
+ )
+ x1_less_than_x2 = bbox.x1 < bbox.x2
+ y1_less_than_y2 = bbox.y1 < bbox.y2
+ return in_x and in_y and x1_less_than_x2 and y1_less_than_y2
diff --git a/mobileadapt/device/device.py b/mobileadapt/device/device.py
index c654f425..08c4ca0e 100644
--- a/mobileadapt/device/device.py
+++ b/mobileadapt/device/device.py
@@ -7,16 +7,16 @@ def __init__(self, app_package):
@abstractmethod
def start_device(self):
- '''
+ """
Function to start device
- '''
+ """
pass
@abstractmethod
def stop_device(self):
- '''
+ """
Function to stop device
- '''
+ """
pass
@abstractmethod
diff --git a/mobileadapt/device/device_factory.py b/mobileadapt/device/device_factory.py
index 0b2d84df..55e72f30 100644
--- a/mobileadapt/device/device_factory.py
+++ b/mobileadapt/device/device_factory.py
@@ -1,21 +1,32 @@
# device/device_factory.py
# from .device import Device
+from loguru import logger
+
from mobileadapt.device.android.android_device import AndroidDevice
from mobileadapt.device.ios_device import IOSDevice
-from loguru import logger
class DeviceFactory:
@staticmethod
- def create_device(platform: str, app_url: str,
- state_representation='aria', download_directory='default', session_id=None):
- if platform == 'android':
- return AndroidDevice(app_package=app_url, download_directory=download_directory, session_id=session_id)
- elif platform == 'ios':
+ def create_device(
+ platform: str,
+ app_url: str,
+ state_representation="aria",
+ download_directory="default",
+ session_id=None,
+ ):
+ if platform == "android":
+ return AndroidDevice(
+ app_package=app_url,
+ download_directory=download_directory,
+ session_id=session_id,
+ )
+ elif platform == "ios":
return IOSDevice(app_url)
- elif platform == 'web':
- raise NotImplementedError("Web device is not implemented in open source version check out https://revyl.ai")
+ elif platform == "web":
+ raise NotImplementedError(
+ "Web device is not implemented in open source version check out https://revyl.ai"
+ )
else:
- raise ValueError(
- "Invalid type. Expected one of: 'android', 'web'.")
+ raise ValueError("Invalid type. Expected one of: 'android', 'web'.")
diff --git a/mobileadapt/device/ios_device.py b/mobileadapt/device/ios_device.py
index 45818d0f..60c42101 100644
--- a/mobileadapt/device/ios_device.py
+++ b/mobileadapt/device/ios_device.py
@@ -2,7 +2,7 @@
class IOSDevice(Device):
- def __init__(self, app_start_url=''):
+ def __init__(self, app_start_url=""):
pass
def get_state(self):
diff --git a/mobileadapt/mobileadapt.py b/mobileadapt/mobileadapt.py
deleted file mode 100644
index 6e4942ab..00000000
--- a/mobileadapt/mobileadapt.py
+++ /dev/null
@@ -1,76 +0,0 @@
-from device.device_factory import DeviceFactory
-
-'''
-This file defines the MobileAdapt class, which provides a high-level interface
-for interacting with mobile devices, and includes a factory function for creating instances.
-'''
-class MobileAdapt:
- def __init__(self, platform: str, app_url: str):
- """
- Initialize the MobileAdapt instance.
-
- Args:
- platform (str): The mobile platform (e.g., 'android' or 'ios').
- app_url (str): The URL or path to the mobile application.
- """
- self.device = DeviceFactory.create_device(platform, app_url)
-
- async def initialize(self):
- """
- Initialize the device by starting it.
-
- This method should be called before performing any other operations.
- """
- await self.device.start_device()
-
- async def get_state(self):
- """
- Retrieve the current state of the device.
-
- Returns:
- The current state representation of the device UI.
- """
- return await self.device.get_state()
-
- async def tap(self, x: int, y: int):
- """
- Perform a tap action on the device screen.
-
- Args:
- x (int): The x-coordinate of the tap location.
- y (int): The y-coordinate of the tap location.
- """
- await self.device.tap(x, y)
-
- async def input(self, x: int, y: int, text: str):
- """
- Input text at a specific location on the device screen.
-
- Args:
- x (int): The x-coordinate of the input location.
- y (int): The y-coordinate of the input location.
- text (str): The text to be input.
- """
- await self.device.input(x, y, text)
-
- async def swipe(self, direction: str):
- """
- Perform a swipe action on the device screen.
-
- Args:
- direction (str): The direction of the swipe (e.g., 'up', 'down', 'left', 'right').
- """
- await self.device.swipe(direction)
-
-def mobileadapt(platform: str, app_url: str = None):
- """
- Create and return a MobileAdapt instance.
-
- Args:
- platform (str): The mobile platform (e.g., 'android' or 'ios').
- app_url (str, optional): The URL or path to the mobile application.
-
- Returns:
- MobileAdapt: An instance of the MobileAdapt class.
- """
- return MobileAdapt(platform, app_url)
\ No newline at end of file
diff --git a/mobileadapt/utils/constants.py b/mobileadapt/utils/constants.py
index 62aac522..54d2f788 100644
--- a/mobileadapt/utils/constants.py
+++ b/mobileadapt/utils/constants.py
@@ -1,4 +1,3 @@
-
# Android Emulator Config
SCREEN_WIDTH = 1080
SCREEN_HEIGHT = 1920
diff --git a/pyproject.toml b/pyproject.toml
index dcf435ea..56be3cbf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,4 +38,9 @@ opencv-python = "^4.10.0.84"
pytest = "^6.2"
[tool.poetry.urls]
-"Bug Tracker" = "https://github.com/RevylAI/Mobileadapt/issues"
\ No newline at end of file
+"Bug Tracker" = "https://github.com/RevylAI/Mobileadapt/issues"
+[tool.poetry.group.dev.dependencies]
+black = "^24.8.0"
+isort = "^5.13.2"
+mypy = "^1.11.2"
+
diff --git a/scripts/format.sh b/scripts/format.sh
new file mode 100755
index 00000000..7fa7c1e6
--- /dev/null
+++ b/scripts/format.sh
@@ -0,0 +1,11 @@
+#!/bin/sh
+cd "$(dirname "$0")" || exit 1
+cd ..
+
+
+printf "\nFormatting Python 🧹\n"
+poetry run black .
+
+printf "\nSorting imports 🧹\n"
+poetry run isort .
+