Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 46 additions & 67 deletions src/qasync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,84 +22,62 @@
import time
from concurrent.futures import Future
from queue import Queue
from typing import TYPE_CHECKING, Literal, Tuple, cast, get_args

logger = logging.getLogger(__name__)

QtModule = None

# If QT_API env variable is given, use that or fail trying
qtapi_env = os.getenv("QT_API", "").strip().lower()
if qtapi_env:
env_to_mod_map = {
"pyqt5": "PyQt5",
"pyqt6": "PyQt6",
"pyqt": "PyQt6",
"pyside6": "PySide6",
"pyside2": "PySide2",
"pyside": "PySide6",
}
if qtapi_env in env_to_mod_map:
QtModuleName = env_to_mod_map[qtapi_env]
else:
raise ImportError(
"QT_API environment variable set ({}) but not one of [{}].".format(
qtapi_env, ", ".join(env_to_mod_map.keys())
)
)

logger.info("Forcing use of {} as Qt Implementation".format(QtModuleName))
QtModule = importlib.import_module(QtModuleName)
# runtime preference order is the same as this literal
QtFlavor = Literal["PyQt6", "PyQt5", "PySide6", "PySide2"]
QT_ALL = cast(Tuple[QtFlavor, ...], get_args(QtFlavor))

# If a Qt lib is already imported, use that
if not QtModule:
for QtModuleName in ("PyQt5", "PyQt6", "PySide2", "PySide6"):
if QtModuleName in sys.modules:
QtModule = sys.modules[QtModuleName]
break

# Try importing qt libs
if not QtModule:
for QtModuleName in ("PyQt5", "PyQt6", "PySide2", "PySide6"):
def _get_qt_flavor() -> QtFlavor:
env = os.getenv("QT_API", "").strip().lower()
# prioritize env var
if env:
lookup = {name.lower(): name for name in QT_ALL}
try:
QtModule = importlib.import_module(QtModuleName)
name = lookup[env]
except KeyError as err:
raise ImportError(
f"QT_API={env!r} is not one of {', '.join(QT_ALL)}"
) from err
logger.info("Forcing use of %s as Qt implementation", name)
return cast(QtFlavor, name)
# if already imported, use it
for name in QT_ALL:
if name in sys.modules:
return cast(QtFlavor, name)
# use the first available on system
for name in QT_ALL:
try:
importlib.import_module(name)
return cast(QtFlavor, name)
except ImportError:
continue
else:
break

if not QtModule:
raise ImportError("No Qt implementations found")

QtCore = importlib.import_module(QtModuleName + ".QtCore", package=QtModuleName)
QtGui = importlib.import_module(QtModuleName + ".QtGui", package=QtModuleName)

if QtModuleName == "PyQt5":
from PyQt5 import QtWidgets
from PyQt5.QtCore import pyqtSlot as Slot

QApplication = QtWidgets.QApplication
AllEvents = QtCore.QEventLoop.ProcessEventsFlags(0x00)

elif QtModuleName == "PyQt6":
from PyQt6 import QtWidgets
from PyQt6.QtCore import pyqtSlot as Slot
if TYPE_CHECKING:
from PySide6 import QtCore, QtWidgets
from PySide6.QtCore import Slot

QApplication = QtWidgets.QApplication
AllEvents = QtCore.QEventLoop.ProcessEventsFlag(0x00)

elif QtModuleName == "PySide2":
from PySide2 import QtWidgets
from PySide2.QtCore import Slot

else:
qt_flavor = _get_qt_flavor()
QtCore = importlib.import_module(f"{qt_flavor}.QtCore")
QtWidgets = importlib.import_module(f"{qt_flavor}.QtWidgets")
QApplication = QtWidgets.QApplication
AllEvents = QtCore.QEventLoop.ProcessEventsFlags(0x00)

elif QtModuleName == "PySide6":
from PySide6 import QtWidgets
from PySide6.QtCore import Slot
# PyQt uses pyqtSlot, PySide uses Slot
Slot = getattr(QtCore, "pyqtSlot", None) or getattr(QtCore, "Slot", None)

QApplication = QtWidgets.QApplication
AllEvents = QtCore.QEventLoop.ProcessEventsFlags(0x00)
# PyQt6 uses ProcessEventsFlags, others use ProcessEventsFlag
Flags = getattr(QtCore.QEventLoop, "ProcessEventsFlags", None) or getattr(
QtCore.QEventLoop, "ProcessEventsFlag"
)
AllEvents = Flags(0x00)

from ._common import with_logger # noqa

Expand Down Expand Up @@ -825,19 +803,20 @@
try:
cls._logger.error(*args, **kwds)
except: # noqa E722
sys.stderr.write("{!r}, {!r}\n".format(args, kwds))

Check warning on line 806 in src/qasync/__init__.py

View workflow job for this annotation

GitHub Actions / collect coverage

Missing coverage

Missing coverage on line 806


from ._unix import _SelectorEventLoop # noqa

QSelectorEventLoop = type("QSelectorEventLoop", (_QEventLoop, _SelectorEventLoop), {})
if sys.platform == "win32":
from ._windows import _ProactorEventLoop # noqa: F401

if os.name == "nt":
from ._windows import _ProactorEventLoop
class QIOCPEventLoop(_QEventLoop, _ProactorEventLoop): ...

QIOCPEventLoop = type("QIOCPEventLoop", (_QEventLoop, _ProactorEventLoop), {})
QEventLoop = QIOCPEventLoop
else:
from ._unix import _SelectorEventLoop # noqa: F401

class QSelectorEventLoop(_QEventLoop, _SelectorEventLoop): ...

QEventLoop = QSelectorEventLoop


Expand Down
82 changes: 82 additions & 0 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
BSD License
"""

import importlib
import sys
import types

import pytest
from pytest import MonkeyPatch

from qasync import QT_ALL, _get_qt_flavor


def _purge_qt(mp: MonkeyPatch):
"""Ensure no Qt modules are loaded."""
for name in QT_ALL:
mp.delitem(sys.modules, name, raising=False)


def _stub_import(mp: MonkeyPatch, available=()):
"""Patch importlib.import_module to only 'exist' for certain modules."""

def fake_import(name):
if name in available:
return types.ModuleType(name)
raise ImportError

mp.setattr(importlib, "import_module", fake_import)


def test_env_exact():
with MonkeyPatch.context() as mp:
_purge_qt(mp)
_stub_import(mp)
mp.setenv("QT_API", "PySide6")
assert _get_qt_flavor() == "PySide6"


def test_env_invalid_raises():
with MonkeyPatch.context() as mp:
_purge_qt(mp)
_stub_import(mp)
mp.setenv("QT_API", "QT")
with pytest.raises(ImportError):
_get_qt_flavor()


def test_already_imported_precedence():
with MonkeyPatch.context() as mp:
_purge_qt(mp)
_stub_import(mp)
mp.delenv("QT_API", raising=False)
mp.setitem(sys.modules, "PySide2", types.ModuleType("PySide2"))
mp.setitem(sys.modules, "PyQt5", types.ModuleType("PyQt5"))
assert _get_qt_flavor() == next(n for n in QT_ALL if n in ("PyQt5", "PySide2"))


def test_first_available_import():
with MonkeyPatch.context() as mp:
_purge_qt(mp)
_stub_import(mp, available=("PySide6",))
mp.delenv("QT_API", raising=False)
assert _get_qt_flavor() == "PySide6"


def test_none_available_raises():
with MonkeyPatch.context() as mp:
_purge_qt(mp)
_stub_import(mp)
mp.delenv("QT_API", raising=False)
with pytest.raises(ImportError):
_get_qt_flavor()


def test_env_overrides_imported():
with MonkeyPatch.context() as mp:
_purge_qt(mp)
_stub_import(mp)
mp.setitem(sys.modules, "PyQt6", types.ModuleType("PyQt6"))
mp.setenv("QT_API", "PySide2")
assert _get_qt_flavor() == "PySide2"
Loading