Skip to content

Instantly share code, notes, and snippets.

@altescy
Created May 11, 2023 10:21
Show Gist options
  • Save altescy/915986a1e883e84108526ae4526e067d to your computer and use it in GitHub Desktop.
Save altescy/915986a1e883e84108526ae4526e067d to your computer and use it in GitHub Desktop.
import json
import logging
import time
import urllib.request
from contextlib import suppress
from http import HTTPStatus
from http.server import HTTPServer, SimpleHTTPRequestHandler
from multiprocessing import Process
from typing import Any, Callable, Type
from unittest.mock import patch
from urllib.error import HTTPError, URLError
from urllib.parse import ParseResult, parse_qs, urlparse
from slack_bolt.app.app import App, SlackAppDevelopmentServer
from slack_bolt.context.ack import Ack
from slack_bolt.context.say import Say
from slack_sdk import WebClient
def wait_for_server(url: str, timeout: float = 5.0) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
urllib.request.urlopen(url)
return
except URLError as e:
if isinstance(e, HTTPError):
return
time.sleep(0.1)
raise TimeoutError(f"Server did not start in {timeout} seconds")
class MockHandler(SimpleHTTPRequestHandler):
protocol_version = "HTTP/1.1"
default_request_version = "HTTP/1.1"
logger = logging.getLogger(__name__)
received_requests: dict
def setup(self) -> None:
super().setup()
self.received_requests = {}
def is_valid_token(self) -> bool:
return "Authorization" in self.headers and str(self.headers["Authorization"]).startswith("Bearer xoxb-")
def is_valid_user_token(self) -> bool:
return "Authorization" in self.headers and str(self.headers["Authorization"]).startswith("Bearer xoxp-")
def set_common_headers(self) -> None:
self.send_header("content-type", "application/json;charset=utf-8")
self.send_header("connection", "close")
self.end_headers()
invalid_auth = {
"ok": False,
"error": "invalid_auth",
}
oauth_v2_access_response = """
{
"ok": true,
"access_token": "xoxb-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy",
"token_type": "bot",
"scope": "chat:write,commands",
"bot_user_id": "U0KRQLJ9H",
"app_id": "A0KRD7HC3",
"team": {
"name": "Slack Softball Team",
"id": "T9TK3CUKW"
},
"enterprise": {
"name": "slack-sports",
"id": "E12345678"
},
"authed_user": {
"id": "U1234",
"scope": "chat:write",
"access_token": "xoxp-1234",
"token_type": "user"
}
}
"""
oauth_v2_access_bot_refresh_response = """
{
"ok": true,
"app_id": "A0KRD7HC3",
"access_token": "xoxb-valid-refreshed",
"expires_in": 43200,
"refresh_token": "xoxe-1-valid-bot-refreshed",
"token_type": "bot",
"scope": "chat:write,commands",
"bot_user_id": "U0KRQLJ9H",
"team": {
"name": "Slack Softball Team",
"id": "T9TK3CUKW"
},
"enterprise": {
"name": "slack-sports",
"id": "E12345678"
}
}
"""
oauth_v2_access_user_refresh_response = """
{
"ok": true,
"app_id": "A0KRD7HC3",
"access_token": "xoxp-valid-refreshed",
"expires_in": 43200,
"refresh_token": "xoxe-1-valid-user-refreshed",
"token_type": "user",
"scope": "search:read",
"team": {
"name": "Slack Softball Team",
"id": "T9TK3CUKW"
},
"enterprise": {
"name": "slack-sports",
"id": "E12345678"
}
}
"""
bot_auth_test_response = """
{
"ok": true,
"url": "https://subarachnoid.slack.com/",
"team": "Subarachnoid Workspace",
"user": "bot",
"team_id": "T0G9PQBBK",
"user_id": "W23456789",
"bot_id": "BZYBOTHED"
}
"""
user_auth_test_response = """
{
"ok": true,
"url": "https://subarachnoid.slack.com/",
"team": "Subarachnoid Workspace",
"user": "some-user",
"team_id": "T0G9PQBBK",
"user_id": "W99999"
}
"""
def _handle(self) -> None:
parsed_path: ParseResult = urlparse(self.path)
path = parsed_path.path
self.received_requests[path] = self.received_requests.get(path, 0) + 1
try:
if path == "/webhook":
self.send_response(200)
self.set_common_headers()
self.wfile.write("OK".encode("utf-8"))
return
if path == "/received_requests.json":
self.send_response(200)
self.set_common_headers()
self.wfile.write(json.dumps(self.received_requests).encode("utf-8"))
return
body: dict = {"ok": True}
if path == "/oauth.v2.access":
if self.headers.get("authorization") is not None:
request_body = self._parse_request_body(
parsed_path=parsed_path,
content_len=int(self.headers.get("Content-Length") or 0),
)
self.logger.info(f"request body: {request_body}")
if request_body.get("grant_type") == "refresh_token":
refresh_token = request_body.get("refresh_token")
if refresh_token is not None:
if "bot-valid" in refresh_token:
self.send_response(200)
self.set_common_headers()
self.wfile.write(self.oauth_v2_access_bot_refresh_response.encode("utf-8"))
return
if "user-valid" in refresh_token:
self.send_response(200)
self.set_common_headers()
self.wfile.write(self.oauth_v2_access_user_refresh_response.encode("utf-8"))
return
elif request_body.get("code") is not None:
self.send_response(200)
self.set_common_headers()
self.wfile.write(self.oauth_v2_access_response.encode("utf-8"))
return
if self.is_valid_user_token():
if path == "/auth.test":
self.send_response(200)
self.set_common_headers()
self.wfile.write(self.user_auth_test_response.encode("utf-8"))
return
if self.is_valid_token():
if path == "/auth.test":
self.send_response(200)
self.set_common_headers()
self.wfile.write(self.bot_auth_test_response.encode("utf-8"))
return
request_body = self._parse_request_body(
parsed_path=parsed_path,
content_len=int(self.headers.get("Content-Length") or 0),
)
self.logger.info(f"request: {path} {request_body}")
header = self.headers["authorization"]
pattern = str(header).split("xoxb-", 1)[1]
if pattern.isnumeric():
self.send_response(int(pattern))
self.set_common_headers()
self.wfile.write("""{"ok":false}""".encode("utf-8"))
return
else:
body = self.invalid_auth
self.send_response(HTTPStatus.OK)
self.set_common_headers()
self.wfile.write(json.dumps(body).encode("utf-8"))
except Exception as e:
self.logger.error(str(e), exc_info=True)
raise
def do_GET(self) -> None:
self._handle()
def do_POST(self) -> None:
self._handle()
def _parse_request_body(self, parsed_path: ParseResult, content_len: int) -> dict:
post_body = self.rfile.read(content_len)
request_body: dict = {}
if post_body:
with suppress(UnicodeDecodeError):
post_body_decoded = post_body.decode("utf-8")
if post_body_decoded.startswith("{"):
request_body = json.loads(post_body)
else:
request_body = {k: v[0] for k, v in parse_qs(post_body_decoded).items()}
else:
if parsed_path and parsed_path.query:
request_body = {k: v[0] for k, v in parse_qs(parsed_path.query).items()}
return request_body
class MockSignatureVerifier:
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
def is_valid_request(
self,
body: str | bytes,
headers: dict[str, str],
) -> bool:
return True
def is_valid(
self,
body: str | bytes,
timestamp: str,
signature: str,
) -> bool:
return True
class SlackAppBuilder:
def __init__(
self,
signing_secret: str,
client: WebClient,
) -> None:
self.signing_secret = signing_secret
self.client = client
def __call__(self) -> App:
with patch(
"slack_bolt.middleware.request_verification.request_verification.SignatureVerifier",
MockSignatureVerifier,
):
app = App(signing_secret=self.signing_secret, client=self.client)
@app.command("/hello")
def hello(ack: Ack, say: Say, body: dict) -> None:
print("/hello command body:", body)
ack()
say("Hello!")
return app
class ServerProcess(Process):
START_TIMEOUT = 5
def __init__(
self,
handler: Type[SimpleHTTPRequestHandler] = SimpleHTTPRequestHandler,
host: str = "",
port: int = 8000,
):
super().__init__()
self.handler = handler
self.host = host
self.port = port
self.server: HTTPServer | None = None
def start(self) -> None:
super().start()
# wait for server to start
wait_for_server(f"http://{self.host}:{self.port}", timeout=self.START_TIMEOUT)
def run(self) -> None:
self.server = HTTPServer((self.host, self.port), self.handler)
self.server.serve_forever()
def terminate(self) -> None:
if self.server:
self.server.shutdown()
self.server.server_close()
super().terminate()
class SlackAppProcess(Process):
START_TIMEOUT = 5
def __init__(
self,
builder: Callable[[], App],
port: int = 3001,
path: str = "/slack/events",
) -> None:
super().__init__()
self.builder = builder
self.port = port
self.path = path
def start(self) -> None:
super().start()
wait_for_server(f"http://localhost:{self.port}", timeout=self.START_TIMEOUT)
def run(self) -> None:
server = SlackAppDevelopmentServer(port=3001, path="/slack/events", app=self.builder())
server.start()
class ProcessContext:
def __init__(self, process: Process) -> None:
self.process = process
def __enter__(self) -> None:
self.process.start()
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.process.terminate()
def main() -> None:
token = "xoxb-valid"
signing_secret = "secret"
with ProcessContext(ServerProcess(MockHandler)):
client = WebClient(token=token, base_url="http://localhost:8000")
builder = SlackAppBuilder(signing_secret=signing_secret, client=client)
with ProcessContext(SlackAppProcess(builder)):
request = urllib.request.Request(
"http://localhost:3001/slack/events",
data=json.dumps({"type": "url_verification", "challenge": "challenge"}).encode("utf-8"),
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
},
)
response = urllib.request.urlopen(request)
response_body = json.loads(response.read().decode("utf-8"))
assert response_body == {"challenge": "challenge"}
request = urllib.request.Request(
"http://localhost:3001/slack/events",
data=json.dumps(
{
"command": "/hello",
"text": "hoge",
"channel_id": "C111",
}
).encode("utf-8"),
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
},
)
urllib.request.urlopen(request)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment