Skip to content

Instantly share code, notes, and snippets.

@staccDOTsol
Created October 27, 2020 05:45
Show Gist options
  • Save staccDOTsol/361bed7fb513c255d3864e4e5efab525 to your computer and use it in GitHub Desktop.
Save staccDOTsol/361bed7fb513c255d3864e4e5efab525 to your computer and use it in GitHub Desktop.
def extract_submissions(fld_bz2, fld_split, which, size=2e5):
path_in = fld_bz2 + '/RS_%s.bz2'%args.dump_name
n = 0
m = 0
n2 = 0
m2 = 0
sub = 0
sid2 = []
sids = []
lines = []
try:
submissions = dict()
subreddit = reddit.subreddit(which)
for submission2 in subreddit.top(limit=5000):
try:
n += 1
#if n%1e4 == 0:
#print('[%s] selected %.3fM from %.2fM submissions'%(
#args.dump_name, m/1e6, n/1e6))
try:
submission = {}
submission["id"] = submission2.id
submission["score"] = submission2.score
submission["domain"] = submission2.domain
submission["permalink"] = submission2.permalink
submission["title"] = submission2.title
submission["num_comments"] = submission2.num_comments
if int(submission['num_comments']) >= 2: # filter 1
submission['title'] = norm_sentence(submission['title'], True)
submission = submission
submissions[get_submission_id(submission)] = submission
lines.append('\t'.join([str(submission[k]) for k in fields_subm]))
m += 1
sid2.append(get_submission_id(submission))
if len(sid2) == size:
#print('writing submissions_sub%i'%sub)
sids.append(set(sid2))
with open(fld_split + '/rs_sub%i.tsv'%sub, 'w', encoding='utf-8') as f:
f.write('\n'.join(lines))
sid2 = []
lines = []
except Exception as e:
print(e)
traceback.print_exc()
continue
lines2 = []
#for sub in range(n_sub):
# open(fld_split + '/rc_sub%i.tsv'%sub, 'w')
comments = dict()
for top_level_comment in submission2.comments:
try:
n2 += 1
comment = {}
comment["id"] = top_level_comment.id
try:
if top_level_comment.author is not None:
comment["author"] = top_level_comment.author.name
else:
comment["author"] = "None"
except:
comment["author"] = "None"
comment["parent_id"] = top_level_comment.parent_id
try:
comment["link_id"] = top_level_comment.link_id
comment["score"] = top_level_comment.score
comment["body"] = top_level_comment.body
except:
comment["link_id"] = comment["parent_id"]
comment["score"] = 0
comment["body"] = ""
#if args.keep_keys:
# k = '\t'.join([comment['link_id'], get_comment_id(comment), 'dep'])
# if k not in keys.keys():
# continue
if comment['body'] != '[deleted]': # filter 1
#if '>' in comment['body'] or '>' in comment['body']: # filter 3: '>' means '>'
# continue
#sid = comment['link_id']
comment['n_char'] = len(comment['body'])
comment['body'] = norm_sentence(comment['body'], True)
#print(comment)
if len(comment['body'].split()) >= 2: # filter 2
comment = comment
comments[get_comment_id(comment)] = comment
lines2.append('\t'.join([str(comment[k]) for k in fields_comm]))
m2 += 1
#break
except Exception as e:
print(e)
traceback.print_exc()
sorted_id = sorted([(
comments[cid]['link_id'],
comments[cid]['parent_id'],
cid
) for cid in comments])
n = len(comments)
#print('total comments: %i'%n)
i = 0
m = 0
lines = []
sum_resp_len = 0
skip_id = {}
if args.leaves_only:
for _, pid, _ in sorted_id:
skip_id[pid] = 1
#print("leaves ratio : %f" % (len(skip_id) / len(sorted_id)), file=sys.stderr)
for sid, pid, cid in sorted_id:
i += 1
if i%1e5 == 0:
#print('selected %.2fM from %.1f/%.1fM comments'%(m/1e6, i/1e6, n/1e6), file=sys.stderr)
if len(lines) > 0:
with open(path_out, 'a', encoding="utf-8") as f:
f.write('\n'.join(lines) + '\n')
lines = []
subreddit = ''
domain = ''
if sid in submissions.keys():
subreddit = submissions[sid]['permalink'].split('/')[2].lower()
domain = submissions[sid]['domain'].lower()
info = subreddit + '\t' + domain
#if args.bl_subreddits:
# if not subreddit:
#print("skip\tmissing\t%s\tN/A\tmissing submission: %s" % (info, sid), file=sys.stderr)
# continue
# if subreddit in bl_subreddits:
#print("skip\tbad_subreddit\t%s\tN/A\toffensive subreddit: %s" % (info, subreddit), file=sys.stderr)
# continue
comment = comments[cid]
if comment['score'] == 'None':
score = 0
else:
score = int(comment['score'])
if score < args.min_score: # filter 1
#print("skip\tlow_score\t%s\t%s\tscore %d < %d" % (info, comment['body'], score, args.min_score), file=sys.stderr)
continue
txts = []
for c in comments:
txts.append(comments[c]['body'])
#print(txts)
#txts = get_convo(sid, cid, cid, submissions, comments) # filter 2
#print(len(txts))
if len(txts) < args.min_depth: # filter 3
#print("skip\tmin_depth\t%s\t%s\tdepth %d < %d: %s" % (info, comment['body'], len(txts), args.min_depth, "|".join(txts)), file=sys.stderr)
continue
for i in range(len(txts)):
txts[i] = norm_sentence(txts[i], False)
if args.leaves_only and args.clean:
sc = '1.0'
skip_target = False
if args.discard_tgt_keys:
tgt_h = hashlib.sha224(txts[i].encode("utf-8")).hexdigest()
if tgt_h in keys_rm.keys():
skip_target = True
if bl_words.extract_keywords(txts[i]) or skip_target:
sc = '0.0'
txts[i] = sc + ' ' + txts[i]
src = ' EOS '.join(txts[:-1])
tgt = txts[-1]
header = ','.join([sid, pid, cid])
lines.append(header + '\t' + src + '\t' + tgt)
sum_resp_len += len(tgt.split())
m += 1
#avg_len = sum_resp_len/m
with open(fld_split + '/%s.tsv'%args.dump_name, 'a', encoding="utf-8") as f:
f.write('\n'.join(lines) + '\n')
#print('finally selected %i/%i'%(m, n))#, avg_len))
with open(fld_split + '/rc_sub%i.tsv'%sub, 'a', encoding='utf-8') as f:
#print(lines2[sub])
f.write('\n'.join(lines2))
except Exception as e:
print(e)
traceback.print_exc()
#sids, ms, ns, mc, ns = extract_submissions(fld_root_in, fld_split, size=args.split_size)
#mc, nc = extract_comments(fld_root_in, fld_split, sids)
#with open(fld_split + '/stat.tsv', 'a') as f:
# f.write('\t'.join(map(str, [args.dump_name, m2, n2, m, n])) + '\n')
#print('extract_comments done.\n')
#return m, n
#print('writing submissions_sub%i'%sub)
sids.append(set(sid))
with open(fld_split + '/rs_sub%i.tsv'%sub, 'a', encoding='utf-8') as f:
f.write('\n'.join(lines))
lines = []
sub += 1
except Exception as e:
print(e)
print('extract_submissions done.\n')
return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment