Skip to content

SVM classification added #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 16.18.0
- SVM classification added

## 16.17.3
- Log Likelihood Cost function:
- `dtype` passed
Expand Down
89 changes: 89 additions & 0 deletions e2e/svm_classifier/svm_classifier_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_algo/src/classifier/svm_classifier/svm_classifier.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:test/test.dart';

num evaluateSVMClassifier(MetricType metric, DType dtype) {
final data = getPimaIndiansDiabetesDataFrame().shuffle();
final samples = splitData(data, [0.8]);
final trainSamples = samples.first;
final testSamples = samples.last;
final model = SVMClassifier(
trainSamples,
'Outcome',
dtype: dtype,
iterationLimit: 10,
);

print('model: ${model.coefficientsByClasses}');

return model.assess(testSamples, metric);
}

Future<void> main() async {
group('SVM classifier', () {
test(
'should return adequate score on pima indians diabetes dataset using '
'accuracy metric, dtype=DType.float32', () {
final score =
evaluateSVMClassifier(MetricType.accuracy, DType.float32);

print('float32, accuracy is $score');

expect(score, greaterThan(0.7));
});

test(
'should return adequate score on pima indians diabetes dataset using '
'accuracy metric, dtype=DType.float64', () {
final score =
evaluateSVMClassifier(MetricType.accuracy, DType.float64);

print('float64, accuracy is $score');

expect(score, greaterThan(0.7));
});

test(
'should return adequate score on pima indians diabetes dataset using '
'precision metric, dtype=DType.float32', () {
final score =
evaluateSVMClassifier(MetricType.precision, DType.float32);

print('float32, precision is $score');

expect(score, greaterThan(0.65));
});

test(
'should return adequate score on pima indians diabetes dataset using '
'precision metric, dtype=DType.float64', () {
final score =
evaluateSVMClassifier(MetricType.precision, DType.float64);

print('float64, precision is $score');

expect(score, greaterThan(0.65));
});

test(
'should return adequate score on pima indians diabetes dataset using '
'recall metric, dtype=DType.float32', () {
final score = evaluateSVMClassifier(MetricType.recall, DType.float32);

print('float32, recall is $score');

expect(score, greaterThan(0.65));
});

test(
'should return adequate score on pima indians diabetes dataset using '
'recall metric, dtype=DType.float64', () {
final score = evaluateSVMClassifier(MetricType.recall, DType.float64);

print('float64, recall is $score');

expect(score, greaterThan(0.65));
});
});
}
55 changes: 37 additions & 18 deletions example/main.dart
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/matrix.dart';

/// A simple usage example using synthetic data. To see more complex examples,
/// please, visit other directories in this folder
Future<void> main() async {
// Let's create a dataframe with fitting data, let's assume, that the target
// column is the fifth column (column with index 4)
final dataFrame = DataFrame(<Iterable<num>>[
[2, 3, 4, 5, 4.3],
[12, 32, 1, 3, 3.5],
[27, 3, 0, 59, 2.1],
], headerExists: false);

// Let's create a regressor itself and train it
final regressor = LinearRegressor(dataFrame, 'col_4',
iterationsLimit: 100,
initialLearningRate: 0.0005,
learningRateType: LearningRateType.constant);

// Let's see adjusted coefficients
print('Regression coefficients: ${regressor.coefficients}');
final features = Matrix.fromList([
[2, 2],
[3, 3],
[4, 4],
[5, 5],
]);
final labels = Matrix.column([12, 18, 24, 30]);
final initialCoefficients = Matrix.column([0, 0]);

final coefficients = gradientDescent(features, labels, initialCoefficients);

print('Coefficients: $coefficients');
}

Matrix gradientDescent(Matrix X, Matrix Y, Matrix initialCoefficients) {
final learningRate = 1e-3;
final iterationLimit = 50;

var coefficients = initialCoefficients;

var coefficientDiff = 1e10;
var minCoefficientDiff = 1e-5;

for (var i = 0; i < iterationLimit; i++) {
if (coefficientDiff <= minCoefficientDiff) {
break;
}

final gradient = X.transpose() * -2 * (Y - X * coefficients);
final newCoefficients = coefficients - gradient * learningRate;

coefficientDiff = (newCoefficients - coefficients).norm();
coefficients = newCoefficients;
}

return coefficients;
}
26 changes: 26 additions & 0 deletions lib/src/classifier/svm_classifier/svm_classifier.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import 'package:ml_algo/src/classifier/linear_classifier.dart';
import 'package:ml_algo/src/classifier/svm_classifier/svm_classifier_impl.dart';
import 'package:ml_algo/src/common/serializable/serializable.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/predictor/retrainable.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';

abstract class SVMClassifier
implements
Assessable,
Serializable,
Retrainable<SVMClassifier>,
LinearClassifier {
factory SVMClassifier(
DataFrame trainData,
String targetName, {
DType dtype,
num learningRate,
int iterationLimit,
bool fitIntercept,
num interceptScale,
num negativeLabel,
num positiveLabel,
}) = SVMClassifierImpl;
}
115 changes: 115 additions & 0 deletions lib/src/classifier/svm_classifier/svm_classifier_impl.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import 'dart:io';

