Skip to content

Instantly share code, notes, and snippets.

@Abacn
Last active June 7, 2023 22:08
Show Gist options
  • Save Abacn/3fa72fab4b0bbf5e3de395106ef47cfb to your computer and use it in GitHub Desktop.
Save Abacn/3fa72fab4b0bbf5e3de395106ef47cfb to your computer and use it in GitHub Desktop.
Demo Codes for Apache Beam cross-language JDBCIO
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Demo Codes for Apache Beam cross-language JDBCIO
from base64 import b64encode
import datetime
from decimal import Decimal
import logging
import os
import typing
import uuid
import apache_beam as beam
from apache_beam import coders
from apache_beam.io.gcp import bigquery_tools
from apache_beam.io.gcp.bigquery import WriteToBigQuery
from apache_beam.io.iobase import Read
from apache_beam.io.jdbc import ReadFromJdbc, WriteToJdbc
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.pipeline import Pipeline
from apache_beam.testing.synthetic_pipeline import SyntheticSource
from apache_beam.typehints.schemas import LogicalType
from apache_beam.typehints.schemas import MillisInstant
from apache_beam.utils.timestamp import Timestamp
_LOGGER = logging.getLogger(__name__)
JdbcWriteTestRow = typing.NamedTuple(
"JdbcWriteTestRow",
[
("f_id", int),
("f_real", float),
("f_fixedchar", str),
("f_varchar", str),
("f_bin", bytes),
("f_timestamp", Timestamp),
("f_decimal", Decimal)
],
)
def row_to_dict(row):
return row._asdict()
def row_to_dict_patch_byte(row):
as_dict = row._asdict()
# In FILE_LOAD, incoming bytes are base64 encoded:
as_dict['f_bin'] = b64encode(as_dict['f_bin'])
as_dict['f_timestamp'] = as_dict['f_timestamp'].to_utc_datetime().replace(tzinfo=datetime.timezone.utc)
return as_dict
SCHEMA = {
'fields': [{
'name': 'f_id', 'type': 'INTEGER', 'mode': 'REQUIRED'
}, {
'name': 'f_real', 'type': 'FLOAT64', 'mode': 'NULLABLE'
}, {
'name': 'f_fixedchar', 'type': 'STRING', 'mode': 'NULLABLE'
}, {
'name': 'f_varchar', 'type': 'STRING', 'mode': 'NULLABLE'
}, {
'name': 'f_bin', 'type': 'BYTES', 'mode': 'NULLABLE'
}, {
'name': 'f_timestamp', 'type': 'TIMESTAMP', 'mode': 'NULLABLE'
}, {
'name': 'f_decimal', 'type': 'NUMERIC', 'mode': 'NULLABLE'
}]
}
class SequenceSyntheticSource(SyntheticSource):
def _gen_kv_pair(self, generator, index):
return int(index), os.urandom(self.element_size)
def print_row(row):
_LOGGER.info(f"row: {row}")
return row
class JdbcXlang:
ROW_COUNT = 100_000_000
def __init__(self):
self.table_name = 'jdbcxlang' + str(uuid.uuid4())[:8]
self.dataset = '<gcp_project>:<dataset>'
self.username = 'postgres'
self.password = '***'
self.driver = 'org.postgresql.Driver'
self.endpoint = 'localhost:5432'
self.jdbc_url = 'jdbc:postgresql://'+self.endpoint+'/postgres' # ?stringtype=unspecified
def setup(self):
"""Setup steps before running test, such as creating the table."""
# Need a separate module to avoid involve unwanted sqlalchemy dependency in main session
# code available at https://gist.github.com/Abacn/b3f234f0eb8515f000cea3464f874e4c
import jdbc_test_tool
jdbc_schema = '(f_id INTEGER, f_real FLOAT, f_fixedchar VARCHAR(12), f_varchar CHAR(12), f_bin bytea, f_timestamp Timestamp, f_decimal DECIMAL(10,2))'
jdbc_test_tool.create_table(table_name=self.table_name, schema=jdbc_schema, username=self.username, password=self.password, endpoint=self.endpoint)
print('Created ' + self.table_name)
def teardown(self):
import jdbc_test_tool
jdbc_test_tool.teardown_table(table_name=self.table_name, username=self.username, password=self.password, endpoint=self.endpoint)
def run_write(self):
"""Pipeline that writes data to Jdbc."""
options = PipelineOptions()
input_spec = {"numRecords": self.ROW_COUNT, "keySizeBytes": 1, "valueSizeBytes": 10}
coders.registry.register_coder(JdbcWriteTestRow, coders.RowCoder)
with Pipeline(options=options) as p:
input = (p
| Read(SequenceSyntheticSource(input_spec))
| beam.Map(lambda i:
JdbcWriteTestRow(i[0], i[0] + 0.1, str(i[0]-1)+".23", str(i[0]-1)+".25", i[1], Timestamp.now(), Decimal(str(i[0]-1)+".23"))
).with_output_types(JdbcWriteTestRow)
| 'Write to jdbc' >> WriteToJdbc(
table_name=self.table_name,
driver_class_name=self.driver,
jdbc_url=self.jdbc_url.replace('localhost', 'host.docker.internal'),
username=self.username,
password=self.password
))
def run_read(self):
"""Pipeline that read data from Jdbc."""
options = PipelineOptions()
LogicalType.register_logical_type(MillisInstant)
with Pipeline(options=options) as p:
output = (p | 'Read from jdbc' >> ReadFromJdbc(
table_name=self.table_name,
driver_class_name=self.driver,
jdbc_url=self.jdbc_url.replace('localhost', 'host.docker.internal'),
username=self.username,
password=self.password))
if self.ROW_COUNT < 20:
_ = (output | beam.Map(print_row))
_ = (output | beam.combiners.Count.Globally() | beam.Map(print))
def run_read_partition(self):
"""Pipeline that read data from Jdbc with partition."""
LogicalType.register_logical_type(MillisInstant)
options = PipelineOptions()
with Pipeline(options=options) as p:
output = (
p
| 'Partitioned Read from jdbc' >> ReadFromJdbc(
table_name=self.table_name,
driver_class_name=self.driver,
partition_column='f_id',
partitions=max(4, min(100, int(self.ROW_COUNT/100_000))),
jdbc_url=self.jdbc_url.replace('localhost', 'host.docker.internal'),
username=self.username,
password=self.password))
if self.ROW_COUNT < 20:
_ = output | beam.Map(print_row)
_ = (output | beam.combiners.Count.Globally() | beam.Map(print))
def run_jdbc_to_bigquery_storage_write(self):
"""Pipeline that read data from Jdbc and write it to BigQuery using STORAGE_WRITE_API."""
LogicalType.register_logical_type(MillisInstant)
options = PipelineOptions()
with Pipeline(options=options) as p:
rows = (p | 'Read from jdbc' >> ReadFromJdbc(
table_name=self.table_name,
driver_class_name=self.driver,
partition_column='f_id',
partitions=max(4, min(100, int(self.ROW_COUNT/100_000))),
jdbc_url=self.jdbc_url.replace('localhost', 'host.docker.internal'),
username=self.username,
password=self.password))
# rows is a NamedTuple
_ = rows | beam.Map(row_to_dict) | WriteToBigQuery(
table=self.dataset + '.' + self.table_name,
method=WriteToBigQuery.Method.STORAGE_WRITE_API,
schema=SCHEMA)
def run_jdbc_to_bigquery_storage_file_load(self):
"""Pipeline that read data from Jdbc and write it to BigQuery using FILE_LOAD."""
LogicalType.register_logical_type(MillisInstant)
options = PipelineOptions()
with Pipeline(options=options) as p:
rows = (p | 'Read from jdbc' >> ReadFromJdbc(
table_name=self.table_name,
driver_class_name=self.driver,
partition_column='f_id',
partitions=max(4, min(100, int(self.ROW_COUNT/100_000))),
jdbc_url=self.jdbc_url.replace('localhost', 'host.docker.internal'),
username=self.username,
password=self.password))
# rows is a NamedTuple
# default temp_file_format (json) also works
_ = rows | beam.Map(row_to_dict_patch_byte) | WriteToBigQuery(
table=self.dataset + '.' + self.table_name,
temp_file_format=bigquery_tools.FileFormat.AVRO,
schema=SCHEMA)
def run_jdbc_read_unsupported_types(self):
"""Write and read unsupported logical types through cast"""
import jdbc_test_tool
table_name = 'test_unsupported'
jdbc_schema = '(f_id INTEGER, f_date DATE, f_time TIME, f_timestamp TIMESTAMP)'
jdbc_test_tool.create_table(table_name=table_name, schema=jdbc_schema, username=self.username, password=self.password, endpoint=self.endpoint)
jdbc_test_tool.query(f'INSERT INTO {table_name} values'
"(1, '2023-05-01', '01:01:01', NULL),"
"(2, '2023-06-15', '12:34:56', '2023-06-15T12:34:56.789Z')")
LogicalType.register_logical_type(MillisInstant)
options = PipelineOptions()
with Pipeline(options=options) as p:
rows = (p | 'Read from jdbc' >> ReadFromJdbc(
query=f"select f_id, CAST(f_date as TEXT), CAST(f_time as TEXT), f_timestamp from {table_name}",
table_name=table_name,
driver_class_name=self.driver,
jdbc_url=self.jdbc_url.replace('localhost', 'host.docker.internal'),
username=self.username,
password=self.password))
_ = rows | beam.Map(print)
# run locally: python jdbcxlang.py --temp_location=gs://<bucket_name>/temp
# run on Dataflow: python jdbcxlang.py --runner=DataflowRunner --project=<gcp_project> --region=us-central1 --temp_location=gs://<bucket_name>/temp --save_main_session
# --num_workers=1 --autoscaling_algorithm=NONE
# --max_num_workers=10
if __name__ == '__main__':
logging.getLogger().setLevel('INFO')
test_instance = JdbcXlang()
test_instance.setup()
test_instance.run_write()
test_instance.run_read()
test_instance.run_read_partition()
test_instance.run_jdbc_to_bigquery_storage_write()
test_instance.run_jdbc_to_bigquery_storage_file_load()
test_instance.teardown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment