November 12, 2014

Building a Recommendation Engine for Reddit. Part 2

Step 2. Building the dataset

On part 1 of this tutorial we laid out the project in detail, and decided that in order to calculate the similarity between two subs, we just need to find the list of users on each one.

However, this methodology becomes a computational problem when you consider the magnitud of Reddit. Reddit has more than 300,000 subreddits, and more than 100 Million users (6% of the US population), so we need to narrow a little bit the data that we need.

A way to narrow it down would be to choose to analyze only the subreddits that are big enough. Following the long tail principle, there are many subreddits with just a couple of redditors, and we are not interested in them (not yet at least).

To do so, I built a list of all subreddits in reddit and their subscriber count. Here is the Python scripts I used.

Note: For this project, I used some web scraping methods that are not friendly to the target sites. These scripts ask for a lot of information to the servers that host that site. Reddit has an awesome API, which should be used whenever possible

This script gets a list of available subreddits and save them to a json file

'''Gets a list of all subredits'''
import requests
import json
import time

def add_subs(content, count):
    if content:
        for sub in content.get('data').get('children'):
            line = sub.get('data').get('display_name')
            file.write(line)
            file.write('\n')
        block_count = len(content.get('data').get('children'))
    return block_count

def get_url(url):
    while 1:
        try:
            content = requests.get(url, timeout=5)
            content = json.loads(content.content)
            return content
        except:
            time.sleep(60)


subs = []
url = 'http://www.reddit.com/reddits.json'
file = open('subs.txt', 'w')
content = get_url(url)
subs = add_subs(content, subs)

base_url = 'http://www.reddit.com/reddits.json?count={0}&after={1}'
count=0
while 1:
    if content:
        after = content.get('data').get('after')
        if after:
            url = base_url.format(count, after)
            print url
            content = get_url(url)
            count += add_subs(content, count)
    else:
        break
file.close()
with open('subs.json', 'w') as file:
    json.dump(subs, file)
    file.close()
````

Now that we have a list of all subreddits, we need to find the user counts for each one.