import 'package:ml_algo/src/classifier/_mixins/assessable_classifier_mixin.dart';
import 'package:ml_algo/src/classifier/_mixins/linear_classifier_mixin.dart';
import 'package:ml_algo/src/classifier/svm_classifier/svm_classifier.dart';
import 'package:ml_algo/src/common/serializable/serializable_mixin.dart';
import 'package:ml_algo/src/helpers/add_intercept_if.dart';
import 'package:ml_algo/src/helpers/features_target_split.dart';
import 'package:ml_algo/src/linear_optimizer/svm_optimizer.dart';
import 'package:ml_algo/src/link_function/link_function.dart';
import 'package:ml_dataframe/src/data_frame/data_frame.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';

class SVMClassifierImpl with LinearClassifierMixin, AssessableClassifierMixin, SerializableMixin implements SVMClassifier {
SVMClassifierImpl(
DataFrame trainData,
String targetName, {
DType dtype = DType.float32,
num learningRate = 1e-4,
int iterationLimit = 100,
bool fitIntercept = true,
num interceptScale = 1,
num negativeLabel = 0,
num positiveLabel = 1,
}) : dtype = dtype,
learningRate = learningRate,
fitIntercept = fitIntercept,
interceptScale = interceptScale,
negativeLabel = negativeLabel,
positiveLabel = positiveLabel {
final splits = featuresTargetSplit(trainData, targetNames: [targetName]);
final features = splits.first.toMatrix(dtype);
final labels = splits.last.toMatrix(dtype);

coefficientsByClasses = SVMOptimizer(
features:
addInterceptIf(fitIntercept, features, interceptScale, dtype),
labels: labels,
learningRate: learningRate,
iterationLimit: iterationLimit,
dtype: dtype)
.findExtrema(isMinimizingObjective: true);
}

@override
late Matrix coefficientsByClasses;

@override
final DType dtype;

@override
final bool fitIntercept;

@override
final num interceptScale;

@override
late LinkFunction linkFunction;

@override
final num negativeLabel;

@override
final num positiveLabel;

final num learningRate;

@override
DataFrame predict(DataFrame testFeatures) {
final predictedLabels = getProbabilitiesMatrix(testFeatures).mapColumns(
(column) => column.mapToVector((probability) =>
probability >= .5
? positiveLabel.toDouble()
: negativeLabel.toDouble()),
);

return DataFrame.fromMatrix(
predictedLabels,
header: targetNames,
);
}

@override
DataFrame predictProbabilities(DataFrame testFeatures) {
// TODO: implement predictProbabilities
throw UnimplementedError();
}

@override
SVMClassifier retrain(DataFrame data) {
// TODO: implement retrain
throw UnimplementedError();
}

@override
Future<File> saveAsJson(String filePath) {
// TODO: implement saveAsJson
throw UnimplementedError();
}

@override
// TODO: implement schemaVersion
int? get schemaVersion => throw UnimplementedError();

@override
// TODO: implement targetNames
Iterable<String> get targetNames => throw UnimplementedError();

@override
Map<String, dynamic> toJson() {
// TODO: implement toJson
throw UnimplementedError();
}
}
59 changes: 59 additions & 0 deletions lib/src/linear_optimizer/svm_optimizer.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import 'package:ml_algo/src/linear_optimizer/linear_optimizer.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:quiver/iterables.dart';
import 'package:xrange/xrange.dart';

class SVMOptimizer implements LinearOptimizer {
SVMOptimizer({
required Matrix features,
required Matrix labels,
required num learningRate,
required int iterationLimit,
required DType dtype,
}) : _features = features,
_labels = labels,
_learningRate = learningRate,
_iterations = integers(0, iterationLimit),
_dtype = dtype;

final Matrix _features;
final Matrix _labels;
final num _learningRate;
final Iterable<int> _iterations;
final DType _dtype;

@override
// TODO: implement costPerIteration
List<num> get costPerIteration => throw UnimplementedError();

@override
Matrix findExtrema(
{Matrix? initialCoefficients,
bool isMinimizingObjective = true,
bool collectLearningData = false}) {
var coefficients = initialCoefficients ??
Matrix.column(List.filled(_features.first.length, 0), dtype: _dtype);

for (final epochs in _iterations) {
final predicted = _features * coefficients;
final production = predicted.multiply(_labels);

enumerate(production.columns.first).forEach((indexed) {
if (indexed.value >= 1) {
coefficients =
coefficients - (coefficients * 2 * 1 / epochs) * _learningRate;
} else {
final val = Matrix.fromColumns(
[_features[indexed.index] * _labels[indexed.index][0]],
dtype: _dtype);

coefficients = coefficients +
(val - coefficients * 2 * 1 / epochs) * _learningRate;
}
});
}

return coefficients;
}
}
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 16.17.3
version: 16.18.0
homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down