Step 2. Building the dataset
On part 1 of this tutorial we laid out the project in detail, and decided that in order to calculate the similarity between two subs, we just need to find the list of users on each one.
However, this methodology becomes a computational problem when you consider the magnitud of Reddit. Reddit has more than 300,000 subreddits, and more than 100 Million users (6% of the US population), so we need to narrow a little bit the data that we need.
A way to narrow it down would be to choose to analyze only the subreddits that are big enough. Following the long tail principle, there are many subreddits with just a couple of redditors, and we are not interested in them (not yet at least).
To do so, I built a list of all subreddits in reddit and their subscriber count. Here is the Python scripts I used.
Note: For this project, I used some web scraping methods that are not friendly to the target sites. These scripts ask for a lot of information to the servers that host that site. Reddit has an awesome API, which should be used whenever possible
This script gets a list of available subreddits and save them to a json file
'''Gets a list of all subredits'''
import requests
import json
import time
def add_subs(content, count):
if content:
for sub in content.get('data').get('children'):
line = sub.get('data').get('display_name')
file.write(line)
file.write('\n')
block_count = len(content.get('data').get('children'))
return block_count
def get_url(url):
while 1:
try:
content = requests.get(url, timeout=5)
content = json.loads(content.content)
return content
except:
time.sleep(60)
subs = []
url = 'http://www.reddit.com/reddits.json'
file = open('subs.txt', 'w')
content = get_url(url)
subs = add_subs(content, subs)
base_url = 'http://www.reddit.com/reddits.json?count={0}&after={1}'
count=0
while 1:
if content:
after = content.get('data').get('after')
if after:
url = base_url.format(count, after)
print url
content = get_url(url)
count += add_subs(content, count)
else:
break
file.close()
with open('subs.json', 'w') as file:
json.dump(subs, file)
file.close()
````
Now that we have a list of all subreddits, we need to find the user counts for each one.
I created a postgres database to store all the reddit information, and use the following script to read the subs.json file we just created, and for each sub, find the number of subscribers and write it to a table. I used the awesome [PRAW](https://praw.readthedocs.org/en/v2.1.19/) library to interact with Reddit's api:
```prettyprint lang-python
import json
import time
import psycopg2
import numpy as np
import praw
from praw.handlers import MultiprocessHandler
DB_NAME=<YOUR_DB_NAME>
DB_USER=<YOUR_DB_USER>
DB_HOST=<YOUR_DB_HOST>
DB_PWD=<YOUR_DB_NPWD>
class Save_to_db():
connection = None
cursor = None
def __init__(self, host, dbname, user_name, password):
# Connect to an existing database
self.connection = psycopg2.connect(dbname=dbname, user=user_name,
host=host, password=password)
self.cursor = self.connection.cursor()
def __del__(self):
if self.cursor is not None:
self.cursor.close();
if self.connection is not None:
self.connection.close();
def insert(self, table_name, entry_list):
if self.cursor is None:
raise Exception("Invalid connection to database.")
# Fill in the new values
subquery=( ", ".join( repr(point) for point in entry_list))
query = 'INSERT INTO %s VALUES (%s);' % (table_name, subquery)
self.cursor.execute(query, subquery)
# Write the new info to db
self.connection.commit()
def execute_sql(self, sql_query):
if self.cursor is None:
raise Exception("Invalid connection to database.")
self.cursor.execute(sql_query)
self.connection.commit()
def retrieve_sql(self, sql_query):
if self.cursor is None:
raise Exception("Invalid connection to database.")
self.cursor.execute(sql_query)
return self.cursor.fetchall()
def create_table(self, table_name, col_dict):
if self.cursor is None:
raise Exception("Invalid connection to database.")
string = ''
for item in col_dict.items():
string = string + item[0] + ' ' + item[1] + ', '
string = string[:-2]
sql_query = 'CREATE TABLE {0} ({1});'.format(table_name, string)
self.cursor.execute(sql_query)
self.connection.commit()
handler = MultiprocessHandler()
r = praw.Reddit(user_agent='Manugarri Reddit Recommendation Engine', handler=handler)
with open('subs.json', 'r') as file:
subs = json.load(file)
subs = np.array(subs)
# We add this in case our script fails and we need to start again
saver = Save_to_db(dbname=DB_NAME, user_name=DB_USER, host=DB_HOST,
password=DB_PWD)
existing_subs = np.array(saver.retrieve_sql('SELECT name from subscriber_count;'))
existing_subs = existing_subs.flatten()
subs = np.setdiff1d(subs, existing_subs)
print('{0} subs left'.format(len(subs)))
for sub in subs:
try:
data = r.get_subreddit(sub, fetch=True)
subscribers = data.subscribers
sql_query = '''INSERT INTO subscriber_count
(name, subscribers)
SELECT '{0}', {1}
WHERE
NOT EXISTS (
SELECT name FROM subscriber_count WHERE name = '{0}'
);'''.format(sub, subscribers)
saver.execute_sql(sql_query)
print sub, ' : ' ,subscribers
except Exception as e:
print e, sub
time.sleep(10)
del saver
At this point, we know the subscriber information for all the subreddits.
Lets see how the subreddits distribution looks like:
xkcd style because why not
As you can see, the big mayority of subs (around 200,000 subs) have less than 200 subscribers. These are probably subreddits that either are new, or too niche to be relevant for our recommendation engine.
To narrow the dataset down a little, let’s focus on the subreddits with 10,000 subscribers or more
kudos to yhat for the ggplot port btw
That will do. The distribution seems much more distributed.
That leaves us with 1400 Subreddits that will be either stored and recommended to the Recommendation engine users.
So we need to get a list of redditors for each of the Subreddits with 10,000+ subscribers.
The process will look like:
1. First we instantiate the redditors table (we will be using the save_to_db
class from the previous block:
# create the table (just once)
with open('subs_db.json') as file:
subs = json.load(file)
col_dict = OrderedDict()
col_dict['redditor'] = 'varchar(20)'
for sub in subs:
col_dict[sub] = 'smallint'
saver = Save_to_db(dbname=DB_NAME, user_name=DB_USER, host=DB_HOST,
password=DB_PWD)
saver.create_table('redditors', col_dict)
Side Note: One mistake i did is that initially i instantiated the redditors table as a very wide table, with a row per redditor and a column by subreddit. This approach turned out to be very problematic. I will explain later how to deal with it.
2. Open the file with the final 1400 subreddits.
3. For each sub:
3b. Get the latest comments on that sub. 3c. Get the user that wrote each comment 4. Store that user’s subreddit on the redditors database.
Here is the code I used (I also used the save_to_db i included before). To accelerate the process I used the multiprocessing library to spin multiple workers.
from collections import OrderedDict
import json
from datetime import datetime
import multiprocessing
from random import shuffle
import numpy as np
import psycopg2
import praw
from praw.handlers import MultiprocessHandler
def store_redditor(sub, redditor, saver):
query = '''
DO
$BODY$
BEGIN
IF EXISTS (SELECT "{sub}" from redditors where redditor = '{redditor}' ) THEN
UPDATE redditors SET "{sub}" = 1 WHERE redditor = "{redditor}";
ELSE
INSERT INTO redditors (redditor, "{sub}") VALUES ('{redditor}',1);
END IF;
END;
$BODY$
'''.format(redditor=redditor, sub=sub)
saver.execute_sql(query)
def sub_in_db(sub, saver):
sub_col = np.array(saver.retrieve_sql('SELECT "{}" from\
redditors'.format(sub)))
print sub_col
if 1 in sub_col:
print 'Sub {0} in Database'.format(sub)
return True
else:
print 'Sub {0} NOT in Database'.format(sub)
return False
def get_redditors(sub, r):
sub_redditors = []
sub_info = r.get_subreddit(sub)
comments = sub_info.get_new(limit=None)
while 1:
c = comments.next()
time_old = (datetime.now() - datetime.fromtimestamp(c.created))
if time_old.total_seconds()/(3600*24) < 180:
redditor = c.author.name
if redditor not in sub_redditors:
store_redditor(sub, redditor, saver)
print sub,redditor
sub_redditors.append(redditor)
else:
pass
else:
break
with open('subs_db.json') as file:
subs = json.load(file)
handler = MultiprocessHandler()
def praw_process(handler):
saver = Save_to_db(dbname='reddit', user_name='reddit', host='localhost',
password='reddit')
r = praw.Reddit(user_agent='Manugarri Recommendation Engine', handler=handler)
for sub in subs:
if not sub_in_db(sub, saver):
get_redditors(sub, r, saver)
if __name__ == '__main__':
saver = Save_to_db(dbname=DB_NAME, user_name=DB_USER, host=DB_HOST,
password=DB_PWD)
jobs = []
for i in range(10): # we span 10 workers
p = multiprocessing.Process(target=praw_process, args=(handler,))
jobs.append(p)
p.start()
4. Once we have enough redditors, we will directly find those redditors comments and update the subreddits they are commenting on.
Here is the code I used to do so. Very similar than the one before, but now the input is a list of redditors and we parse their profile pages.
def get_comments_subs(username, r):
# Get the list of comments by the user
comlist = []
try:
comments = r.get_redditor(username).get_comments(limit=1000)
comlist.extend(comments)
comlist = list(set([str(i.subreddit) for i in comlist]))
except:
logging.error('Redditor {} not found'.format(username))
return comlist
def get_submissions_subs(username, r):
# Get the list of submissions by the user
postlist = []
try:
submissions = r.get_redditor(username).get_submitted(limit=1000)
postlist.extend(submissions)
postlist = list(set([str(i.subreddit) for i in postlist]))
except:
logging.error('Redditor {} not found'.format(username))
return postlist
def get_redditor_subs_praw(redditor, reddit_session):
print redditor
subs = []
try:
subs.extend(get_comments_subs(str(redditor), reddit_session))
subs.extend(get_submissions_subs(str(redditor), reddit_session))
except:
logging.exception('')
subs = [sub for sub in subs if sub in SUBS_DB]
return subs
def update_redditor(redditor, subs, lock):
values = [1 for i in range(len(subs))]
insert_dict = OrderedDict(zip(subs, values))
insert_dict['redditor'] = "'"+redditor+"'"
# worker_saver = Save_to_db(dbname='reddit', user_name='reddit', host='localhost',
# password='reddit')
connection = psycopg2.connect(dbname='reddit', user='reddit', host='localhost',
password='reddit', sslmode='require')
cursor = connection.cursor()
columns = ', '.join(['"'+key+'"' for key in insert_dict.keys()])
values = ', '.join([str(value) for value in insert_dict.values()])
insert_query = 'insert into {} ({}) values ({});'.format('redditors2', columns, values)
update_values = ', '.join(['"' + key + '" = '+ str(1) for key in insert_dict])
update_query = '''UPDATE redditors2 SET {} WHERE redditor = '{}';'''.format(update_values, redditor)
query = '''
DO
$BODY$
BEGIN
IF EXISTS (SELECT redditor from redditors2 where redditor = '{}' ) THEN
{}
ELSE
{}
END IF;
END;
$BODY$
'''.format(redditor, update_query, insert_query)
cursor.execute(query)
connection.commit()
print redditor, len(subs)
def worker_get_redditor_subs(job_id, redditor_list, lock, handler):
reddit_session = praw.Reddit(user_agent='Manugarri Recommendation Engine',
handler = handler)
for redditor in redditor_list:
subs = get_redditor_subs_praw(redditor, reddit_session)
if len(subs) > 0:
update_redditor(redditor, subs, lock)
def dispatch_jobs(data, job_number, function):
def chunks(l, n):
return [l[i:i+n] for i in range(0, len(l), n)]
total = len(data)
chunk_size = total / job_number
slice = chunks(data, chunk_size)
jobs = []
lock = multiprocessing.Lock()
handler = MultiprocessHandler()
for i, s in enumerate(slice):
j = multiprocessing.Process(target=function, args=(i, s, lock, handler))
jobs.append(j)
for j in jobs:
j.start()
if __name__ == '__main__':
with open('subs_db.json') as file:
SUBS_DB = json.load(file)
saver = Save_to_db(dbname=DB_NAME, user_name=DB_USER, host=DB_HOST,
password=DB_PWD)
# we implement 2 different tables in case the process is interrupted.
redditors = set(saver.retrieve_sql('SELECT DISTINCT redditor from\
redditors;'))
n_workers =20
existing_redditors = set(saver.retrieve_sql('SELECT DISTINCT redditor from\
redditors2;'))
redditors = list(redditors - existing_redditors)
redditors.sort()
dispatch_jobs(redditors, n_workers, worker_get_redditor_subs)
So by now we have a set of subreddits, and a list of redditors with a list of subreddits they have commented on.
Next Step will be about how to compute the similarity. I will explain that on part 3.
####### Side note.
At this moment in the project I realized that it would be much easier if , instead of having the redditors table in the following format:
redditor | sub1 | sub2 | --> a column per subreddit
redditor1 | 0 | 0 | --> 1 when the redditor has commented on the sub
redditor2 | 1 | 0 | --> 0 if the redditor has not commented on the sub
...
we gave it a longer structure (that means, less columns and more rows)
redditor | sub
redditor1 | sub1
redditor1 | sub234
redditor2 | sub456
...
So I used the following script to get the table as a csv, change the format, and save it to another csv, which I then imported to the database again:
redditors = open('redditors_wide.csv')
redditors_long = open('redditors_long.csv', 'w')
columns = redditors.readline()
columns = columns.replace('\n','')
columns = columns.split(',')
redditors_test.write('redditor,sub\n')
for line in redditors:
line = line.replace('\n','')
line = line.split(',')
redditor = line[0]
indices = [i for i, j in enumerate(line) if j == '1']
subs = [columns[i] for i in indices]
for sub in subs:
redditors_long.write('{},{}\n'.format(redditor,sub))
redditors_long.close()