Skip to content

Instantly share code, notes, and snippets.

@jacksmith15
Created August 10, 2020 11:07
Show Gist options
  • Save jacksmith15/bfb524ee0cc05913cd38a40e3cf6d28f to your computer and use it in GitHub Desktop.
Save jacksmith15/bfb524ee0cc05913cd38a40e3cf6d28f to your computer and use it in GitHub Desktop.
faust-component-test.py
from random import randint
import faust
from simple_settings import settings
app = faust.App(**settings.FAUST)
source_topic = app.topic("source", value_type=int)
square_topic = app.topic("square", key_type=int, value_type=int)
@app.task
async def source():
for _ in range(10):
await source_topic.send(value=randint())
@app.agent(source_topic)
async def square(messages: StreamT[int]):
async for message in messages:
await square_topic.send(key=message, value=message ** 2)
@app.agent(square_topic)
async def write(messages: StreamT[int]):
async for value, square in messages:
print(f"{value}^2 == {square}")
from typing import Sequence
from unittest.mock import patch
import pytest
from faust import App
from faust.topics import Topic
from faust.channels import Channel
class _MockTopic(Channel):
def __init__(self, app, topics: Sequence[str] = None, **kwargs):
self.topics: List[str] = list(topics) if topics else []
self.events: List = []
params = signature(Topic).parameters
kwargs = {key: value for key, value in kwargs.items() if key in params}
super().__init__(app, **kwargs)
async def put(self, value):
self.events.append(value)
await super().put(value)
def clear(self):
self.events = []
def __iter__(self):
return iter(self.events)
def __len__(self):
return len(self.events)
def mock_topic(app: App, topic: Topic):
return _MockTopic(
app,
topics=topic.topics,
loop=topic.loop,
schema=topic.schema,
key_type=topic.key_type,
value_type=topic.value_type,
)
@contextmanager
def mock_agent_topics(app: App):
original: Dict[str, Topic] = {}
mocked: Dict[str, _MockTopic] = {}
for name, agent in dict(app.agents).items():
if isinstance(agent.channel, Topic):
original[name] = agent.channel
topic_name = agent.channel.get_topic_name()
if topic_name in mocked:
agent.channel = mocked[topic_name]
else:
agent.channel = mocked[topic_name] = mock_topic(app, agent.channel)
yield mocked
for name, agent in dict(app.agents).items():
agent.channel = original[name]
@pytest.fixture(scope="session", autouse=True)
def topics(event_loop):
with mock_agent_topics(dsx_app) as mock_topics:
yield mock_topics
@pytest.fixture(scope="class", autouse=True)
def truncate_topics(topics):
for topic in topics.values():
topic.clear()
@pytest.fixture(scope="session", autouse=True)
def app(event_loop, topics):
dsx_app.finalize()
dsx_app.conf.store = "memory://"
dsx_app.flow_control.resume()
yield dsx_app
@pytest.fixture(autouse=True)
def source_topic(topics):
topic = topics["source"]
with patch("app.source_topic", new=topic):
yield topic
@pytest.fixture(autouse=True)
def square_topic(topics):
topic = topics["square"]
with patch("app.square_topic", new=topic):
yield topic
from app import source
def test_that_source_values_are_squared(source_topic, square_topic):
await source()
source_values = [message.value for message in source_topic]
assert all(isinstance(value, int) for value in source_values)
async with square.test_context() as agent:
for message in source_topic:
agent.put(message)
square_keys = [message.key for message in square_topic]
assert source_values == square_keys
square_values = [message.value for message in square_topic]
assert square_values == [value ^ 2 for value in source_values]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment