Skip to content

Instantly share code, notes, and snippets.

@z3z1ma
Last active January 25, 2025 22:10
Show Gist options
  • Save z3z1ma/4f4677f8c1f9f8edb14b334ecd5a7747 to your computer and use it in GitHub Desktop.
Save z3z1ma/4f4677f8c1f9f8edb14b334ecd5a7747 to your computer and use it in GitHub Desktop.
A simple language server implementation for SQLMesh using PyGLS
#!/usr/bin/env python
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration."""
import asyncio
import gc
import io
import logging
import re
import typing as t
import weakref
from collections import defaultdict
from contextlib import suppress
from functools import lru_cache
from itertools import cycle
from pathlib import Path
import sqlmesh
from lsprotocol import types
from pygls.server import LanguageServer
from pygls.workspace import TextDocument
from sqlglot.errors import ParseError
from sqlmesh.core.dialect import format_model_expressions, parse
from sqlmesh.core.lineage import column_description
from sqlmesh.core.model import SqlModel
from sqlmesh.utils import type_is_known
logger = logging.getLogger(__name__)
WORKSPACE_DIAGNOSTICS: t.Dict[str, t.Tuple[t.Optional[int], t.List[types.Diagnostic]]] = {}
"""A mapping of document URIs to diagnostics."""
CONTEXTS: t.Dict[str, sqlmesh.Context] = {}
"""A mapping of workspace paths to SQLMesh contexts."""
PATHS_TO_MODELS: t.Dict[str, t.Tuple[sqlmesh.Context, sqlmesh.Model]] = {}
"""A mapping of file paths to SQLMesh (context, model) tuples."""
C_MUTEX = defaultdict(asyncio.Lock)
"""A locking mechanism for ensuring that context mutation is thread-safe."""
loop = asyncio.get_event_loop()
server = LanguageServer("sqlmesh-lsp", "v0.1.0", loop=loop)
async def refresh_context_loop(context: sqlmesh.Context) -> None:
"""Refresh the SQLMesh context every 5 seconds.
SQLMesh already ensures that the context is only refreshed when necessary so this
is efficient and safe to do, even if the context is large. Mtime checks are used.
"""
gc_iter = cycle(list(range(10)))
while True:
await asyncio.sleep(10.0)
if context._loader.reload_needed():
async with C_MUTEX[context.path]:
await asyncio.to_thread(context.load)
PATHS_TO_MODELS.update(
{str(model._path.resolve()): (context, weakref.proxy(model)) for model in context.models.values()}
)
if next(gc_iter) == 0:
gc.collect()
_CACHE = set()
"""A cache of URIs for which we have already ensured a context exists."""
async def ensure_context_for_document(document: TextDocument) -> TextDocument:
"""Ensure that a context exists for the given document if applicable."""
if document.uri in _CACHE:
return document
_CACHE.add(document.uri)
path = Path(document.path).resolve()
if path.suffix not in (".sql", ".py"):
return document
initial_path = path
while path.parents:
if str(path) in CONTEXTS:
return document
path = path.parent
path = initial_path
loaded = False
while path.parents and not loaded:
for ext in ("py", "yml", "yaml"):
config_path = path / f"config.{ext}"
if config_path.exists():
with suppress(Exception):
handle = sqlmesh.Context(paths=path)
loop.create_task(refresh_context_loop(handle))
CONTEXTS[str(path)] = handle
PATHS_TO_MODELS.update(
{str(model._path.resolve()): (handle, weakref.proxy(model)) for model in handle.models.values()}
)
server.show_message(f"Context loaded for: {path}")
loaded = True
break
path = path.parent
return document
@server.feature(types.TEXT_DOCUMENT_COMPLETION)
async def completions(ls: LanguageServer, params: types.CompletionParams):
"""Provide completions based on upstream model column information."""
items = []
document = await ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
if context is None or model is None:
return types.CompletionList(is_incomplete=False, items=[])
for dep in model.depends_on:
model_dep = context.models[dep]
if model_dep.columns_to_types:
for column, type_ in model_dep.columns_to_types.items():
items.append(
types.CompletionItem(
label=column,
label_details=types.CompletionItemLabelDetails(detail=type_.sql()),
documentation=f"Source: {dep}\n\n"
+ (column_description(context, dep, column) or "No description available"),
kind=types.CompletionItemKind.Field,
)
)
return types.CompletionList(is_incomplete=False, items=items)
@server.feature(types.TEXT_DOCUMENT_FORMATTING)
async def formatting(ls: LanguageServer, params: types.DocumentFormattingParams):
"""Format the document based using SQLMesh format_model_expressions."""
document = await ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
if context is None or model is None:
return []
default_dialect = context.default_dialect
dialect = model.dialect if model and model.is_sql else default_dialect
try:
expressions = parse(document.source, default_dialect=dialect)
except Exception:
return []
try:
fmt_doc = format_model_expressions(expressions, dialect, **context.config.format.generator_options)
if context.config.format.append_newline:
fmt_doc += "\n"
except Exception as e:
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
return []
return [
types.TextEdit(
range=types.Range(
types.Position(0, 0),
types.Position(len(document.lines), len(document.lines[-1])),
),
new_text=fmt_doc,
)
]
_top_of_file = types.Range(start=types.Position(line=0, character=0), end=types.Position(line=0, character=0))
_cached_re_compile = t.cast(t.Callable[[str, re.RegexFlag], re.Pattern[str]], lru_cache(maxsize=1024)(re.compile))
def _iter_match_ranges_in_projection(term: str, source: str):
"""Iterate over ranges of matches for a term in a SQL projection."""
col_patt = _cached_re_compile(rf'\b["`]?({term})["`]?,?', re.IGNORECASE)
projection_patt = _cached_re_compile(r"SELECT\s+(.*)\s+FROM", re.DOTALL | re.IGNORECASE)
for p_match in projection_patt.finditer(source):
if not p_match.group(1):
continue
proj_start, proj_end = p_match.span(1)
proj_substr = source[proj_start:proj_end]
for c_match in col_patt.finditer(proj_substr):
if not c_match.group(1):
continue
col_start, col_end = c_match.span(1)
start, end = proj_start + col_start, proj_start + col_end
line = source.count("\n", 0, start)
char_s = start - source.rfind("\n", 0, start) - 1
char_e = end - source.rfind("\n", 0, end)
yield types.Range(
start=types.Position(line=line, character=char_s),
end=types.Position(line=line, character=char_e),
)
def _update_diagnostics(document: TextDocument) -> None:
"""Update diagnostics for the given document."""
WORKSPACE_DIAGNOSTICS[document.uri] = (document.version, diagnostics := [])
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
if context is None or model is None:
return
default_dialect = context.default_dialect
dialect = model.dialect if model and model.is_sql else default_dialect
try:
_ = parse(document.source, default_dialect=dialect)
except ParseError as e:
for error in e.errors:
line = error["line"]
comments_before_line = [
_l for _l in document.lines[:line] if _l.strip().startswith(("/*", "--"))
] # This is just a hack to adjust the line number, not a proper solution but it works
line -= len(comments_before_line)
diagnostics.append(
types.Diagnostic(
message=e.args[0],
severity=types.DiagnosticSeverity.Error,
range=types.Range(
start=types.Position(line=line, character=error["col"]),
end=types.Position(line=line, character=error["col"]),
),
)
)
if model is not None and isinstance(model, SqlModel):
sqlmesh_renderer_logger = logging.getLogger("sqlmesh.core.renderer")
buf = io.StringIO()
interceptor = logging.StreamHandler(stream=buf)
interceptor.setLevel(logging.WARNING)
interceptor.setFormatter(logging.Formatter("%(message)s"))
sqlmesh_renderer_logger.addHandler(interceptor)
_ = model._query_renderer.render(execution_time="now")
sqlmesh_renderer_logger.removeHandler(interceptor)
buf.seek(0)
warnings = buf.read().strip()
if warnings:
diagnostics.append(
types.Diagnostic(message=warnings, severity=types.DiagnosticSeverity.Warning, range=_top_of_file)
)
setattr(model, "_columns_to_types", None) # clear cached columns to types
if model and model.columns_to_types:
for column, type_ in model.columns_to_types.items():
if not type_is_known(type_):
for range_ in _iter_match_ranges_in_projection(column, document.source):
diagnostics.append(
types.Diagnostic(
message=f"Unknown type for column: {column} - add a type hint to final projection",
severity=types.DiagnosticSeverity.Warning,
range=range_,
)
)
@server.feature(types.TEXT_DOCUMENT_DID_OPEN)
async def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
"""Update diagnostics on document open and refresh context if necessary."""
document = await ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
path = Path(document.path)
known_paths = PATHS_TO_MODELS.keys()
for context in CONTEXTS.values():
if path.is_relative_to(context.path) and path.suffix in (".sql", ".py") and str(path) not in known_paths:
ls.show_message(f"Refreshing context with new file: {path}", types.MessageType.Info)
async with C_MUTEX[context.path]:
await asyncio.to_thread(context.load)
PATHS_TO_MODELS.update(
{str(model._path.resolve()): (context, weakref.proxy(model)) for model in context.models.values()}
)
_update_diagnostics(document)
for uri, (version, diagnostics) in WORKSPACE_DIAGNOSTICS.items():
ls.publish_diagnostics(uri=uri, version=version, diagnostics=diagnostics)
@server.feature(types.TEXT_DOCUMENT_DID_CLOSE)
async def did_close(ls: LanguageServer, params: types.DidCloseTextDocumentParams):
"""Remove diagnostics on document close."""
if params.text_document.uri in WORKSPACE_DIAGNOSTICS:
del WORKSPACE_DIAGNOSTICS[params.text_document.uri]
@server.feature(types.TEXT_DOCUMENT_DID_SAVE)
async def did_save(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
"""Update diagnostics on document save."""
document = await ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
context, _ = PATHS_TO_MODELS.get(document.path, (None, None))
if context is not None:
context._loader._path_mtimes[Path(document.path)] = 0.0
async with C_MUTEX[context.path]:
await asyncio.to_thread(context.load)
for model in context.models.values():
if model._path == Path(document.path):
PATHS_TO_MODELS[document.path] = (context, weakref.proxy(model))
break
_update_diagnostics(document)
for uri, (version, diagnostics) in WORKSPACE_DIAGNOSTICS.items():
ls.publish_diagnostics(uri=uri, version=version, diagnostics=diagnostics)
@server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
async def did_change(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
"""Update diagnostics on document change."""
document = await ensure_context_for_document(ls.workspace.get_text_document(params.text_document.uri))
_update_diagnostics(document)
for uri, (version, diagnostics) in WORKSPACE_DIAGNOSTICS.items():
ls.publish_diagnostics(uri=uri, version=version, diagnostics=diagnostics)
@server.feature(types.WORKSPACE_DID_CHANGE_WATCHED_FILES)
async def did_change_watched_files(ls: LanguageServer, params: types.DidChangeWatchedFilesParams):
"""Refresh context if a file changes."""
updated = {}
for change in params.changes:
document = await ensure_context_for_document(ls.workspace.get_text_document(change.uri))
if change.type == types.FileChangeType.Changed:
_update_diagnostics(document)
continue
path = Path(document.path)
known_paths = PATHS_TO_MODELS.keys()
if change.type == types.FileChangeType.Deleted and str(path) in known_paths:
# We don't need to refresh the context if a file is deleted, we just remove it from the cache
del PATHS_TO_MODELS[str(path)]
continue
for context in CONTEXTS.values():
# If a new file is created, we need to force reload the appropriate context
if (
path.is_relative_to(context.path)
and path.suffix in (".sql", ".py")
and str(path) not in known_paths
and change.type == types.FileChangeType.Created
and not updated.get(context.path, False)
):
ls.show_message(f"Refreshing context with new file: {path}", types.MessageType.Info)
async with C_MUTEX[context.path]:
await asyncio.to_thread(context.load)
PATHS_TO_MODELS.update(
{
str(model._path.resolve()): (context, weakref.proxy(model))
for model in context.models.values()
}
)
updated[context.path] = True
@server.feature(types.TEXT_DOCUMENT_HOVER)
async def hover(ls: LanguageServer, params: types.HoverParams):
"""Provide hover information based on upstream model column information."""
pos = params.position
document = await ensure_context_for_document(ls.workspace.get_text_document(params.text_document.uri))
word = document.word_at_position(pos)
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
if context is None or model is None:
return
doc = None
for dep in model.depends_on:
with suppress(Exception):
doc = column_description(context, dep, word)
if doc:
break
else:
with suppress(Exception):
doc = column_description(context, model.name, word)
if doc is None:
return
return types.Hover(
contents=types.MarkupContent(
kind=types.MarkupKind.Markdown,
value=doc,
),
range=types.Range(
start=types.Position(line=pos.line, character=0),
end=types.Position(line=pos.line + 1, character=0),
),
)
@server.feature(types.TEXT_DOCUMENT_DEFINITION)
async def definition(ls: LanguageServer, params: types.DefinitionParams):
pos = params.position
document = await ensure_context_for_document(ls.workspace.get_text_document(params.text_document.uri))
term = document.word_at_position(pos)
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
if context is None or model is None:
return
text_to_path = {}
for dep in model.depends_on:
model_dep = context.models.get(dep)
if not model_dep:
continue
columns = model_dep.columns_to_types or {}
text_to_path.update({col: Path(model_dep._path) for col in columns})
path = text_to_path.get(term)
if path is None:
return
for range_ in _iter_match_ranges_in_projection(term, path.read_text()):
return types.Location(uri=Path(path).resolve().as_uri(), range=range_)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(message)s")
server.start_io()
@sungchun12
Copy link

rad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment