Created
May 11, 2023 10:21
-
-
Save altescy/915986a1e883e84108526ae4526e067d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import logging | |
import time | |
import urllib.request | |
from contextlib import suppress | |
from http import HTTPStatus | |
from http.server import HTTPServer, SimpleHTTPRequestHandler | |
from multiprocessing import Process | |
from typing import Any, Callable, Type | |
from unittest.mock import patch | |
from urllib.error import HTTPError, URLError | |
from urllib.parse import ParseResult, parse_qs, urlparse | |
from slack_bolt.app.app import App, SlackAppDevelopmentServer | |
from slack_bolt.context.ack import Ack | |
from slack_bolt.context.say import Say | |
from slack_sdk import WebClient | |
def wait_for_server(url: str, timeout: float = 5.0) -> None: | |
start_time = time.time() | |
while time.time() - start_time < timeout: | |
try: | |
urllib.request.urlopen(url) | |
return | |
except URLError as e: | |
if isinstance(e, HTTPError): | |
return | |
time.sleep(0.1) | |
raise TimeoutError(f"Server did not start in {timeout} seconds") | |
class MockHandler(SimpleHTTPRequestHandler): | |
protocol_version = "HTTP/1.1" | |
default_request_version = "HTTP/1.1" | |
logger = logging.getLogger(__name__) | |
received_requests: dict | |
def setup(self) -> None: | |
super().setup() | |
self.received_requests = {} | |
def is_valid_token(self) -> bool: | |
return "Authorization" in self.headers and str(self.headers["Authorization"]).startswith("Bearer xoxb-") | |
def is_valid_user_token(self) -> bool: | |
return "Authorization" in self.headers and str(self.headers["Authorization"]).startswith("Bearer xoxp-") | |
def set_common_headers(self) -> None: | |
self.send_header("content-type", "application/json;charset=utf-8") | |
self.send_header("connection", "close") | |
self.end_headers() | |
invalid_auth = { | |
"ok": False, | |
"error": "invalid_auth", | |
} | |
oauth_v2_access_response = """ | |
{ | |
"ok": true, | |
"access_token": "xoxb-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy", | |
"token_type": "bot", | |
"scope": "chat:write,commands", | |
"bot_user_id": "U0KRQLJ9H", | |
"app_id": "A0KRD7HC3", | |
"team": { | |
"name": "Slack Softball Team", | |
"id": "T9TK3CUKW" | |
}, | |
"enterprise": { | |
"name": "slack-sports", | |
"id": "E12345678" | |
}, | |
"authed_user": { | |
"id": "U1234", | |
"scope": "chat:write", | |
"access_token": "xoxp-1234", | |
"token_type": "user" | |
} | |
} | |
""" | |
oauth_v2_access_bot_refresh_response = """ | |
{ | |
"ok": true, | |
"app_id": "A0KRD7HC3", | |
"access_token": "xoxb-valid-refreshed", | |
"expires_in": 43200, | |
"refresh_token": "xoxe-1-valid-bot-refreshed", | |
"token_type": "bot", | |
"scope": "chat:write,commands", | |
"bot_user_id": "U0KRQLJ9H", | |
"team": { | |
"name": "Slack Softball Team", | |
"id": "T9TK3CUKW" | |
}, | |
"enterprise": { | |
"name": "slack-sports", | |
"id": "E12345678" | |
} | |
} | |
""" | |
oauth_v2_access_user_refresh_response = """ | |
{ | |
"ok": true, | |
"app_id": "A0KRD7HC3", | |
"access_token": "xoxp-valid-refreshed", | |
"expires_in": 43200, | |
"refresh_token": "xoxe-1-valid-user-refreshed", | |
"token_type": "user", | |
"scope": "search:read", | |
"team": { | |
"name": "Slack Softball Team", | |
"id": "T9TK3CUKW" | |
}, | |
"enterprise": { | |
"name": "slack-sports", | |
"id": "E12345678" | |
} | |
} | |
""" | |
bot_auth_test_response = """ | |
{ | |
"ok": true, | |
"url": "https://subarachnoid.slack.com/", | |
"team": "Subarachnoid Workspace", | |
"user": "bot", | |
"team_id": "T0G9PQBBK", | |
"user_id": "W23456789", | |
"bot_id": "BZYBOTHED" | |
} | |
""" | |
user_auth_test_response = """ | |
{ | |
"ok": true, | |
"url": "https://subarachnoid.slack.com/", | |
"team": "Subarachnoid Workspace", | |
"user": "some-user", | |
"team_id": "T0G9PQBBK", | |
"user_id": "W99999" | |
} | |
""" | |
def _handle(self) -> None: | |
parsed_path: ParseResult = urlparse(self.path) | |
path = parsed_path.path | |
self.received_requests[path] = self.received_requests.get(path, 0) + 1 | |
try: | |
if path == "/webhook": | |
self.send_response(200) | |
self.set_common_headers() | |
self.wfile.write("OK".encode("utf-8")) | |
return | |
if path == "/received_requests.json": | |
self.send_response(200) | |
self.set_common_headers() | |
self.wfile.write(json.dumps(self.received_requests).encode("utf-8")) | |
return | |
body: dict = {"ok": True} | |
if path == "/oauth.v2.access": | |
if self.headers.get("authorization") is not None: | |
request_body = self._parse_request_body( | |
parsed_path=parsed_path, | |
content_len=int(self.headers.get("Content-Length") or 0), | |
) | |
self.logger.info(f"request body: {request_body}") | |
if request_body.get("grant_type") == "refresh_token": | |
refresh_token = request_body.get("refresh_token") | |
if refresh_token is not None: | |
if "bot-valid" in refresh_token: | |
self.send_response(200) | |
self.set_common_headers() | |
self.wfile.write(self.oauth_v2_access_bot_refresh_response.encode("utf-8")) | |
return | |
if "user-valid" in refresh_token: | |
self.send_response(200) | |
self.set_common_headers() | |
self.wfile.write(self.oauth_v2_access_user_refresh_response.encode("utf-8")) | |
return | |
elif request_body.get("code") is not None: | |
self.send_response(200) | |
self.set_common_headers() | |
self.wfile.write(self.oauth_v2_access_response.encode("utf-8")) | |
return | |
if self.is_valid_user_token(): | |
if path == "/auth.test": | |
self.send_response(200) | |
self.set_common_headers() | |
self.wfile.write(self.user_auth_test_response.encode("utf-8")) | |
return | |
if self.is_valid_token(): | |
if path == "/auth.test": | |
self.send_response(200) | |
self.set_common_headers() | |
self.wfile.write(self.bot_auth_test_response.encode("utf-8")) | |
return | |
request_body = self._parse_request_body( | |
parsed_path=parsed_path, | |
content_len=int(self.headers.get("Content-Length") or 0), | |
) | |
self.logger.info(f"request: {path} {request_body}") | |
header = self.headers["authorization"] | |
pattern = str(header).split("xoxb-", 1)[1] | |
if pattern.isnumeric(): | |
self.send_response(int(pattern)) | |
self.set_common_headers() | |
self.wfile.write("""{"ok":false}""".encode("utf-8")) | |
return | |
else: | |
body = self.invalid_auth | |
self.send_response(HTTPStatus.OK) | |
self.set_common_headers() | |
self.wfile.write(json.dumps(body).encode("utf-8")) | |
except Exception as e: | |
self.logger.error(str(e), exc_info=True) | |
raise | |
def do_GET(self) -> None: | |
self._handle() | |
def do_POST(self) -> None: | |
self._handle() | |
def _parse_request_body(self, parsed_path: ParseResult, content_len: int) -> dict: | |
post_body = self.rfile.read(content_len) | |
request_body: dict = {} | |
if post_body: | |
with suppress(UnicodeDecodeError): | |
post_body_decoded = post_body.decode("utf-8") | |
if post_body_decoded.startswith("{"): | |
request_body = json.loads(post_body) | |
else: | |
request_body = {k: v[0] for k, v in parse_qs(post_body_decoded).items()} | |
else: | |
if parsed_path and parsed_path.query: | |
request_body = {k: v[0] for k, v in parse_qs(parsed_path.query).items()} | |
return request_body | |
class MockSignatureVerifier: | |
def __init__(self, *args: Any, **kwargs: Any) -> None: | |
pass | |
def is_valid_request( | |
self, | |
body: str | bytes, | |
headers: dict[str, str], | |
) -> bool: | |
return True | |
def is_valid( | |
self, | |
body: str | bytes, | |
timestamp: str, | |
signature: str, | |
) -> bool: | |
return True | |
class SlackAppBuilder: | |
def __init__( | |
self, | |
signing_secret: str, | |
client: WebClient, | |
) -> None: | |
self.signing_secret = signing_secret | |
self.client = client | |
def __call__(self) -> App: | |
with patch( | |
"slack_bolt.middleware.request_verification.request_verification.SignatureVerifier", | |
MockSignatureVerifier, | |
): | |
app = App(signing_secret=self.signing_secret, client=self.client) | |
@app.command("/hello") | |
def hello(ack: Ack, say: Say, body: dict) -> None: | |
print("/hello command body:", body) | |
ack() | |
say("Hello!") | |
return app | |
class ServerProcess(Process): | |
START_TIMEOUT = 5 | |
def __init__( | |
self, | |
handler: Type[SimpleHTTPRequestHandler] = SimpleHTTPRequestHandler, | |
host: str = "", | |
port: int = 8000, | |
): | |
super().__init__() | |
self.handler = handler | |
self.host = host | |
self.port = port | |
self.server: HTTPServer | None = None | |
def start(self) -> None: | |
super().start() | |
# wait for server to start | |
wait_for_server(f"http://{self.host}:{self.port}", timeout=self.START_TIMEOUT) | |
def run(self) -> None: | |
self.server = HTTPServer((self.host, self.port), self.handler) | |
self.server.serve_forever() | |
def terminate(self) -> None: | |
if self.server: | |
self.server.shutdown() | |
self.server.server_close() | |
super().terminate() | |
class SlackAppProcess(Process): | |
START_TIMEOUT = 5 | |
def __init__( | |
self, | |
builder: Callable[[], App], | |
port: int = 3001, | |
path: str = "/slack/events", | |
) -> None: | |
super().__init__() | |
self.builder = builder | |
self.port = port | |
self.path = path | |
def start(self) -> None: | |
super().start() | |
wait_for_server(f"http://localhost:{self.port}", timeout=self.START_TIMEOUT) | |
def run(self) -> None: | |
server = SlackAppDevelopmentServer(port=3001, path="/slack/events", app=self.builder()) | |
server.start() | |
class ProcessContext: | |
def __init__(self, process: Process) -> None: | |
self.process = process | |
def __enter__(self) -> None: | |
self.process.start() | |
def __exit__(self, *args: Any, **kwargs: Any) -> None: | |
self.process.terminate() | |
def main() -> None: | |
token = "xoxb-valid" | |
signing_secret = "secret" | |
with ProcessContext(ServerProcess(MockHandler)): | |
client = WebClient(token=token, base_url="http://localhost:8000") | |
builder = SlackAppBuilder(signing_secret=signing_secret, client=client) | |
with ProcessContext(SlackAppProcess(builder)): | |
request = urllib.request.Request( | |
"http://localhost:3001/slack/events", | |
data=json.dumps({"type": "url_verification", "challenge": "challenge"}).encode("utf-8"), | |
headers={ | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {token}", | |
}, | |
) | |
response = urllib.request.urlopen(request) | |
response_body = json.loads(response.read().decode("utf-8")) | |
assert response_body == {"challenge": "challenge"} | |
request = urllib.request.Request( | |
"http://localhost:3001/slack/events", | |
data=json.dumps( | |
{ | |
"command": "/hello", | |
"text": "hoge", | |
"channel_id": "C111", | |
} | |
).encode("utf-8"), | |
headers={ | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {token}", | |
}, | |
) | |
urllib.request.urlopen(request) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment