Skip to content

Commit 0aa8a73

Browse files
authored
Fix null convergence (#7)
* Add treatment for null convergence due precision limits. * add OS update to build. * update pip version on build.
1 parent 09215c7 commit 0aa8a73

File tree

9 files changed

+95
-76
lines changed

9 files changed

+95
-76
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
flake8:
44
pip install -U flake8
5-
flake8
5+
flake8 pyClickModels
66

77
isort:
88
pip install -U isort
@@ -30,4 +30,4 @@ publish:
3030
sh ./scripts/build_wheels.sh
3131
#twine upload --repository testpypi dist/*
3232
twine upload dist/*
33-
rm -fr build dist .egg *.egg-info
33+
#rm -fr build dist .egg *.egg-info

pyClickModels/DBN.pyx

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
# cython: linetrace=True
22

3-
import os
4-
from glob import glob
53
import gzip
4+
import os
65
import time
6+
from glob import glob
7+
78
import ujson
8-
from libcpp.vector cimport vector
9-
from libcpp.unordered_map cimport unordered_map
10-
from libcpp.string cimport string
11-
from libc.stdlib cimport rand, RAND_MAX, srand
12-
from libc.time cimport time as ctime
9+
1310
from cython.operator cimport dereference, postincrement
14-
from pyClickModels.jsonc cimport(json_object, json_tokener_parse,
15-
json_object_object_get_ex, json_object_get_string,
16-
lh_table, lh_entry, json_object_array_length,
17-
json_object_array_get_idx, json_object_get_int,
18-
json_object_put)
11+
from libc.stdlib cimport RAND_MAX, rand, srand
12+
from libc.time cimport time as ctime
13+
from libcpp.string cimport string
14+
from libcpp.unordered_map cimport unordered_map
15+
from libcpp.vector cimport vector
1916

17+
from pyClickModels.jsonc cimport (json_object, json_object_array_get_idx,
18+
json_object_array_length,
19+
json_object_get_int, json_object_get_string,
20+
json_object_object_get_ex, json_object_put,
21+
json_tokener_parse, lh_entry, lh_table)
2022

2123
# Start by setting the seed for the random values required for initalizing the DBN
2224
# parameters.
@@ -125,12 +127,14 @@ cdef class Factor:
125127
result *= (1 - self.gamma) * (1 - self.cr)
126128
else:
127129
result *= self.gamma * (1 - self.cr)
130+
# Compute P(C_{>r},P_{>r} | E_{r+1})
128131
if not z:
129132
if self.last_r >= self.r + 1:
130133
return 0
131134
else:
132135
if self.r < self.cp_vector_given_e[0].size():
133136
result *= self.cp_vector_given_e[0][self.r]
137+
# P(E_r=x | C<r, P<r)
134138
result *= (self.e_r_vector_given_CP[0][self.r] if x else
135139
1 - self.e_r_vector_given_CP[0][self.r])
136140
return result
@@ -414,6 +418,7 @@ cdef class DBNModel():
414418
# Probability of clicks at positions greater than the last document in results
415419
# page is zero.
416420
X_r_vector[total_docs] = 0
421+
gamma = self.get_param(b'gamma')
417422

418423
for r in range(total_docs - 1, -1, -1):
419424
json_object_object_get_ex(
@@ -423,7 +428,6 @@ cdef class DBNModel():
423428
)
424429
doc = json_object_get_string(tmp)
425430
alpha = self.get_param(b'alpha', query, &doc)
426-
gamma = self.get_param(b'gamma')
427431

428432
X_r_1 = X_r_vector[r + 1]
429433
X_r = alpha[0] + (1 - alpha[0]) * gamma[0] * X_r_1
@@ -438,6 +442,10 @@ cdef class DBNModel():
438442
439443
Mathematically: P(E_r = 1 | C_{<r}, P_{<r})
440444
445+
This is discussed in equation (24) in the blog post:
446+
447+
https://towardsdatascience.com/how-to-extract-relevance-from-clickstream-data-2a870df219fb
448+
441449
Args
442450
----
443451
clickstream: *json_object
@@ -465,8 +473,10 @@ cdef class DBNModel():
465473
# position r + 1 will be required later so add +1 in computation
466474
vector[float] e_r_vector_given_CP = vector[float](total_docs + 1 - idx, 0.0)
467475

468-
# First document has 100% of being Examined regardless of clicks or purchases.
476+
# First document has 100% chance of being Examined regardless of clicks or
477+
# purchases.
469478
e_r_vector_given_CP[0] = 1
479+
gamma = self.get_param(b'gamma')
470480

471481
for r in range(idx, total_docs):
472482
json_object_object_get_ex(
@@ -492,7 +502,6 @@ cdef class DBNModel():
492502

493503
alpha = self.get_param(b'alpha', query, &doc)
494504
sigma = self.get_param(b'sigma', query, &doc)
495-
gamma = self.get_param(b'gamma')
496505

497506
if purchase:
498507
return e_r_vector_given_CP
@@ -602,6 +611,10 @@ cdef class DBNModel():
602611
603612
P(C_{>r}, P_{>r} | E_{r+1})
604613
614+
This is equation (25) from blog post:
615+
616+
https://towardsdatascience.com/how-to-extract-relevance-from-clickstream-data-2a870df219fb
617+
605618
Args
606619
----
607620
clickstream: *json_object
@@ -624,13 +637,10 @@ cdef class DBNModel():
624637

625638
# Subtract 1 as E_{r+1} is defined up to r - 1 documents
626639
for r in range(total_docs - 1):
627-
628640
e_r_vector_given_CP = self.build_e_r_vector_given_CP(clickstream, r + 1,
629641
query)
630-
631642
cp_vector_given_e[r] = self.compute_cp_p(clickstream, r + 1, query,
632643
&e_r_vector_given_CP, cr_dict)
633-
634644
return cp_vector_given_e
635645

636646
cdef int get_last_r(self, json_object *clickstream, const char *event=b'click'):
@@ -868,7 +878,7 @@ cdef class DBNModel():
868878
json_object_object_get_ex(doc_data, b'purchase', &tmp)
869879
purchase = json_object_get_int(tmp)
870880

871-
alpha = self.get_param(b'gamma', query, &doc)[0]
881+
alpha = self.get_param(b'alpha', query, &doc)[0]
872882
sigma = self.get_param(b'sigma', query, &doc)[0]
873883
gamma = self.get_param(b'gamma')[0]
874884

@@ -887,15 +897,19 @@ cdef class DBNModel():
887897
e_r_vector_given_CP,
888898
cp_vector_given_e
889899
)
900+
890901
# Loop through all possible values of x, y and z, where each is an integer
891902
# boolean.
892903
for i in range(2):
893904
for j in range(2):
894905
for k in range(2):
895906
ESS_denominator += factor.compute_factor(i, j, k)
896907

897-
ESS_0 = factor.compute_factor(1, 0, 0) / ESS_denominator
898-
ESS_1 = factor.compute_factor(1, 0, 1) / ESS_denominator
908+
if not ESS_denominator:
909+
ESS_0, ESS_1 = 0, 0
910+
else:
911+
ESS_0 = factor.compute_factor(1, 0, 0) / ESS_denominator
912+
ESS_1 = factor.compute_factor(1, 0, 1) / ESS_denominator
899913

900914
tmp_gamma_param[0][0] += ESS_1
901915
tmp_gamma_param[0][1] += ESS_0 + ESS_1

pyClickModels/__version__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
VERSION = (0, 0, 1)
2-
3-
__version__ = '.'.join([str(e) for e in VERSION])
1+
__version__ = '0.0.2'

scripts/build_wheels.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
docker run -v $(pwd):/pyClickModels quay.io/pypa/manylinux1_x86_64 sh -c '''
2+
yum update
23
yum install -y json-c-devel
34
45
cd /pyClickModels
56
67
for PYVER in /opt/python/*/bin/; do
78
if [[ $PYVER != *"27"* ]]; then
9+
"${PYVER}/pip" install -U pip
810
"${PYVER}/pip" install -U setuptools
911
"${PYVER}/pip" install -r requirements.txt
1012
"${PYVER}/python" setup.py sdist bdist_wheel

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import os
44
import sys
5-
import Cython.Compiler.Options
65
from codecs import open
7-
from Cython.Distutils import build_ext
86
from setuptools import setup
9-
from Cython.Build import cythonize
10-
from distutils.extension import Extension
117
from setuptools.command.test import test as TestCommand
8+
from distutils.extension import Extension
9+
import Cython.Compiler.Options
10+
from Cython.Distutils import build_ext
11+
from Cython.Build import cythonize
1212

1313

1414
here = os.path.abspath(os.path.dirname(__file__))

tests/conftest.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,3 @@ def build_DBN_test_data(users=10, docs=10, queries=2):
8282
for row in final_result[half_results:]:
8383
f.write(json.dumps(row).encode() + '\n'.encode())
8484
return persistence, params, tmp_folder
85-
86-
87-
@pytest.fixture
88-
def sessions():
89-
sessions = [
90-
{
91-
'sessionID': [
92-
{"doc": "doc0", "click": 0, "purchase": 0},
93-
{"doc": "doc1", "click": 1, "purchase": 0},
94-
{"doc": "doc2", "click": 1, "purchase": 1}
95-
]
96-
},
97-
{
98-
'sessionID': [
99-
{"doc": "doc0", "click": 0, "purchase": 0},
100-
{"doc": "doc1", "click": 1, "purchase": 0}
101-
]
102-
},
103-
104-
]
105-
return sessions
328 Bytes
Binary file not shown.
197 Bytes
Binary file not shown.

tests/test_cy_DBN.pyx

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1+
import gzip
12
import tempfile
3+
24
import ujson
3-
import gzip
5+
6+
from cython.operator cimport dereference, postincrement
47
from libcpp.string cimport string
58
from libcpp.unordered_map cimport unordered_map
69
from libcpp.vector cimport vector
7-
from cython.operator cimport dereference, postincrement
10+
811
from pyClickModels.DBN cimport DBNModel, Factor
12+
913
from pyClickModels.DBN import DBN
10-
from pyClickModels.jsonc cimport(json_object, json_tokener_parse,
11-
json_object_get_object, lh_table, json_object_put)
14+
15+
from pyClickModels.jsonc cimport (json_object, json_object_get_object,
16+
json_object_put, json_tokener_parse,
17+
lh_table)
18+
1219
from conftest import build_DBN_test_data
13-
from numpy.testing import assert_almost_equal, assert_allclose
20+
from numpy.testing import assert_allclose, assert_almost_equal
1421

1522
ctypedef unordered_map[string, unordered_map[string, float]] dbn_param
1623

@@ -49,7 +56,6 @@ cdef bint test_fit():
4956

5057
# it = model.alpha_params.begin()
5158
while(it != model.alpha_params.end()):
52-
# prints keys
5359
# print(dereference(it).first)
5460
query = (dereference(it).first)
5561
dquery = extract_keys(query)
@@ -1886,6 +1892,22 @@ cdef bint test_export_judgments():
18861892
return True
18871893

18881894

1895+
cdef bint test_not_null_converence():
1896+
cdef:
1897+
DBNModel model = DBN()
1898+
1899+
model.fit('tests/fixtures/null_test', iters=1)
1900+
return True
1901+
1902+
1903+
cdef bint test_long_list_null_converence():
1904+
cdef:
1905+
DBNModel model = DBN()
1906+
1907+
model.fit('tests/fixtures/eighty_skus', iters=2)
1908+
return True
1909+
1910+
18891911
cpdef run_tests():
18901912
assert test_get_search_context_string()
18911913
assert test_compute_cr()
@@ -1906,25 +1928,29 @@ cpdef run_tests():
19061928
assert test_update_gamma_param()
19071929
assert test_fit()
19081930
assert test_export_judgments()
1931+
assert test_not_null_converence()
1932+
assert test_long_list_null_converence()
19091933

19101934

19111935
if __name__ == '__main__':
1912-
assert test_get_search_context_string()
1913-
assert test_compute_cr()
1914-
assert test_get_param()
1915-
assert test_build_e_r_vector(&alpha_params, &sigma_params, &gamma_param)
1916-
assert test_build_X_r_vector(&alpha_params, &sigma_params, &gamma_param)
1917-
assert test_build_e_r_vector_given_CP(&alpha_params, &sigma_params, &gamma_param)
1918-
assert test_build_cp_p(&alpha_params)
1919-
assert test_build_CP_vector_given_e(&alpha_params, &sigma_params, &gamma_param)
1920-
assert test_get_last_r()
1921-
assert test_update_tmp_alpha(&alpha_params, &sigma_params, &gamma_param)
1922-
assert test_update_tmp_sigma(&alpha_params, &sigma_params, &gamma_param)
1923-
assert test_compute_factor_last_click_lower_than_r()
1924-
assert test_compute_factor_last_click_higher_than_r()
1925-
assert test_update_tmp_gamma()
1926-
assert test_update_alpha_params()
1927-
assert test_update_sigma_params()
1928-
assert test_update_gamma_param()
1929-
assert test_fit()
1930-
assert test_export_judgments()
1936+
#assert test_get_search_context_string()
1937+
#assert test_compute_cr()
1938+
#assert test_get_param()
1939+
#assert test_build_e_r_vector(&alpha_params, &sigma_params, &gamma_param)
1940+
#assert test_build_X_r_vector(&alpha_params, &sigma_params, &gamma_param)
1941+
#assert test_build_e_r_vector_given_CP(&alpha_params, &sigma_params, &gamma_param)
1942+
#assert test_build_cp_p(&alpha_params)
1943+
#assert test_build_CP_vector_given_e(&alpha_params, &sigma_params, &gamma_param)
1944+
#assert test_get_last_r()
1945+
#assert test_update_tmp_alpha(&alpha_params, &sigma_params, &gamma_param)
1946+
#assert test_update_tmp_sigma(&alpha_params, &sigma_params, &gamma_param)
1947+
#assert test_compute_factor_last_click_lower_than_r()
1948+
#assert test_compute_factor_last_click_higher_than_r()
1949+
#assert test_update_tmp_gamma()
1950+
#assert test_update_alpha_params()
1951+
#assert test_update_sigma_params()
1952+
#assert test_update_gamma_param()
1953+
#assert test_fit()
1954+
#assert test_export_judgments()
1955+
#assert test_not_null_converence()
1956+
pass

0 commit comments

Comments
 (0)