Last active
April 19, 2024 03:39
-
-
Save smothiki/e265b5b1173fc261cfe5a99f85c0026f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import logging as log | |
from pathlib import Path | |
from pyspark.sql import DataFrame | |
from pyspark.sql.types import StructType | |
from pyspark.sql.functions import udf | |
from pyspark.sql.types import StructField | |
from pyspark.sql.functions import lit | |
from pyspark.sql.types import StringType | |
import cml.data_v1 as cmldata | |
import os | |
import uuid | |
import os | |
# Tika is a library that allows you to extract text from a file in one of the many formats it supports | |
from tika import parser, detector, language | |
os.environ['TIKA_CLIENT_ONLY']="True" | |
os.environ['TIKA_SERVER_ENDPOINT']='http://localhost:9998' | |
def pdfcontent(file): | |
# os.environ["TIKA_SERVER_JAR"]='https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/2.6.0/tika-server-standard-2.6.0.jar' | |
# tika.initVM() | |
return parser.from_file(file)["content"] | |
# | |
#def pdfcontent(file): | |
# from pypdf import PdfReader | |
# reader = PdfReader(file) | |
# number_of_pages = len(reader.pages) | |
# text='' | |
# for i in range(0,number_of_pages): | |
# page = reader.pages[i] | |
# text += page.extract_text() | |
# return text | |
class SVError(Exception): | |
def __init__(self, message=None): | |
""" | |
Constructor | |
:param message: the error message. | |
""" | |
self.message = message | |
class TextExtractionError(SVError): | |
""" | |
Exception raised when a text could not be extracted from a file | |
:param path: name of the file from which text could not be extracted | |
:param message: explanation of the error | |
:rtype: object | |
""" | |
def __init__(self, path, message=None): | |
super().__init__(message=message) | |
self.path = path | |
if not self.message: | |
self.message = ( | |
f"Could not extract text from the file. " | |
f"Check if it a Tika-supported document format: {self.path}" | |
) | |
class TextExtraction: | |
def __init__(self, path: str): | |
self.path_ = path | |
self.subject = os.path.basename(os.path.dirname(path)) | |
self.text_ = self.to_text(self.path_) | |
self.doctype_ = self.document_type(self.path_) | |
self.language_ = language.from_buffer(self.text_) | |
self.id_ = str(uuid.uuid4()) | |
@staticmethod | |
def to_text(path: str) -> str: | |
""" | |
Extracts plain-text from a file, in one of the Tika-supported formats | |
:param path: path to the document file | |
:return: text from document file | |
""" | |
# Preconditions check for an existing, readable, non-empty file | |
# check_valid_file(path) | |
log.info(f"Parsing file: {path}") | |
try: | |
text_content: str = pdfcontent(path) | |
if text_content is None: | |
raise TextExtractionError( | |
path=path, message=f"No content found in file: {path}" | |
) | |
return text_content.strip() | |
except Exception as e: | |
raise TextExtractionError(path, str(e)) | |
@staticmethod | |
def document_type(path: str) -> str: | |
""" | |
Determines the MIME type of the file | |
:param path: the filesystem path to the document. | |
:return: the MIME-type, such as "application/pdf" | |
""" | |
# Preconditions check for an existing, readable, non-empty file | |
# check_valid_file(path) | |
return detector.from_file(path) | |
def __repr__(self): | |
limit: int = min(100, len(self.text_)) | |
return f" Document type: {self.doctype_}\n Language: {self.language_}\n Text: {self.text_[:limit]}..." | |
#class TextExtractionJob(BootcampComputeJob): | |
class TextExtractionJob(): | |
""" | |
This class is the entry point for the text extraction job. | |
Given a directory of documents, it read all the files in the directory, | |
and all the subdirectories recursively, and extracts plain text from each file. | |
It then stores the extracted text in a database table. | |
""" | |
def __init__(self): | |
self.job_name = "TextExtractionJob" | |
logging.info(f'Initializing {self.job_name} job') | |
CONNECTION_NAME = "eng-ml-dev-env-aws-dl" | |
conn = cmldata.get_connection(CONNECTION_NAME) | |
self.spark = conn.get_spark_session() | |
self.text_struc = StructType([ | |
StructField("path", StringType(), True), | |
StructField("subject", StringType(), True), | |
StructField("text", StringType(), True), | |
StructField("doctype", StringType(), True), | |
StructField("language", StringType(), True), | |
StructField("uuid", StringType(), True) | |
]) | |
@staticmethod | |
def _udf_text_extraction(path): | |
""" | |
A function that extracts text, its document-type and language | |
from a file, given its path. | |
""" | |
extraction = TextExtraction(path) | |
return {"path": path, | |
"subject": extraction.subject, | |
"text": extraction.text_, | |
"doctype": extraction.doctype_, | |
"language": extraction.language_, | |
"uuid": extraction.id_ | |
} | |
def run(self) -> None: | |
""" | |
This method is the entry point for the compute job where | |
the text is extracted from the documents, and stored in a database table. | |
:return: None | |
""" | |
logging.info(f'Running {self.job_name} job') | |
files_df = self._list_documents() | |
logging.info(f'Extracting text from {files_df.count()} files') | |
df = self._extract_text(files_df) | |
# self._persist(df=df, table='DOCUMENT') | |
def _extract_text(self, files_df: DataFrame) -> DataFrame: | |
""" | |
Extracts plain-text from each file in the DataFrame | |
:param files_df: DataFrame containing the list of files | |
:return: DataFrame containing the extracted text | |
""" | |
# Step 1: Extract text from each file | |
files_df = files_df.withColumn('extract', | |
udf(self._udf_text_extraction, | |
self.text_struc)(files_df.value)) | |
# Step 2: Extract the columns from the nested structure | |
df = files_df.select('extract.language', | |
'extract.path', | |
'extract.subject', | |
'extract.doctype', | |
'extract.text', | |
'extract.uuid') | |
# Step 3: Rename the columns | |
df = df.withColumnRenamed("language", "LANGUAGE") \ | |
.withColumnRenamed("uuid", "UUID") \ | |
.withColumnRenamed("path", "PATH") \ | |
.withColumnRenamed("subject", "SUBJECT") \ | |
.withColumnRenamed("doctype", "DOCTYPE") \ | |
.withColumnRenamed("text", "TEXT") | |
# Step 4: Add boolean columns that help in later processing | |
df = df.withColumn('CHUNKED', lit(False)) | |
print(df.count()) | |
# Step 5: Show the DataFrame | |
df.show() | |
return df | |
def _list_documents(self) -> DataFrame: | |
""" | |
Lists all the files in the directory, and returns as a DataFrame | |
:return: DataFrame containing the list of files | |
""" | |
# Step 1: List all files in the directory using pathlib | |
all_files = ["/home/cdsw/docs/test.pdf"] | |
print("all docs", all_files) | |
# Step 2: Read all file-names into a Spark DataFrame | |
files = [str(file) for file in all_files] | |
files_df = self.spark.createDataFrame(files, StringType()) | |
files_df.show(truncate=False) | |
return files_df | |
def describe(self): | |
return 'Extracts text from documents in a directory, and stores it in a database table' | |
if __name__ == '__main__': | |
job = TextExtractionJob() | |
job.run() | |
job.spark.stop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment