Skip to content

Commit 6b75c71

Browse files
committed
naive bayes hyperparams
1 parent 4a2414c commit 6b75c71

File tree

4 files changed

+310
-22
lines changed

4 files changed

+310
-22
lines changed

app.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,19 @@
1818
## Imports
1919
##########################################################################
2020

21+
import json
22+
2123
from flask import Flask
2224
from flask import render_template, jsonify, request
2325

2426
from numpy import asarray
2527
from functools import partial
2628

2729
from sklearn.svm import SVC
28-
from sklearn.naive_bayes import MultinomialNB
2930
from sklearn.preprocessing import MinMaxScaler
3031
from sklearn.linear_model import LogisticRegression
3132
from sklearn.metrics import precision_recall_fscore_support as prfs
33+
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB, ComplementNB
3234
from sklearn.datasets import make_blobs, make_circles, make_moons, make_classification
3335

3436

@@ -65,12 +67,16 @@ def generate():
6567
# TODO: test content type and send 400 if not JSON
6668
data = request.get_json()
6769
generator = {
70+
'…': None, None: None, "": None,
6871
'binary': make_binary,
6972
'multiclass': make_multiclass,
7073
'blobs': make_blobs,
7174
'circles': make_circles,
7275
'moons': make_moons,
73-
}[data.get("generator", "binary")]
76+
}[data.get("generator", None)]
77+
78+
if generator is None:
79+
return "invalid generate request: unspecified data generator", 400
7480

7581
X, y = generator()
7682
X = MinMaxScaler().fit_transform(X)
@@ -89,14 +95,24 @@ def fit():
8995
params = data.get("model", {})
9096
dataset = data.get("dataset", [])
9197
model = {
92-
'bayes': MultinomialNB(),
98+
'gaussiannb': GaussianNB(),
99+
'multinomialnb': MultinomialNB(),
100+
'bernoullinb': BernoulliNB(),
101+
'complementnb': ComplementNB(),
93102
'svm': SVC(),
94103
'logit': LogisticRegression(),
95104
}.get(params.pop("model", None), None)
96105

97106
# Validate the request is correct and sane
98107
if model is None or len(dataset) == 0:
99-
return "invalid fit request", 400
108+
return "invalid fit request: please specify model and data", 400
109+
110+
try:
111+
params = {
112+
key: json.loads(val) for key, val in params.items()
113+
}
114+
except json.decoder.JSONDecodeError:
115+
return "invalid fit request: cannot parse json hyperparameters", 400
100116

101117
# Set the hyperparameters on the model
102118
model.set_params(**params)

hyparams.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python
2+
# hparams
3+
# Prints out hyperparameters and their defaults for various models.
4+
#
5+
# Author: Benjamin Bengfort <benjamin@bengfort.com>
6+
# Created: Wed Nov 27 10:45:19 2019 -0500
7+
#
8+
# Copyright (C) 2019 Georgetown Data Analytics (CCPE)
9+
# For license information, see LICENSE.txt
10+
#
11+
# ID: hparams.py [] benjamin@bengfort.com $
12+
13+
"""
14+
Prints out hyperparameters and their defaults for various models.
15+
"""
16+
17+
##########################################################################
18+
## Imports
19+
##########################################################################
20+
21+
import pprint
22+
import argparse
23+
24+
from sklearn.svm import SVC
25+
from sklearn.linear_model import LogisticRegression
26+
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB, ComplementNB
27+
28+
29+
MODELS = {
30+
"svm": SVC,
31+
"logistic": LogisticRegression,
32+
"gaussiannb": GaussianNB,
33+
"multinomialnb": MultinomialNB,
34+
"bernoullinb": BernoulliNB,
35+
"complementnb": ComplementNB,
36+
}
37+
38+
39+
##########################################################################
40+
## Main Method
41+
##########################################################################
42+
43+
def main(args):
44+
for model in args.model:
45+
params = MODELS[model]().get_params()
46+
pprint.pprint(params)
47+
print("\n")
48+
49+
50+
if __name__ == "__main__":
51+
parser = argparse.ArgumentParser(
52+
description="prints out hyperparameters and their defaults"
53+
)
54+
55+
parser.add_argument(
56+
"model", choices=MODELS.keys(), nargs="+",
57+
help="the models for whom to print out the params and defaults"
58+
)
59+
60+
args = parser.parse_args()
61+
main(args)

static/js/dataspace.js

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,23 @@ const margin = {top: 10, right: 10, bottom: 10, left: 10};
44
var app = null;
55
var currentClass = 0;
66

7+
// Displays a danger alert message in the top of the screen.
8+
function alertMessage(message) {
9+
var alert = $('<div class="alert alert-danger alert-dismissible fade show mb-0 mt-0" role="alert">');
10+
var btnClose = $('<button type="button" class="close" data-dismiss="alert" aria-label="Close">');
11+
var x = $.parseHTML('<span aria-hidden="true">&times;</span>');
12+
13+
alert.text(message);
14+
btnClose.append(x);
15+
alert.append(btnClose);
16+
17+
$("#alerts").append(alert);
18+
19+
setTimeout(function() {
20+
alert.alert('close');
21+
alert.alert('dispose');
22+
}, 2000);
23+
}
724

825
class Dataspace {
926
constructor(selector) {
@@ -65,11 +82,15 @@ class Dataspace {
6582
}).then(json => {
6683
this.dataset = json;
6784
this.draw();
85+
}).catch(error => {
86+
console.log(error);
87+
alertMessage("Server could not generate dataset!");
6888
});
6989
}
7090

7191
// Fit the model specified in the data fields to the data in the plot
7292
fit(model) {
93+
$("#metrics").removeClass("visible").addClass("invisible");
7394
if (this.dataset.length == 0) {
7495
console.log("cannot fit model to no data!");
7596
return
@@ -89,6 +110,8 @@ class Dataspace {
89110
}).then(json => {
90111
$("#f1score").text(json.metrics.f1);
91112
$("#metrics").removeClass("invisible").addClass("visible");
113+
}).catch(error => {
114+
alertMessage("Could not fit model, check JSON hyperparams and try again!");
92115
});
93116
}
94117

@@ -137,17 +160,24 @@ $(document).ready(function() {
137160
return false;
138161
})
139162

140-
// Change the model hyperparameter tabs on select
163+
// Change the model hyperparameter tabs on select and set the current model family.
141164
$("select#modelSelect").change(function(e) {
142165
e.preventDefault();
166+
// Deactivate current model control form tab
143167
$('#modelTabs [class*="active"]').removeClass("show active");
144168

169+
// Activate the selected model control form tab
145170
var model = $(e.target).val();
146171
$("#"+model).addClass("show active");
172+
173+
// Ensure that the info button points to the currect model
174+
$("#infoBtn").attr("data-target", "#" + model + "InfoModal");
175+
147176
return false;
148177
})
149178

150-
// Display the model when the fit button is clicked
179+
// POST the active model control form when the fit button is clicked then render
180+
// the model contours and score (along with any other model-visualizations).
151181
$("button#fitBtn").click(function(e) {
152182
e.preventDefault();
153183

@@ -157,8 +187,53 @@ $(document).ready(function() {
157187
return obj;
158188
}, {});
159189

190+
// Add unchecked checkboxes and change checkboxes to true/false
191+
form.find('input[type="checkbox"]').each(function() {
192+
var cb = $(this);
193+
if (!cb.prop("disabled")) {
194+
data[cb.attr("name")] = cb.prop("checked").toString();
195+
}
196+
});
197+
198+
console.log(data);
160199
app.fit(data);
161200
return false;
162201
});
163202

203+
// Enable the correct hyperparameters based on the selected naive bayes model
204+
$('#bayes input[name="model"]').change(function(e) {
205+
206+
// Disable all of the form controls except for the radios
207+
$('#bayes input[type="text"').prop("disabled", true);
208+
$('#bayes input[type="checkbox"').prop("disabled", true);
209+
210+
// Enable based on the model ID
211+
switch($(this).val()) {
212+
case "gaussiannb":
213+
$('#bayes input[name="priors"]').prop("disabled", false);
214+
$('#bayes input[name="var_smoothing"]').prop("disabled", false);
215+
break
216+
case "multinomialnb":
217+
$('#bayes input[name="alpha"]').prop("disabled", false);
218+
$('#bayes input[name="class_prior"]').prop("disabled", false);
219+
$('#bayes input[name="fit_prior"]').prop("disabled", false);
220+
break;
221+
case "bernoullinb":
222+
$('#bayes input[name="alpha"]').prop("disabled", false);
223+
$('#bayes input[name="binarize"]').prop("disabled", false);
224+
$('#bayes input[name="class_prior"]').prop("disabled", false);
225+
$('#bayes input[name="fit_prior"]').prop("disabled", false);
226+
break;
227+
case "complementnb":
228+
$('#bayes input[name="alpha"]').prop("disabled", false);
229+
$('#bayes input[name="class_prior"]').prop("disabled", false);
230+
$('#bayes input[name="fit_prior"]').prop("disabled", false);
231+
$('#bayes input[name="norm"]').prop("disabled", false);
232+
break;
233+
default:
234+
console.log("unknown bayesian model selected, cannot enable form!");
235+
}
236+
237+
});
238+
164239
});

0 commit comments

Comments
 (0)