from __future__ import annotations
import asyncio
import traceback
import typing
import uuid
from functools import partial
import attrs
import lsprotocol.types as lsp
from esbonio import server
from esbonio.server import Uri
from esbonio.sphinx_agent import types
from .client import ClientState
from .config import SphinxConfig
from .config import register_structure_hooks
if typing.TYPE_CHECKING:
from collections.abc import Callable
from esbonio.server.features.project_manager import ProjectManager
from .client import SphinxClient
SphinxClientFactory = Callable[["SphinxManager", "SphinxConfig"], "SphinxClient"]
[docs]
@attrs.define
class ClientCreatedNotification:
"""The payload of a ``sphinx/clientCreated`` notification"""
id: str
"""The client's id"""
scope: str
"""The scope at which the client was created."""
config: SphinxConfig
"""The final configuration."""
pid: int
"""The process id of the client process."""
[docs]
@attrs.define
class AppCreatedNotification:
"""The payload of a ``sphinx/appCreated`` notification"""
id: str
"""The client's id"""
application: types.SphinxInfo
"""Details about the created application."""
[docs]
@attrs.define
class ClientErroredNotification:
"""The payload of a ``sphinx/clientErrored`` notification"""
id: str
"""The client's id"""
error: str
"""Short description of the error."""
detail: str
"""Detailed description of the error."""
[docs]
@attrs.define
class ClientDestroyedNotification:
"""The payload of ``sphinx/clientDestroyed`` notification."""
id: str
"""The client's id"""
@attrs.define
class SphinxBuildTriggers:
"""Valid configuration options for Sphinx build triggers."""
on_save: bool = attrs.field(default=True)
"""Trigger a build when a file is saved."""
on_change: bool | float = attrs.field(default=2.0)
"""Trigger a build each time a file has changed, with a configurable delay."""
@attrs.define
class ManagerConfig:
"""Configuration options for the sphinx manager."""
build_triggers: SphinxBuildTriggers = attrs.field(factory=SphinxBuildTriggers)
"""Options controlling when to trigger a Sphinx build."""
@attrs.define
class RestartSphinxParams:
"""Parameters for the ``esbonio.sphinx.restart`` command"""
id: str
"""The id of the sphinx client to restart"""
class SphinxManager(server.LanguageFeature):
"""Responsible for managing Sphinx application instances."""
def __init__(
self,
client_factory: SphinxClientFactory,
project_manager: ProjectManager,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.client_factory = client_factory
"""Used to create new Sphinx client instances."""
self.project_manager = project_manager
"""The project manager instance to use."""
self.clients: dict[str, SphinxClient | None] = {
# Prevent any clients from being created in the global scope.
"": None,
}
"""Holds currently active Sphinx clients."""
self.config: ManagerConfig = ManagerConfig()
"""The SphinxManager's configuration."""
self._events = server.EventSource(self.logger)
"""The SphinxManager can emit events."""
self._pending_builds: dict[str, asyncio.Task] = {}
"""Holds tasks that will trigger a build after a given delay if not cancelled."""
self._progress_tokens: dict[str, str] = {}
"""Holds work done progress tokens."""
def add_listener(self, event: str, handler):
self._events.add_listener(event, handler)
def initialize(self, params: lsp.InitializeParams):
"""Called once the initial handshake between client and server has finished."""
register_structure_hooks(self.converter)
self.configuration.subscribe(
"esbonio.sphinx", ManagerConfig, self.update_configuration
)
def update_configuration(self, event: server.ConfigChangeEvent[ManagerConfig]):
"""Called when the user's configuration is updated."""
self.config = event.value
async def document_change(self, params: lsp.DidChangeTextDocumentParams):
if (uri := Uri.parse(params.text_document.uri)) is None:
return
client = await self.get_client(uri)
if client is None:
return
if (delay := self.config.build_triggers.on_change) is False:
return
# Cancel any existing pending builds
if (task := self._pending_builds.pop(client.id, None)) is not None:
task.cancel()
self._pending_builds[client.id] = asyncio.create_task(
self.trigger_build_after(
uri,
client.id,
delay=max(float(delay), 1.0), # Enforce a minimum 1s delay
)
)
async def document_open(self, params: lsp.DidOpenTextDocumentParams):
# Ensure that a Sphinx app instance is created the first time a document in a
# given project is opened.
if (uri := Uri.parse(params.text_document.uri)) is not None:
await self.get_client(uri)
async def document_save(self, params: lsp.DidSaveTextDocumentParams):
if (uri := Uri.parse(params.text_document.uri)) is None:
return
client = await self.get_client(uri)
if client is None:
return
if not self.config.build_triggers.on_save:
return
# Cancel any existing pending builds
if (task := self._pending_builds.pop(client.id, None)) is not None:
task.cancel()
self.logger.debug("Build triggered on save")
await self.trigger_build(uri)
async def shutdown(self, params: None):
"""Called when the server is instructed to ``shutdown``."""
# Stop any existing clients.
tasks = []
for client in self.clients.values():
if client:
self.logger.debug("Stopping SphinxClient: %s", client)
tasks.append(asyncio.create_task(client.stop()))
await asyncio.gather(*tasks)
async def trigger_build_after(self, uri: Uri, app_id: str, delay: float):
"""Trigger a build for the given uri after the given delay."""
await asyncio.sleep(delay)
self._pending_builds.pop(app_id)
self.logger.debug("Build triggered after %ss delay", delay)
await self.trigger_build(uri)
async def trigger_build(self, uri: Uri):
"""Trigger a build for the relevant Sphinx application for the given uri."""
client = await self.get_client(uri)
if client is None:
return
if client.state not in {ClientState.Running}:
self.logger.debug("Skipping build, client is %s", client.state)
return
if (project := self.project_manager.get_project(uri)) is None:
self.logger.debug("Skipping build, project is None")
return
# Pass through any unsaved content to the Sphinx agent.
content_overrides: dict[str, str] = {}
known_src_uris = await project.get_src_uris()
for src_uri in known_src_uris:
doc = self.server.workspace.get_text_document(str(src_uri))
doc_version = doc.version or 0
saved_version = getattr(doc, "saved_version", 0)
if saved_version < doc_version:
content_overrides[str(src_uri)] = doc.source
try:
result = await client.build(content_overrides=content_overrides)
# Notify listeners.
self._events.trigger("build", client, result)
except Exception as exc:
self.server.window_show_message(
lsp.ShowMessageParams(message=f"{exc}", type=lsp.MessageType.Error)
)
return
async def restart_client(self, client_id: str):
"""Restart the client with the given id"""
for client in self.clients.values():
if client is None:
continue
if client.id != client_id:
continue
try:
await client.restart()
except Exception:
self.logger.exception("Unable to restart sphinx client")
break
else:
self.logger.error(f"No client with id {client_id!r} available to restart")
async def get_client(self, uri: Uri) -> SphinxClient | None:
"""Given a uri, return the relevant sphinx client instance for it."""
scope = self.server.configuration.scope_for(uri)
if scope not in self.clients:
self.logger.debug("No client found, creating new subscription")
self.server.configuration.subscribe(
"esbonio.sphinx",
SphinxConfig,
partial(self._create_or_replace_client, uri),
scope=uri,
)
# It's possible for this code path to be hit multiple times in quick
# succession e.g. on a fresh server start with N .rst files already open,
# creating the opportunity to accidentally spawn N duplicated clients!
#
# To prevent this, store a `None` at this scope, all being well it will be
# replaced with the actual client instance when the
# `_create_or_replace_client` callback runs.
self.clients[scope] = None
# The first few callers in a given scope will miss out, but that shouldn't
# matter too much
return None
if (client := self.clients[scope]) is None:
self.logger.debug("No applicable client for uri: %s", uri)
return None
return await client
async def _create_or_replace_client(
self, uri: Uri, event: server.ConfigChangeEvent[SphinxConfig]
):
"""Create or replace thesphinx client instance for the given config.
Parameters
----------
uri
The uri for which the sphinx client was originally created for
event
The configuration change event
"""
config = event.value
# Do not try and create clients in the global scope
if event.scope == "":
return
# If there was a previous client, stop it.
if (previous_client := self.clients.pop(event.scope, None)) is not None:
self.server.protocol.notify(
"sphinx/clientDestroyed",
ClientDestroyedNotification(id=previous_client.id),
)
self.server.run_task(previous_client.stop())
resolved = config.resolve(uri, self.server.workspace, self.logger)
if resolved is None:
self.clients[event.scope] = None
return
self.clients[event.scope] = client = self.client_factory(self, resolved)
self.logger.debug("Client created for scope %s", event.scope)
client.add_listener("state-change", partial(self._on_state_change, event.scope))
# Start the client
await client
def _on_state_change(
self,
scope: str,
client: SphinxClient,
old_state: ClientState,
new_state: ClientState,
):
"""React to state changes in the client."""
if new_state == ClientState.Starting:
self.server.protocol.notify(
"sphinx/clientCreated",
ClientCreatedNotification(
id=client.id, scope=scope, config=client.config, pid=client.pid
),
)
if old_state == ClientState.Starting and new_state == ClientState.Running:
if (sphinx_info := client.sphinx_info) is not None:
self.project_manager.register_project(scope, client.db)
self._events.trigger("app-created", client)
self.server.protocol.notify(
"sphinx/appCreated",
AppCreatedNotification(id=client.id, application=sphinx_info),
)
if old_state == ClientState.Building:
self.stop_progress(client)
if new_state == ClientState.Building:
self.server.run_task(self.start_progress(client))
if new_state == ClientState.Errored:
error = ""
detail = ""
if (exc := getattr(client, "exception", None)) is not None:
error = f"{type(exc).__name__}: {exc}"
detail = "".join(
traceback.format_exception(type(exc), exc, exc.__traceback__)
)
self.server.window_show_message(
lsp.ShowMessageParams(message=error, type=lsp.MessageType.Error)
)
self.server.protocol.notify(
"sphinx/clientErrored",
ClientErroredNotification(id=client.id, error=error, detail=detail),
)
async def start_progress(self, client: SphinxClient):
"""Start reporting work done progress for the given client."""
# Make sure any existing progress tokens are cleaned up
if (existing := self._progress_tokens.get(client.id)) is not None:
self.logger.warning("Overwriting existing progress token: %r!", existing)
self.stop_progress(client)
token = str(uuid.uuid4())
self.logger.debug("Starting progress: '%s'", token)
try:
await self.server.work_done_progress.create_async(token)
except Exception as exc:
self.logger.debug("Unable to create progress token: %s", exc)
return
self._progress_tokens[client.id] = token
self.server.work_done_progress.begin(
token,
lsp.WorkDoneProgressBegin(title="sphinx-build", cancellable=False),
)
def stop_progress(self, client: SphinxClient):
if (token := self._progress_tokens.pop(client.id, None)) is None:
return
self.logger.debug("Ending progress: %r", token)
self.server.work_done_progress.end(
token, lsp.WorkDoneProgressEnd(message="Finished")
)
def report_progress(self, client: SphinxClient, progress: types.ProgressParams):
"""Report progress done for the given client."""
if client.state not in {ClientState.Running, ClientState.Building}:
return
if (token := self._progress_tokens.get(client.id, None)) is None:
return
self.server.work_done_progress.report(
token,
lsp.WorkDoneProgressReport(
message=progress.message,
percentage=progress.percentage,
cancellable=False,
),
)