unrealxr/libunreal/mcu_driver.py

145 lines
4.9 KiB
Python

from tempfile import TemporaryDirectory
from dataclasses import dataclass
from os import path, environ
from typing import Callable
from loguru import logger
from shutil import which
from time import sleep
from enum import Enum
import subprocess
import threading
import socket
import struct
import signal
import atexit
import os
class MCUCommandTypes(Enum):
ROLL = 0
PITCH = 1
YAW = 2
GENERIC_MESSAGE = 3
BRIGHTNESS_UP = 4
BRIGHTNESS_DOWN = 5
@dataclass
class MCUCallbackWrapper:
OnRollUpdate: Callable[[float], None]
OnPitchUpdate: Callable[[float], None]
OnYawUpdate: Callable[[float], None]
OnTextMessageRecieved: Callable[[str], None]
OnBrightnessUp: Callable[[int], None]
OnBrightnessDown: Callable[[int], None]
vendor_to_driver_table: dict[str, str] = {
"MRG": "xreal_ar_driver",
}
def find_executable_path_from_driver_name(driver_name) -> str:
# First try the normal driver path
try:
driver_path = path.join("drivers", driver_name)
file = open(driver_path)
file.close()
return driver_path
except OSError:
# Then search the system path
driver_path = which(driver_name)
if driver_path == "":
raise OSError("Could not find driver executable in driver directory or in PATH")
return driver_path
def start_mcu_event_listener(driver_vendor: str, events: MCUCallbackWrapper):
driver_executable = find_executable_path_from_driver_name(vendor_to_driver_table[driver_vendor])
created_temp_dir = TemporaryDirectory()
sock_path = path.join(created_temp_dir.name, "mcu_socket")
def on_socket_event(sock: socket.socket):
while True:
message_type = sock.recv(1)
if message_type[0] == MCUCommandTypes.ROLL.value:
roll_data = sock.recv(4)
roll_value = struct.unpack("!f", roll_data)[0]
if not isinstance(roll_value, float):
logger.warning("Expected roll value to be a float but got other type instead")
continue
events.OnRollUpdate(roll_value)
elif message_type[0] == MCUCommandTypes.PITCH.value:
pitch_data = sock.recv(4)
pitch_value = struct.unpack("!f", pitch_data)[0]
if not isinstance(pitch_value, float):
logger.warning("Expected pitch value to be a float but got other type instead")
continue
events.OnPitchUpdate(pitch_value)
elif message_type[0] == MCUCommandTypes.YAW.value:
yaw_data = sock.recv(4)
yaw_value = struct.unpack("!f", yaw_data)[0]
if not isinstance(yaw_value, float):
logger.warning("Expected yaw value to be a float but got other type instead")
continue
events.OnYawUpdate(yaw_value)
elif message_type[0] == MCUCommandTypes.GENERIC_MESSAGE.value:
length_bytes = sock.recv(4)
msg_len = struct.unpack("!I", length_bytes)[0]
msg_bytes = sock.recv(msg_len)
msg = msg_bytes.decode("utf-8", errors="replace")
events.OnTextMessageRecieved(msg)
elif message_type[0] == MCUCommandTypes.BRIGHTNESS_UP.value:
brightness_bytes = sock.recv(1)
events.OnBrightnessUp(int.from_bytes(brightness_bytes, byteorder='big'))
elif message_type[0] == MCUCommandTypes.BRIGHTNESS_DOWN.value:
brightness_bytes = sock.recv(1)
events.OnBrightnessDown(int.from_bytes(brightness_bytes, byteorder='big'))
else:
logger.warning(f"Unknown message type recieved: {str(message_type[0])}")
def start_socket_handout():
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(sock_path)
server.listen(1)
sock: socket.socket | None = None
while True:
sock, _ = server.accept()
threaded_connection_processing = threading.Thread(target=on_socket_event, args=(sock,), daemon=True)
threaded_connection_processing.start()
created_temp_dir.cleanup()
def spawn_child_process():
custom_env = environ.copy()
custom_env["UNREALXR_NREAL_DRIVER_SOCK"] = sock_path
process = subprocess.Popen([driver_executable], env=custom_env)
def kill_child():
if process.pid is None:
pass
else:
os.kill(process.pid, signal.SIGTERM)
atexit.register(kill_child)
threaded_socket_handout = threading.Thread(target=start_socket_handout, daemon=True)
threaded_socket_handout.start()
sleep(0.01) # Give the socket server time to initialize
threaded_child_process = threading.Thread(target=spawn_child_process, daemon=True)
threaded_child_process.start()