{
"responses": {
"Hi": ["Hi", "Hello"],
"Hello": ["Hello", "Sup"],
"how are you doing": ["pretty good, thank you"],
"what day is it": ["it is monday"]
}
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 142,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to /Users/lobster/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n",
"[nltk_data] Downloading package wordnet to /Users/lobster/nltk_data...\n",
"[nltk_data] Package wordnet is already up-to-date!\n",
"[nltk_data] Downloading package stopwords to\n",
"[nltk_data] /Users/lobster/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 142,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import nltk\n",
"import os\n",
"import os.path\n",
"import json\n",
"import string\n",
"import random\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"from nltk.corpus import stopwords\n",
"\n",
"nltk.download('punkt') # first-time use only\n",
"nltk.download('wordnet') # first-time use only\n",
"nltk.download('stopwords') # first-time use only"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [
"RESPONSE_TRESHOLD = 0.6"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
"# it seems that nltk stopwords list is too big, lets use our own\n",
"class OurStopwords:\n",
" def words(self, _):\n",
" return ['a', 'the', 'is', 'are', 'will', 'was', 'were']\n",
" \n",
"stopwords = OurStopwords()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load database"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"def LemTokens(tokens):\n",
" return [lemmer.lemmatize(token) for token in tokens]\n",
"\n",
"remove_punct_dict = dict((ord(punct), None) for punct in string.punctuation)\n",
"\n",
"def LemNormalize(text):\n",
" return LemTokens(nltk.word_tokenize(text.lower().translate(remove_punct_dict)))"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DB created\n"
]
}
],
"source": [
"class ResponseDB:\n",
" def __init__(self):\n",
" print(\"DB created\")\n",
" self.data = {}\n",
" if os.path.isfile(\"responses.db\"):\n",
" with open('responses.db') as f:\n",
" data = json.load(f)\n",
" self.data = data['responses']\n",
" \n",
" self.recomputeTfIdf()\n",
" \n",
" \n",
" def recomputeTfIdf(self):\n",
" self.tfidf = TfidfVectorizer()\n",
" self.tfidf_matrix = self.tfidf.fit_transform(list(self.data.keys()))\n",
" \n",
" \n",
" def getAnswersProb(self, raw):\n",
" phrase = LemNormalize(raw)\n",
" phrase = [e for e in phrase if e not in stopwords.words('english')]\n",
" print(phrase)\n",
" \n",
" tfidf = self.tfidf.transform( [' '.join(phrase)] )\n",
" print(\"tfidf\", tfidf)\n",
" print(\"tfidf_matrix\", self.tfidf_matrix)\n",
" \n",
" similarity = cosine_similarity(tfidf, self.tfidf_matrix)[0]\n",
" similarity = [(i, sim) for i, sim in enumerate(similarity)]\n",
" similarity = sorted(similarity, key=lambda x: x[1], reverse=True)\n",
" print(\"similarity\", similarity)\n",
" print(\"Top 3 phrases:\")\n",
" \n",
" result = []\n",
" for el in similarity:\n",
" result.append((el[1], self.data[list(self.data.keys())[el[0]]] ))\n",
" \n",
" \n",
" for i in range(max(len(similarity), 3)):\n",
" print(\" \", similarity[i], self.data[list(self.data.keys())[similarity[i][0]]])\n",
" \n",
" return result\n",
" \n",
" def getAnswers(self, raw):\n",
" ans = self.getAnswersProb(raw)\n",
" \n",
" result = [ e for e in ans if e[0] > RESPONSE_TRESHOLD ]\n",
" if len(result) == 0:\n",
" return None\n",
" \n",
" return random.choice(random.choice(result)[1])\n",
"\n",
" \n",
"\n",
"db = ResponseDB()"
]
},
{
"cell_type": "code",
"execution_count": 157,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['hi']\n",
"tfidf (0, 4)\t1.0\n",
"tfidf_matrix (0, 4)\t1.0\n",
" (1, 3)\t1.0\n",
" (2, 5)\t0.5\n",
" (2, 0)\t0.5\n",
" (2, 9)\t0.5\n",
" (2, 2)\t0.5\n",
" (3, 8)\t0.5\n",
" (3, 1)\t0.5\n",
" (3, 6)\t0.5\n",
" (3, 7)\t0.5\n",
"similarity [(0, 1.0), (1, 0.0), (2, 0.0), (3, 0.0)]\n",
"Top 3 phrases:\n",
" (0, 1.0) ['Hi', 'Hello']\n",
" (1, 0.0) ['Hello', 'Sup']\n",
" (2, 0.0) ['pretty good, thank you']\n",
" (3, 0.0) ['it is monday']\n"
]
},
{
"data": {
"text/plain": [
"[(1.0, ['Hi', 'Hello']),\n",
" (0.0, ['Hello', 'Sup']),\n",
" (0.0, ['pretty good, thank you']),\n",
" (0.0, ['it is monday'])]"
]
},
"execution_count": 157,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.getAnswersProb(\"Hi\")"
]
},
{
"cell_type": "code",
"execution_count": 162,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['hello']\n",
"tfidf (0, 3)\t1.0\n",
"tfidf_matrix (0, 4)\t1.0\n",
" (1, 3)\t1.0\n",
" (2, 5)\t0.5\n",
" (2, 0)\t0.5\n",
" (2, 9)\t0.5\n",
" (2, 2)\t0.5\n",
" (3, 8)\t0.5\n",
" (3, 1)\t0.5\n",
" (3, 6)\t0.5\n",
" (3, 7)\t0.5\n",
"similarity [(1, 1.0), (0, 0.0), (2, 0.0), (3, 0.0)]\n",
"Top 3 phrases:\n",
" (1, 1.0) ['Hello', 'Sup']\n",
" (0, 0.0) ['Hi', 'Hello']\n",
" (2, 0.0) ['pretty good, thank you']\n",
" (3, 0.0) ['it is monday']\n"
]
},
{
"data": {
"text/plain": [
"'Hello'"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.getAnswers(\"Hello\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}