Skip to content

Instantly share code, notes, and snippets.

@pertschuk
Created December 4, 2019 01:33
Show Gist options
  • Save pertschuk/f7f71513540335c83f4284a3acbb02a1 to your computer and use it in GitHub Desktop.
Save pertschuk/f7f71513540335c83f4284a3acbb02a1 to your computer and use it in GitHub Desktop.
from google.cloud import language
from google.cloud.language import enums
from google.cloud.language import types
SUBSET_SIZE = 10000 # the number of passages to classify
client = language.LanguageServiceClient()
with open('./categories.tsv', 'w+') as outfile:
with open('./collectionandqueries/collection.tsv') as collection:
for i, line in enumerate(collection):
if i > SUBSET_SIZE: break
try:
doc_id, doc_text = line.split('\t')
document = types.Document(
content=doc_text,
type=enums.Document.Type.PLAIN_TEXT)
category = client.classify_text(document)
for cat in category.categories:
line = doc_id + '\t' + cat.name
line += '\t' + str(cat.confidence) + '\n'
outfile.write(line)
except:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment