Skip to content

Commit 22a5f90

Browse files
authored
Adds contours by fitting the a grid of values. (#12)
Fixes #1
1 parent 95a2671 commit 22a5f90

File tree

5 files changed

+156
-43
lines changed

5 files changed

+156
-43
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
In the tradition of Tkinter SVM GUI, the purpose of this app is to demonstrate how machine learning model forms are affected by the shape of the underlying dataset. By selecting a dataset or by creating one of your own, you can fit a model to the data and see how the model would make decisions based on the data it has been trained on. Although this is a toy example, hopefully it helps give you the intuition that the machine learning process is a model selection search for the best combination of features, algorithm, and hyperparameter that generalize well in a bounded feature space.
66

7+
![Screenshot](static/img/screenshot.png)
8+
79
## Getting Started
810

911
To run this app locally, first clone the repository and install the requirements:

app.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
##########################################################################
2020

2121
import json
22+
import numpy as np
2223

2324
from flask import Flask
2425
from flask import render_template, jsonify, request
@@ -94,6 +95,7 @@ def fit():
9495
data = request.get_json()
9596
params = data.get("model", {})
9697
dataset = data.get("dataset", [])
98+
grid = data.get("grid", [])
9799
model = {
98100
'gaussiannb': GaussianNB(),
99101
'multinomialnb': MultinomialNB(),
@@ -132,8 +134,22 @@ def fit():
132134
yhat = model.predict(X)
133135
metrics = prfs(y, yhat, average="macro")
134136

137+
# Make probability predictions on the grid to implement contours
138+
# The returned value is the class index + the probability
139+
# To get the selected class in JavaScript, use Math.floor(p)
140+
# Where p is the probability returned by the grid. Note that this
141+
# method guarantees that no P(c) == 1 to prevent class misidentification
142+
Xp = asarray([
143+
[point["x"], point["y"]] for point in grid
144+
])
145+
preds = []
146+
for proba in model.predict_proba(Xp):
147+
c = np.argmax(proba)
148+
preds.append(float(c+proba[c])-0.000001)
149+
135150
return jsonify({
136-
"metrics": dict(zip(["precision", "recall", "f1", "support"], metrics))
151+
"metrics": dict(zip(["precision", "recall", "f1", "support"], metrics)),
152+
"grid": preds,
137153
})
138154

139155
##########################################################################

static/img/screenshot.png

1.38 MB
Loading

static/js/dataspace.js

Lines changed: 135 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,48 +24,50 @@ function alertMessage(message) {
2424

2525
class Dataspace {
2626
constructor(selector) {
27-
this.svg = d3.select(selector);
28-
this.$svg = $(selector);
29-
this.dataset = [];
30-
31-
// drawing properties are hardcoded for now
32-
this.width = this.$svg.width();
33-
this.height = this.$svg.height();
34-
this.color = d3.scaleOrdinal(d3.schemeCategory10);
35-
36-
this.xScale = d3.scaleLinear()
37-
.domain([0, 1])
38-
.range([margin.left, this.width - margin.right]);
39-
40-
this.yScale = d3.scaleLinear()
41-
.domain([0, 1])
42-
.range([margin.top, this.height - margin.bottom])
27+
this.svg = d3.select(selector);
28+
this.$svg = $(selector);
29+
this.dataset = [];
30+
this.grid = null;
31+
32+
// drawing properties are hardcoded for now
33+
this.width = this.$svg.width();
34+
this.height = this.$svg.height();
35+
this.color = d3.scaleOrdinal(d3.schemeCategory10);
36+
37+
this.xScale = d3.scaleLinear()
38+
.domain([0, 1])
39+
.range([margin.left, this.width - margin.right]);
40+
41+
this.yScale = d3.scaleLinear()
42+
.domain([0, 1])
43+
.range([margin.top, this.height - margin.bottom])
4344
}
4445

4546
draw() {
4647
var self = this;
4748
self.svg.selectAll("circle")
48-
.data(self.dataset)
49-
.enter()
50-
.append("circle")
51-
.attr('cx', function (d) { return self.xScale(d.x); })
52-
.attr('cy', function (d) { return self.yScale(d.y); })
53-
.attr('fill', function (d) { return self.color(d.c); })
54-
.attr('r', radius);
49+
.data(self.dataset)
50+
.enter()
51+
.append("circle")
52+
.attr('cx', function (d) { return self.xScale(d.x); })
53+
.attr('cy', function (d) { return self.yScale(d.y); })
54+
.attr('fill', function (d) { return self.color(d.c); })
55+
.attr("stroke", "#FFFFFF")
56+
.attr('r', radius);
5557
}
5658

5759
// Add raw data point (e.g. where x and y are between 0 and 1)
5860
addPoint(point) {
59-
this.dataset.push(point);
60-
this.draw();
61+
this.dataset.push(point);
62+
this.draw();
6163
}
6264

6365
// Add coordinates data point (e.g. where x and y are in the svg)
6466
addCoords(coords) {
6567
var point = {
66-
x: this.xScale.invert(coords[0]),
67-
y: this.yScale.invert(coords[1]),
68-
c: currentClass
68+
x: this.xScale.invert(coords[0]),
69+
y: this.yScale.invert(coords[1]),
70+
c: currentClass
6971
};
7072
this.addPoint(point);
7173
}
@@ -74,14 +76,14 @@ class Dataspace {
7476
fetch(data) {
7577
this.reset();
7678
d3.json("/generate", {
77-
method: "POST",
78-
body: JSON.stringify(data),
79-
headers: {
80-
"Content-Type": "application/json; charset=UTF-8"
81-
}
79+
method: "POST",
80+
body: JSON.stringify(data),
81+
headers: {
82+
"Content-Type": "application/json; charset=UTF-8"
83+
}
8284
}).then(json => {
83-
this.dataset = json;
84-
this.draw();
85+
this.dataset = json;
86+
this.draw();
8587
}).catch(error => {
8688
console.log(error);
8789
alertMessage("Server could not generate dataset!");
@@ -92,13 +94,18 @@ class Dataspace {
9294
fit(model) {
9395
$("#metrics").removeClass("visible").addClass("invisible");
9496
if (this.dataset.length == 0) {
95-
console.log("cannot fit model to no data!");
96-
return
97+
console.log("cannot fit model to no data!");
98+
return
9799
}
98100

101+
// The contours grid determines what to make predictions on.
102+
// TODO: don't pass this to the server but allow the server to compute it.
103+
var self = this;
104+
self.grid = self.contoursGrid()
99105
var data = {
100-
model: model,
101-
dataset: this.dataset,
106+
model: model,
107+
dataset: self.dataset,
108+
grid: self.grid,
102109
}
103110

104111
d3.json("/fit", {
@@ -108,18 +115,64 @@ class Dataspace {
108115
"Content-Type": "application/json; charset=UTF-8"
109116
}
110117
}).then(json => {
118+
// Reset the old contours
119+
self.svg.selectAll("g").remove();
120+
121+
// Update the metrics
111122
$("#f1score").text(json.metrics.f1);
112123
$("#metrics").removeClass("invisible").addClass("visible");
124+
125+
// Update the grid with the predictions values.
126+
$.each(json.grid, function(i, val) {
127+
self.grid[i] = val;
128+
})
129+
130+
// Compute the thresholds from the classes, then compute the colors
131+
var thresholds = self.classes().map(i => d3.range(i, i + 1, 0.1)).flat().sort();
132+
var colorMap = {}
133+
$.each(self.classes(), c => {
134+
colorMap[c] = d3.scaleLinear().domain([c, c+1])
135+
.interpolate(d3.interpolateHcl)
136+
.range(["#FFFFFF", self.color(c)])
137+
});
138+
139+
var getColor = d => {
140+
console.log(d.value)
141+
return colorMap[Math.floor(d.value)](d.value)
142+
}
143+
144+
// Add the contours from the predictions for each class
145+
var contours = d3.contours()
146+
.size([self.grid.n, self.grid.m])
147+
.thresholds(thresholds)
148+
.smooth(true)
149+
(self.grid)
150+
.map(self.grid.transform)
151+
152+
// Draw the contours on the SVG
153+
self.svg.insert("g", ":first-child")
154+
.attr("fill", "none")
155+
.attr("stroke", "#FFFFFF")
156+
.attr("stroke-opacity", 0.65)
157+
.selectAll("path")
158+
.data(contours) // Here is where the contours gets added
159+
.join("path")
160+
.attr("fill", getColor) // Here is the color value!
161+
.style("opacity", 0.85)
162+
.attr("d", d3.geoPath());
163+
113164
}).catch(error => {
165+
console.log(error);
114166
alertMessage("Could not fit model, check JSON hyperparams and try again!");
115167
});
116168
}
117169

118170
// Reset the plotting area
119171
reset() {
120-
this.dataset = [];
121-
this.svg.selectAll("circle").remove();
122-
$("#metrics").removeClass("visible").addClass("invisible");
172+
this.dataset = [];
173+
this.svg.selectAll("circle").remove();
174+
this.svg.selectAll("g").remove();
175+
$("#metrics").removeClass("visible").addClass("invisible");
123176
}
124177

125178
// Count the number of classes in the dataset
@@ -132,6 +185,46 @@ class Dataspace {
132185
}, []);
133186
}
134187

188+
// Create the contours grid to pass to the predict function.
189+
contoursGrid() {
190+
var self = this;
191+
const q = 4;
192+
const x0 = -q / 2, x1 = this.width + margin.right + q;
193+
const y0 = -q / 2, y1 = this.height + q;
194+
const n = Math.ceil((x1-x0) / q);
195+
const m = Math.ceil((y1-y0) / q);
196+
const grid = new Array(n*m);
197+
grid.x = -q;
198+
grid.y = -q;
199+
grid.k = q;
200+
grid.n = n;
201+
grid.m = m;
202+
203+
// Converts from grid coordinates (indexes) to screen coordinates (pixels).
204+
grid.transform = ({ type, value, coordinates }) => {
205+
return {
206+
type, value, coordinates: coordinates.map(rings => {
207+
return rings.map(points => {
208+
return points.map(([x, y]) => ([
209+
grid.x + grid.k * x,
210+
grid.y + grid.k * y
211+
]));
212+
});
213+
})
214+
};
215+
}
216+
217+
// We just have to pass the x and y values to the server to predict them using the model, then the rest of the code is the sames?
218+
for (let j = 0; j < m; ++j) {
219+
for (let i = 0; i < n; ++i) {
220+
var obj = { x: this.xScale.invert(i * q + x0), y: this.yScale.invert(j * q + y0) };
221+
grid[j * grid.n + i] = obj;
222+
}
223+
}
224+
225+
return grid;
226+
}
227+
135228
}
136229

137230
$(document).ready(function() {

templates/index.html

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@
200200
<div class="tab-pane fade" id="svm" role="tabpanel">
201201
<form class="form">
202202
<input type="hidden" name="model" value="svm" />
203+
<input type="hidden" name="probability" value="true" />
203204
<div class="row">
204205
<div class="col-md-2">
205206
<div class="form-check">
@@ -445,6 +446,7 @@ <h5 class="modal-title" id="aboutModalLabel">About Data Space</h5>
445446
machine learning model forms are affected by the shape of the underlying dataset. By
446447
selecting a dataset or by creating one of your own, you can fit a model to the data
447448
and see how the model would make decisions based on the data it has been trained on.
449+
The fitted contours display the highest likelihoods of the class the model would select.
448450
Although this is a toy example, hopefully it helps give you the intuition that the
449451
machine learning process is a model selection search for the best combination of features,
450452
algorithm, and hyperparameter that generalize well in a bounded feature space.

0 commit comments

Comments
 (0)