Source code for ucas_dm.prediction_algorithms.surprise_base_algo
from surprise import Dataset, Reader, dump
from .base_algo import BaseAlgo
import pandas as pd
[docs]class SurpriseBaseAlgo(BaseAlgo):
"""
Do not use this class directly.
This is the base class for all other sub-class which use the algorithms from
Python recommend package--'Surprise'.
Inherit from this base class will obtain some basic features.
"""
def __init__(self):
super().__init__()
self._user_log = None
self._surprise_model = None
[docs] def train(self, train_set):
if self._surprise_model is None:
self._surprise_model = self._init_surprise_model() # Initialize prediction model
self._user_log = pd.DataFrame(train_set)
self._user_log.columns = ['user_id', 'item_id']
''' Cause there is no rate in this situation, so just simply set rate to 1'''
rate_log = self._user_log.copy()
rate_log = rate_log.drop_duplicates()
rate_log['rate'] = 1
reader = Reader(rating_scale=(0, 1))
train_s = Dataset.load_from_df(rate_log, reader)
''' train surprise-framework based model '''
self._surprise_model.fit(train_s.build_full_trainset())
return self
[docs] def _init_surprise_model(self):
"""
Sub-class should implement this method which return a prediction algorithm from package 'Surprise'.
:return: A surprise-based recommend model
"""
raise NotImplementedError()
[docs] def top_k_recommend(self, u_id, k):
specific_user_log = self._user_log[self._user_log['user_id'] == u_id]
viewed_num = specific_user_log.shape[0]
assert (viewed_num != 0), "User id doesn't exist"
predict_rate_log = self._user_log.copy()
predict_rate_log = predict_rate_log[['item_id']].drop_duplicates()
predict_rate_log = predict_rate_log[~predict_rate_log['item_id'].isin(specific_user_log['item_id'])]
predict_rate_log['prate'] = predict_rate_log.apply(lambda row: self.predict(u_id, row['item_id']), axis=1)
predict_rate_log = predict_rate_log.sort_values(by=['prate'], ascending=False)
predict_rate_log = predict_rate_log[:k]
top_k_rate = predict_rate_log['prate'].values.tolist()
top_k_item = predict_rate_log['item_id'].values.tolist()
return top_k_rate, top_k_item
[docs] def predict(self, u_id, i_id):
"""
Predict the rate of user 'u_id' give to the item 'i_id'
:param u_id: user id
:param i_id: item id
:return: rate value
"""
_, _, _, est, _ = self._surprise_model.predict(u_id, i_id)
return est
[docs] def to_dict(self):
raise NotImplementedError()
[docs] @classmethod
def load(cls, fname):
res = super(SurpriseBaseAlgo, cls).load(fname)
assert (hasattr(res, '_surprise_model')), 'Not a standard SurpriseBaseAlgo class.'
setattr(res, '_surprise_model', dump.load(fname + '.surprise'))
return res
[docs] def save(self, fname, *args):
if len(args) == 0:
ignore = ['_surprise_model']
else:
ignore = args[0].append('_surprise_model')
dump.dump(fname + '.surprise', algo=self._surprise_model)
super().save(fname, ignore)