I created a postgres database to store all the reddit information, and use the following script to read the subs.json file we just created, and for each sub, find the number of subscribers and write it to a table. I used the awesome [PRAW](https://praw.readthedocs.org/en/v2.1.19/) library to interact with Reddit's api:

```prettyprint lang-python
import json
import time
import psycopg2
import numpy as np

import praw
from praw.handlers import MultiprocessHandler


DB_NAME=<YOUR_DB_NAME>
DB_USER=<YOUR_DB_USER>
DB_HOST=<YOUR_DB_HOST>
DB_PWD=<YOUR_DB_NPWD>

class Save_to_db():
    connection = None
    cursor = None

    def __init__(self, host, dbname, user_name, password):
        # Connect to an existing database
        self.connection = psycopg2.connect(dbname=dbname, user=user_name,
                host=host, password=password)
        self.cursor = self.connection.cursor()

    def __del__(self):
        if self.cursor is not None:
            self.cursor.close();

        if self.connection is not None:
            self.connection.close();

    def insert(self, table_name, entry_list):
        if self.cursor is None:
            raise Exception("Invalid connection to database.")
        # Fill in the new values
        subquery=( ", ".join( repr(point) for point in entry_list))
        query = 'INSERT INTO %s VALUES (%s);' % (table_name, subquery)
        self.cursor.execute(query, subquery)
        # Write the new info to db
        self.connection.commit()

    def execute_sql(self, sql_query):
        if self.cursor is None:
            raise Exception("Invalid connection to database.")
        self.cursor.execute(sql_query)
        self.connection.commit()

    def retrieve_sql(self, sql_query):
        if self.cursor is None:
            raise Exception("Invalid connection to database.")
        self.cursor.execute(sql_query)
        return self.cursor.fetchall()

    def create_table(self, table_name, col_dict):
        if self.cursor is None:
            raise Exception("Invalid connection to database.")
        string = ''
        for item in col_dict.items():
            string = string + item[0] + ' ' + item[1] + ', '
        string = string[:-2]
        sql_query = 'CREATE TABLE {0} ({1});'.format(table_name, string)
        self.cursor.execute(sql_query)
        self.connection.commit()

handler = MultiprocessHandler()
r = praw.Reddit(user_agent='Manugarri Reddit Recommendation Engine', handler=handler)

with open('subs.json', 'r') as file:
    subs = json.load(file)
    subs = np.array(subs)

# We add this in case our script fails  and we need to start again
saver = Save_to_db(dbname=DB_NAME, user_name=DB_USER, host=DB_HOST,
password=DB_PWD)
existing_subs = np.array(saver.retrieve_sql('SELECT name from subscriber_count;'))
existing_subs = existing_subs.flatten()
subs = np.setdiff1d(subs, existing_subs)

print('{0} subs left'.format(len(subs)))

for sub in subs:
    try:
        data = r.get_subreddit(sub, fetch=True)
        subscribers = data.subscribers

        sql_query = '''INSERT INTO subscriber_count
            (name, subscribers)
        SELECT '{0}', {1}
        WHERE
        NOT EXISTS (
            SELECT name FROM subscriber_count WHERE name = '{0}'
        );'''.format(sub, subscribers)
        saver.execute_sql(sql_query)
        print sub, ' : ' ,subscribers
    except Exception as e:
        print e, sub
        time.sleep(10)

del saver

At this point, we know the subscriber information for all the subreddits.

Lets see how the subreddits distribution looks like:

Sub distribution xkcd style because why not

As you can see, the big mayority of subs (around 200,000 subs) have less than 200 subscribers. These are probably subreddits that either are new, or too niche to be relevant for our recommendation engine.

To narrow the dataset down a little, let’s focus on the subreddits with 10,000 subscribers or more

Subscriber count kudos to yhat for the ggplot port btw

That will do. The distribution seems much more distributed.

That leaves us with 1400 Subreddits that will be either stored and recommended to the Recommendation engine users.

So we need to get a list of redditors for each of the Subreddits with 10,000+ subscribers.

The process will look like:

1. First we instantiate the redditors table (we will be using the save_to_db class from the previous block:

# create the table (just once)
with open('subs_db.json') as file:
    subs = json.load(file)
col_dict = OrderedDict()
col_dict['redditor'] = 'varchar(20)'
for sub in subs:
  col_dict[sub] = 'smallint'

saver = Save_to_db(dbname=DB_NAME, user_name=DB_USER, host=DB_HOST,
password=DB_PWD)

saver.create_table('redditors', col_dict)

Side Note: One mistake i did is that initially i instantiated the redditors table as a very wide table, with a row per redditor and a column by subreddit. This approach turned out to be very problematic. I will explain later how to deal with it.

2. Open the file with the final 1400 subreddits.

3. For each sub:

3b. Get the latest comments on that sub. 3c. Get the user that wrote each comment 4. Store that user’s subreddit on the redditors database.

Here is the code I used (I also used the save_to_db i included before). To accelerate the process I used the multiprocessing library to spin multiple workers.

from collections import OrderedDict
import json
from datetime import datetime
import multiprocessing
from random import shuffle

import numpy as np
import psycopg2
import praw
from praw.handlers import MultiprocessHandler


def store_redditor(sub, redditor, saver):
    query = '''
    DO
    $BODY$
    BEGIN
    IF EXISTS (SELECT "{sub}" from redditors where redditor = '{redditor}' ) THEN
        UPDATE redditors SET "{sub}" = 1 WHERE redditor = "{redditor}";
    ELSE
        INSERT INTO redditors (redditor, "{sub}") VALUES ('{redditor}',1);
    END IF;
    END;
    $BODY$
    '''.format(redditor=redditor, sub=sub)
    saver.execute_sql(query)


def sub_in_db(sub, saver):
    sub_col = np.array(saver.retrieve_sql('SELECT "{}" from\
        redditors'.format(sub)))
    print sub_col
    if 1 in sub_col:
        print 'Sub {0} in Database'.format(sub)
        return True
    else:
        print 'Sub {0} NOT in Database'.format(sub)
        return False

def get_redditors(sub, r):
    sub_redditors = []
    sub_info = r.get_subreddit(sub)
    comments = sub_info.get_new(limit=None)
    while 1:
        c = comments.next()
        time_old = (datetime.now() - datetime.fromtimestamp(c.created))
        if time_old.total_seconds()/(3600*24) < 180:
            redditor = c.author.name
            if redditor not in sub_redditors:
                store_redditor(sub, redditor, saver)
                print sub,redditor
                sub_redditors.append(redditor)
            else:
                pass
        else:
            break

with open('subs_db.json') as file:
    subs = json.load(file)

handler = MultiprocessHandler()
def praw_process(handler):
    saver = Save_to_db(dbname='reddit', user_name='reddit', host='localhost',
    password='reddit')
    r = praw.Reddit(user_agent='Manugarri Recommendation Engine', handler=handler)
    for sub in subs:
        if not sub_in_db(sub, saver):
            get_redditors(sub, r, saver)

if __name__ == '__main__':
    saver = Save_to_db(dbname=DB_NAME, user_name=DB_USER, host=DB_HOST,
    password=DB_PWD)
    jobs = []
    for i in range(10): # we span 10 workers
        p = multiprocessing.Process(target=praw_process, args=(handler,))
        jobs.append(p)
        p.start()

4. Once we have enough redditors, we will directly find those redditors comments and update the subreddits they are commenting on.

Here is the code I used to do so. Very similar than the one before, but now the input is a list of redditors and we parse their profile pages.

def get_comments_subs(username, r):
    # Get the list of comments by the user
    comlist = []
    try:
        comments = r.get_redditor(username).get_comments(limit=1000)
        comlist.extend(comments)
        comlist = list(set([str(i.subreddit) for i in comlist]))
    except:
        logging.error('Redditor {} not found'.format(username))
    return comlist

def get_submissions_subs(username, r):
    # Get the list of submissions by the user
    postlist = []
    try:
        submissions = r.get_redditor(username).get_submitted(limit=1000)
        postlist.extend(submissions)
        postlist = list(set([str(i.subreddit) for i in postlist]))
    except:
        logging.error('Redditor {} not found'.format(username))
    return postlist

def get_redditor_subs_praw(redditor, reddit_session):
    print redditor
    subs = []
    try:
        subs.extend(get_comments_subs(str(redditor), reddit_session))
        subs.extend(get_submissions_subs(str(redditor), reddit_session))
    except:
        logging.exception('')
    subs = [sub for sub in subs if sub in SUBS_DB]
    return subs

def update_redditor(redditor, subs, lock):
    values = [1 for i in range(len(subs))]
    insert_dict = OrderedDict(zip(subs, values))
    insert_dict['redditor'] = "'"+redditor+"'"
    # worker_saver = Save_to_db(dbname='reddit', user_name='reddit', host='localhost',
    #               password='reddit')
    connection = psycopg2.connect(dbname='reddit', user='reddit', host='localhost',
                   password='reddit', sslmode='require')
    cursor = connection.cursor()
    columns = ', '.join(['"'+key+'"' for key in insert_dict.keys()])
    values = ', '.join([str(value) for value in insert_dict.values()])
    insert_query = 'insert into {} ({}) values ({});'.format('redditors2', columns, values)

    update_values = ', '.join(['"' + key + '" = '+ str(1) for key in insert_dict])
    update_query = '''UPDATE redditors2 SET {} WHERE redditor = '{}';'''.format(update_values, redditor)

    query = '''
    DO
    $BODY$
    BEGIN
    IF EXISTS (SELECT redditor from redditors2 where redditor = '{}' ) THEN
        {}
    ELSE
        {}
    END IF;
    END;
    $BODY$
    '''.format(redditor, update_query, insert_query)
    cursor.execute(query)
    connection.commit()
    print redditor, len(subs)

def worker_get_redditor_subs(job_id, redditor_list, lock, handler):
    reddit_session = praw.Reddit(user_agent='Manugarri Recommendation Engine',
            handler = handler)
    for redditor in redditor_list:
        subs = get_redditor_subs_praw(redditor, reddit_session)
        if len(subs) > 0:
            update_redditor(redditor, subs, lock)

def dispatch_jobs(data, job_number, function):

    def chunks(l, n):
        return [l[i:i+n] for i in range(0, len(l), n)]

    total = len(data)
    chunk_size = total / job_number
    slice = chunks(data, chunk_size)
    jobs = []
    lock = multiprocessing.Lock()
    handler = MultiprocessHandler()
    for i, s in enumerate(slice):
        j = multiprocessing.Process(target=function, args=(i, s, lock, handler))
        jobs.append(j)
    for j in jobs:
        j.start()

if __name__ == '__main__':
    with open('subs_db.json') as file:
        SUBS_DB = json.load(file)
    saver = Save_to_db(dbname=DB_NAME, user_name=DB_USER, host=DB_HOST,
    password=DB_PWD)
    
    # we implement 2 different tables in case the process is interrupted.
    redditors = set(saver.retrieve_sql('SELECT DISTINCT redditor from\
        redditors;'))
    n_workers =20
    existing_redditors = set(saver.retrieve_sql('SELECT DISTINCT redditor from\
            redditors2;'))
    redditors = list(redditors - existing_redditors)
    redditors.sort()
    dispatch_jobs(redditors, n_workers, worker_get_redditor_subs)

So by now we have a set of subreddits, and a list of redditors with a list of subreddits they have commented on.

Next Step will be about how to compute the similarity. I will explain that on part 3.


####### Side note.

At this moment in the project I realized that it would be much easier if , instead of having the redditors table in the following format:

 redditor     |  sub1  |   sub2  | --> a column per subreddit

redditor1     |   0    |   0     | --> 1 when the redditor has commented on the sub

redditor2     |   1    |   0     | --> 0 if the redditor has not commented on the sub

...

we gave it a longer structure (that means, less columns and more rows)

 redditor     |  sub

redditor1     |  sub1

redditor1     |  sub234

redditor2     |  sub456

...

So I used the following script to get the table as a csv, change the format, and save it to another csv, which I then imported to the database again:

redditors = open('redditors_wide.csv')
redditors_long = open('redditors_long.csv', 'w')
columns = redditors.readline()
columns = columns.replace('\n','')
columns = columns.split(',')
redditors_test.write('redditor,sub\n')
for line in redditors:
    line = line.replace('\n','')
    line = line.split(',')
    redditor = line[0]
    indices = [i for i, j in enumerate(line) if j == '1']
    subs = [columns[i] for i in indices]
    for sub in subs:
        redditors_long.write('{},{}\n'.format(redditor,sub))
redditors_long.close()

Powered by Hugo & Kiss.