Created
January 16, 2020 15:33
-
-
Save saikocat/7ae4eefe6d686a9b1df24a4312e4ab26 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/harterrt/cookiecutter-python-etl/blob/master/README.md#benefits | |
# https://github.com/mozilla/python_mozetl/tree/master/tests | |
import pytest | |
from pyspark.sql import SparkSession | |
import json | |
@pytest.fixture(scope="session") | |
def spark(): | |
spark = ( | |
SparkSession.builder.master("local").appName("python_mozetl_test").getOrCreate() | |
# config("") - parallelism -> 2, compress -> False | |
) | |
# Set server timezone at UTC+0 | |
# spark.conf.set("spark.sql.session.timeZone", "UTC") | |
yield spark | |
spark.stop() | |
@pytest.fixture | |
def spark_context(spark): | |
return spark.sparkContext | |
@pytest.fixture(autouse=True) | |
def no_spark_stop(monkeypatch): | |
""" Disable stopping the shared spark session during tests """ | |
def nop(*args, **kwargs): | |
print("Disabled spark.stop for testing") | |
monkeypatch.setattr("pyspark.sql.SparkSession.stop", nop) | |
@pytest.fixture | |
def df_equals(row_to_dict): | |
def to_comparable(df): | |
# choose a unique ordering; lexographic ordering of dictionaries | |
return sorted( | |
map(row_to_dict, df.collect()), key=lambda x: json.dumps(x, sort_keys=True) | |
) | |
def func(this, that): | |
return to_comparable(this) == to_comparable(that) | |
return func | |
@pytest.fixture(scope="session") | |
def row_to_dict(): | |
"""Convert pyspark.Row to dict for easier unordered comparison""" | |
def func(row, recursive=True): | |
return row.asDict(recursive=recursive) | |
return func |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment