blob: bbdf5950365abbed854553e66937eb8b9b18cae2 [file] [log] [blame]
# Copyright 2025 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
from __future__ import annotations
import abc
import atexit
import contextlib
import logging
from typing import TYPE_CHECKING, Final, Iterator, Self, Tuple
from typing_extensions import override
from crossbench import exception
from crossbench.parse import NumberParser
if TYPE_CHECKING:
from crossbench.plt.base import Platform
class PortForwardException(Exception):
pass
class PortScope:
""" This class allows to forward ports in a local scope.
The main PortManager is responsible for handling all forwarded ports, while
the PortScope only tracks forwarded ports while it's active, and then tears
down those forwarded ports when the scope is closed.
"""
def __init__(self,
manager: PortManager,
parent_scope: Self | None = None) -> None:
self._manager: Final[PortManager] = manager
self._parent_scope: Final[Self | None] = parent_scope
assert parent_scope is not self
self._forwarded_ports: dict[int, int | str] = {}
self._reverse_forwarded_ports: dict[int, int] = {}
@property
def parent_scope(self) -> Self | None:
return self._parent_scope
@property
def platform(self) -> Platform:
return self._manager.platform
@property
def is_nested(self) -> bool:
return self._parent_scope is not None
@property
def is_empty(self) -> bool:
return not self._forwarded_ports and not self._reverse_forwarded_ports
@property
def forwarded_ports(self) -> Tuple[int, ...]:
return tuple(self._forwarded_ports.keys())
@property
def reverse_forwarded_ports(self) -> Tuple[int, ...]:
return tuple(self._reverse_forwarded_ports.keys())
@contextlib.contextmanager
def nested(self) -> Iterator[PortScope]:
with self._manager.nested() as scope:
yield scope
def is_forwarded_port_used(self, local_port: int) -> bool:
return bool(self.lookup_forwarded_port(local_port))
def is_reverse_forwarded_port_used(self, remote_port: int) -> bool:
return bool(self.lookup_reverse_forwarded_port(remote_port))
def is_own_forwarded_port_used(self, local_port: int) -> bool:
return local_port in self._forwarded_ports
def is_own_reverse_forwarded_port_used(self, remote_port: int) -> bool:
return remote_port in self._reverse_forwarded_ports
def lookup_forwarded_port(self, local_port: int) -> PortScope | None:
for current_scope in self:
if current_scope.is_own_forwarded_port_used(local_port):
return current_scope
return None
def lookup_reverse_forwarded_port(self, remote_port: int) -> PortScope | None:
for current_scope in self:
if current_scope.is_own_reverse_forwarded_port_used(remote_port):
return current_scope
return None
def __iter__(self) -> Iterator[Self]:
current_scope: Self | None = self
while current_scope:
yield current_scope
current_scope = current_scope.parent_scope
def forward(self, local_port: int, remote_port: int) -> int:
local_port = NumberParser.port_number_zero(local_port, "local_port")
if self.is_forwarded_port_used(local_port):
raise PortForwardException(
f"Cannot forward local port {local_port} twice, "
"it is already forwarded.")
local_port = self._manager.forward(local_port, remote_port)
self._forwarded_ports[local_port] = remote_port
return local_port
def forward_devtools(self, local_port: int, remote_identifier: str) -> int:
"""Forwards a DevTools debugging port from a remote target to a local port.
Args:
local_port: The local port number to forward to. If 0, a free
port will be chosen by the system.
remote_identifier: A string identifying the remote DevTools socket or
service. For Android, this is typically a
localabstract socket name like
"chrome_devtools_remote".
For other platforms, it might be a remote port number
or other service identifier.
Returns:
The local port number that was actually used for forwarding.
"""
local_port = NumberParser.port_number_zero(local_port, "local_port")
if self.is_forwarded_port_used(local_port):
raise PortForwardException(
f"Cannot forward local port {local_port} twice, "
"it is already forwarded.")
local_port = self._manager.forward_devtools(local_port, remote_identifier)
self._forwarded_ports[local_port] = remote_identifier
return local_port
def stop_forward(self, local_port: int) -> None:
if local_port not in self._forwarded_ports:
raise PortForwardException(
f"Cannot stop forwarding local port {local_port}, "
f"it was never forwarded.")
del self._forwarded_ports[local_port]
self._manager.stop_forward(local_port)
def reverse_forward(self, remote_port: int, local_port: int) -> int:
remote_port = NumberParser.port_number_zero(remote_port, "remote_port")
if self.is_reverse_forwarded_port_used(remote_port):
raise PortForwardException(
f"Cannot reverse forward remote port {remote_port} twice, "
"it is already forwarded.")
remote_port = self._manager.reverse_forward(remote_port, local_port)
self._reverse_forwarded_ports[remote_port] = local_port
return remote_port
def stop_reverse_forward(self, remote_port: int) -> None:
if remote_port not in self._reverse_forwarded_ports:
raise PortForwardException(
f"Cannot stop reverse forwarding remote port {remote_port}, "
f"it was never forwarded.")
del self._reverse_forwarded_ports[remote_port]
self._manager.stop_reverse_forward(remote_port)
class PortManager(abc.ABC):
"""Keeps track of opened forwarded and reverse-forwarded ports.
All ports are closed when the PortManager is closed.
To limit the risk of leaking ports you can use the .nested() scope.
Global PortManager (one instance per platform)
- with PortScope 1:
- forward port 1
- forward port 2
yield
- disable port 1
- disable port 2
- with PortScope 2:
...
"""
def __init__(self, platform: Platform, throw: bool = False) -> None:
self._platform: Platform = platform
self._throw: bool = throw
self._is_active: bool = False
# Keeps track of scoped ports.
self._port_scope: PortScope = PortScope(self, None)
self._start()
def _start(self) -> None:
assert not self._is_active, f"Cannot activate {self} twice"
assert self._port_scope.is_empty, "Expected empty port scope"
self._is_active = True
atexit.register(self.stop)
@property
def scope(self) -> PortScope:
return self._port_scope
@property
def platform(self) -> Platform:
return self._platform
@contextlib.contextmanager
def nested(self) -> Iterator[PortScope]:
"""Open a nested port scope, all forwarded ports that were opened
during this scope will be closed when leaving the scope. """
old_scope = self._port_scope
self._port_scope = PortScope(self, self._port_scope)
try:
yield self._port_scope
finally:
try:
self._stop_current_scoped_ports()
finally:
self._port_scope = old_scope
def assert_is_active(self) -> None:
if not self._is_active:
raise PortForwardException("Need active PortManager")
@property
def is_empty(self) -> bool:
return self._port_scope.is_empty and not self.has_nested_scopes
@property
def has_nested_scopes(self) -> bool:
return self._port_scope.is_nested
def stop(self) -> None:
self.assert_is_active()
atexit.unregister(self.stop)
self._stop_all()
self._is_active = False
def _stop_all(self) -> None:
if self.has_nested_scopes:
logging.error("Closing PortManager with open nested port scopes")
exceptions = exception.Annotator(self._throw)
for port_scope in self._port_scope:
with exceptions.capture("Stopping port forwarding"):
self._stop_scoped_ports(port_scope, exceptions)
exceptions.assert_success("Could not stop all port forwarding")
def _stop_current_scoped_ports(self) -> None:
exceptions = exception.Annotator(self._throw)
self._stop_scoped_ports(self._port_scope, exceptions)
exceptions.assert_success("Could not stop all port forwarding")
def _stop_scoped_ports(self, port_scope: PortScope,
exceptions: exception.Annotator) -> None:
for local_port in port_scope.forwarded_ports:
with exceptions.capture(f"Stopping forwarding {local_port}"):
port_scope.stop_forward(local_port)
for remote_port in port_scope.reverse_forwarded_ports:
with exceptions.capture(f"Stopping reverse forwarding {remote_port}"):
port_scope.stop_reverse_forward(remote_port)
assert self._port_scope.is_empty, "Expected empty PortScope"
@abc.abstractmethod
def forward(self, local_port: int, remote_port: int) -> int:
pass
def forward_devtools(self, local_port: int, remote_identifier: str) -> int:
del local_port, remote_identifier
raise NotImplementedError(
f"forward_devtools_port not implemented for {self}")
@abc.abstractmethod
def stop_forward(self, local_port: int) -> None:
pass
@abc.abstractmethod
def reverse_forward(self, remote_port: int, local_port: int) -> int:
pass
@abc.abstractmethod
def stop_reverse_forward(self, remote_port: int) -> None:
pass
class LocalPortManager(PortManager):
def __init__(self, platform: Platform, throw: bool = False) -> None:
super().__init__(platform, throw)
self.platform.assert_is_local()
@override
def forward(self, local_port: int, remote_port: int) -> int:
""" Forwards a device remote_port to a local port."""
if remote_port != local_port:
raise ValueError("Cannot forward a remote port on a local platform.")
local_port = NumberParser.port_number(local_port, "local_port")
self.platform.assert_is_local()
return local_port
@override
def stop_forward(self, local_port: int) -> None:
del local_port
@override
def reverse_forward(self, remote_port: int, local_port: int) -> int:
if remote_port != local_port:
raise ValueError("Cannot forward a remote port on a local platform.")
remote_port = NumberParser.port_number(remote_port, "remote_port")
return remote_port
@override
def stop_reverse_forward(self, remote_port: int) -> None:
del remote_port