blob: 7dd8638c8319388d2c9674ae71d0089d825bfc44 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Copyright 2021 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Simple layer to abstract the grpc implementation from users."""
import asyncio
import logging
import grpc
import os
from lab import managed_dut_pb2
from lab import dut_manager_pb2
from lab import dut_manager_pb2_grpc as grpc_stub
from chromiumos.longrunning import operations_pb2
from chromiumos.longrunning import operations_pb2_grpc
from moblab_common import moblabrpc_connector
class DutManagerRpcConnectorAsync(object):
"""Abstract out the grpc details of connecting the moblab rpc
configuration server.
"""
channel = None
stub = None
operations_stub = None
@classmethod
def connect(cls):
"""Connect to the grpc service, caching the connection.
Raises:
MoblabRpcConnectorError: On a connection issue.
"""
if not cls.channel or not cls.stub:
address = os.getenv("DUT_MANAGER_GRPS_SERVICE_ADDRESS")
cls.channel = grpc.aio.insecure_channel(address)
if cls.channel:
cls.stub = grpc_stub.DutManagerServiceStub(cls.channel)
if not cls.stub:
raise moblabrpc_connector.MoblabRpcConnectorError(
"Unable to connect to DutManager service at %s"
% address
)
cls.operations_stub = operations_pb2_grpc.OperationsStub(
cls.channel
)
if not cls.operations_stub:
raise moblabrpc_connector.MoblabRpcConnectorError(
"Unable to connect to Operations service at %s"
% address
)
else:
raise moblabrpc_connector.MoblabRpcConnectorError(
"No server found %s" % address
)
@classmethod
def disconnect(cls):
"""Invalidate the cached grpc server connection."""
cls.channel = None
cls.stub = None
cls.operations_stub = None
@classmethod
async def create_managed_dut(cls, ip):
"""Create and add managed DUT to the list of managed DUTs.
Raises:
MoblabRpcConnectorError: If rpc call failed.
"""
cls.connect()
try:
request = dut_manager_pb2.CreateManagedDutRequest(
name=managed_dut_pb2.NetworkIdentifier(ip_address=ip),
display_name=ip,
)
operation = await cls.stub.CreateManagedDut(request)
logging.debug("Operation returned: %s", operation)
operation_request = operations_pb2.GetOperationRequest(
name=operation.name
)
while not operation.done:
operation = await cls.operations_stub.GetOperation(
operation_request
)
logging.debug("tick: Operation returned: %s", operation)
await asyncio.sleep(1)
return operation.response
except grpc.RpcError:
cls.disconnect() # Force reconnect on retry
logging.exception("Error enrolling dut %s", ip)
raise moblabrpc_connector.MoblabRpcConnectorError(
"Error enrolling dut %s" % ip
)
@classmethod
async def list_managed_duts(cls):
"""List all DUTs managed by DUT Manager.
Raises:
MoblabRpcConnectorError: If rpc call failed.
Returns: List of ManagedDut objects.
"""
cls.connect()
try:
request = dut_manager_pb2.ListManagedDutsRequest()
response = await cls.stub.ListManagedDuts(request)
return response.duts if response else None
except grpc.RpcError:
cls.disconnect() # Force reconnect on retry
logging.exception("Error listing all duts")
raise moblabrpc_connector.MoblabRpcConnectorError(
"Error listing all duts"
)
class DutManagerRpcConnector(object):
"""Abstract out the grpc details of connecting the moblab rpc
configuration server.
"""
channel = None
stub = None
@classmethod
def connect(cls):
"""Connect to the grpc service, caching the connection.
Raises:
MoblabRpcConnectorError: On a connection issue.
"""
if not cls.channel or not cls.stub:
address = os.getenv("DUT_MANAGER_GRPS_SERVICE_ADDRESS")
cls.channel = grpc.insecure_channel(address)
if cls.channel:
cls.stub = grpc_stub.DutManagerServiceStub(cls.channel)
if not cls.stub:
raise moblabrpc_connector.MoblabRpcConnectorError(
"Unable to connect to DutManager service at %s"
% address
)
else:
raise moblabrpc_connector.MoblabRpcConnectorError(
"No server found %s" % address
)
@classmethod
def disconnect(cls):
"""Invalidate the cached grpc server connection."""
cls.channel = None
cls.stub = None
@classmethod
def get_managed_dut(cls, dut_ip: str):
"""Get managed dut details.
Raises:
MoblabRpcConnectorError: If rpc call failed.
Returns: ManagedDut object
"""
cls.connect()
try:
request = dut_manager_pb2.GetManagedDutRequest()
request.name.ip_address = dut_ip
response = cls.stub.GetManagedDut(request)
return response.dut if response else None
except grpc.RpcError:
cls.disconnect() # Force reconnect on retry
logging.exception("Error reading the %s dut", dut_ip)
raise moblabrpc_connector.MoblabRpcConnectorError(
"Error reading the %s dut", dut_ip
)
@classmethod
def update_managed_dut(cls, dut, field_mask):
"""Update managed dut details.
Raises:
MoblabRpcConnectorError: If rpc call failed.
Returns: None
"""
cls.connect()
try:
request = dut_manager_pb2.UpdateManagedDutRequest(
dut=dut, update_mask=field_mask
)
cls.stub.UpdateManagedDut(request)
except grpc.RpcError:
cls.disconnect() # Force reconnect on retry
error_msg = "Error updating dut: {}".format(dut.name.ip_address)
logging.exception(error_msg)
raise moblabrpc_connector.MoblabRpcConnectorError(error_msg)
@classmethod
def delete_managed_dut(cls, dut_ip: str):
"""Delete dut from managed dut collection.
Raises:
MoblabRpcConnectorError: If rpc call failed.
Returns: None
"""
cls.connect()
try:
request = dut_manager_pb2.DeleteManagedDutRequest(
name=managed_dut_pb2.NetworkIdentifier(ip_address=dut_ip)
)
cls.stub.DeleteManagedDut(request)
except grpc.RpcError:
cls.disconnect() # Force reconnect on retry
error_msg = "Error deleting dut: {}".format(dut_ip)
logging.exception(error_msg)
raise moblabrpc_connector.MoblabRpcConnectorError(error_msg)