Created
October 27, 2020 05:45
-
-
Save staccDOTsol/361bed7fb513c255d3864e4e5efab525 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def extract_submissions(fld_bz2, fld_split, which, size=2e5): | |
path_in = fld_bz2 + '/RS_%s.bz2'%args.dump_name | |
n = 0 | |
m = 0 | |
n2 = 0 | |
m2 = 0 | |
sub = 0 | |
sid2 = [] | |
sids = [] | |
lines = [] | |
try: | |
submissions = dict() | |
subreddit = reddit.subreddit(which) | |
for submission2 in subreddit.top(limit=5000): | |
try: | |
n += 1 | |
#if n%1e4 == 0: | |
#print('[%s] selected %.3fM from %.2fM submissions'%( | |
#args.dump_name, m/1e6, n/1e6)) | |
try: | |
submission = {} | |
submission["id"] = submission2.id | |
submission["score"] = submission2.score | |
submission["domain"] = submission2.domain | |
submission["permalink"] = submission2.permalink | |
submission["title"] = submission2.title | |
submission["num_comments"] = submission2.num_comments | |
if int(submission['num_comments']) >= 2: # filter 1 | |
submission['title'] = norm_sentence(submission['title'], True) | |
submission = submission | |
submissions[get_submission_id(submission)] = submission | |
lines.append('\t'.join([str(submission[k]) for k in fields_subm])) | |
m += 1 | |
sid2.append(get_submission_id(submission)) | |
if len(sid2) == size: | |
#print('writing submissions_sub%i'%sub) | |
sids.append(set(sid2)) | |
with open(fld_split + '/rs_sub%i.tsv'%sub, 'w', encoding='utf-8') as f: | |
f.write('\n'.join(lines)) | |
sid2 = [] | |
lines = [] | |
except Exception as e: | |
print(e) | |
traceback.print_exc() | |
continue | |
lines2 = [] | |
#for sub in range(n_sub): | |
# open(fld_split + '/rc_sub%i.tsv'%sub, 'w') | |
comments = dict() | |
for top_level_comment in submission2.comments: | |
try: | |
n2 += 1 | |
comment = {} | |
comment["id"] = top_level_comment.id | |
try: | |
if top_level_comment.author is not None: | |
comment["author"] = top_level_comment.author.name | |
else: | |
comment["author"] = "None" | |
except: | |
comment["author"] = "None" | |
comment["parent_id"] = top_level_comment.parent_id | |
try: | |
comment["link_id"] = top_level_comment.link_id | |
comment["score"] = top_level_comment.score | |
comment["body"] = top_level_comment.body | |
except: | |
comment["link_id"] = comment["parent_id"] | |
comment["score"] = 0 | |
comment["body"] = "" | |
#if args.keep_keys: | |
# k = '\t'.join([comment['link_id'], get_comment_id(comment), 'dep']) | |
# if k not in keys.keys(): | |
# continue | |
if comment['body'] != '[deleted]': # filter 1 | |
#if '>' in comment['body'] or '>' in comment['body']: # filter 3: '>' means '>' | |
# continue | |
#sid = comment['link_id'] | |
comment['n_char'] = len(comment['body']) | |
comment['body'] = norm_sentence(comment['body'], True) | |
#print(comment) | |
if len(comment['body'].split()) >= 2: # filter 2 | |
comment = comment | |
comments[get_comment_id(comment)] = comment | |
lines2.append('\t'.join([str(comment[k]) for k in fields_comm])) | |
m2 += 1 | |
#break | |
except Exception as e: | |
print(e) | |
traceback.print_exc() | |
sorted_id = sorted([( | |
comments[cid]['link_id'], | |
comments[cid]['parent_id'], | |
cid | |
) for cid in comments]) | |
n = len(comments) | |
#print('total comments: %i'%n) | |
i = 0 | |
m = 0 | |
lines = [] | |
sum_resp_len = 0 | |
skip_id = {} | |
if args.leaves_only: | |
for _, pid, _ in sorted_id: | |
skip_id[pid] = 1 | |
#print("leaves ratio : %f" % (len(skip_id) / len(sorted_id)), file=sys.stderr) | |
for sid, pid, cid in sorted_id: | |
i += 1 | |
if i%1e5 == 0: | |
#print('selected %.2fM from %.1f/%.1fM comments'%(m/1e6, i/1e6, n/1e6), file=sys.stderr) | |
if len(lines) > 0: | |
with open(path_out, 'a', encoding="utf-8") as f: | |
f.write('\n'.join(lines) + '\n') | |
lines = [] | |
subreddit = '' | |
domain = '' | |
if sid in submissions.keys(): | |
subreddit = submissions[sid]['permalink'].split('/')[2].lower() | |
domain = submissions[sid]['domain'].lower() | |
info = subreddit + '\t' + domain | |
#if args.bl_subreddits: | |
# if not subreddit: | |
#print("skip\tmissing\t%s\tN/A\tmissing submission: %s" % (info, sid), file=sys.stderr) | |
# continue | |
# if subreddit in bl_subreddits: | |
#print("skip\tbad_subreddit\t%s\tN/A\toffensive subreddit: %s" % (info, subreddit), file=sys.stderr) | |
# continue | |
comment = comments[cid] | |
if comment['score'] == 'None': | |
score = 0 | |
else: | |
score = int(comment['score']) | |
if score < args.min_score: # filter 1 | |
#print("skip\tlow_score\t%s\t%s\tscore %d < %d" % (info, comment['body'], score, args.min_score), file=sys.stderr) | |
continue | |
txts = [] | |
for c in comments: | |
txts.append(comments[c]['body']) | |
#print(txts) | |
#txts = get_convo(sid, cid, cid, submissions, comments) # filter 2 | |
#print(len(txts)) | |
if len(txts) < args.min_depth: # filter 3 | |
#print("skip\tmin_depth\t%s\t%s\tdepth %d < %d: %s" % (info, comment['body'], len(txts), args.min_depth, "|".join(txts)), file=sys.stderr) | |
continue | |
for i in range(len(txts)): | |
txts[i] = norm_sentence(txts[i], False) | |
if args.leaves_only and args.clean: | |
sc = '1.0' | |
skip_target = False | |
if args.discard_tgt_keys: | |
tgt_h = hashlib.sha224(txts[i].encode("utf-8")).hexdigest() | |
if tgt_h in keys_rm.keys(): | |
skip_target = True | |
if bl_words.extract_keywords(txts[i]) or skip_target: | |
sc = '0.0' | |
txts[i] = sc + ' ' + txts[i] | |
src = ' EOS '.join(txts[:-1]) | |
tgt = txts[-1] | |
header = ','.join([sid, pid, cid]) | |
lines.append(header + '\t' + src + '\t' + tgt) | |
sum_resp_len += len(tgt.split()) | |
m += 1 | |
#avg_len = sum_resp_len/m | |
with open(fld_split + '/%s.tsv'%args.dump_name, 'a', encoding="utf-8") as f: | |
f.write('\n'.join(lines) + '\n') | |
#print('finally selected %i/%i'%(m, n))#, avg_len)) | |
with open(fld_split + '/rc_sub%i.tsv'%sub, 'a', encoding='utf-8') as f: | |
#print(lines2[sub]) | |
f.write('\n'.join(lines2)) | |
except Exception as e: | |
print(e) | |
traceback.print_exc() | |
#sids, ms, ns, mc, ns = extract_submissions(fld_root_in, fld_split, size=args.split_size) | |
#mc, nc = extract_comments(fld_root_in, fld_split, sids) | |
#with open(fld_split + '/stat.tsv', 'a') as f: | |
# f.write('\t'.join(map(str, [args.dump_name, m2, n2, m, n])) + '\n') | |
#print('extract_comments done.\n') | |
#return m, n | |
#print('writing submissions_sub%i'%sub) | |
sids.append(set(sid)) | |
with open(fld_split + '/rs_sub%i.tsv'%sub, 'a', encoding='utf-8') as f: | |
f.write('\n'.join(lines)) | |
lines = [] | |
sub += 1 | |
except Exception as e: | |
print(e) | |
print('extract_submissions done.\n') | |
return |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment