Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions examples/ag_news.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from npc_gzip.compressors.base import BaseCompressor
from npc_gzip.compressors.gzip_compressor import GZipCompressor
from npc_gzip.compressors.bz2_compressor import Bz2Compressor
from npc_gzip.compressors.lzma_compressor import LzmaCompressor
from npc_gzip.knn_classifier import KnnClassifier


Expand Down Expand Up @@ -43,7 +45,7 @@ def get_data() -> tuple:
return (train, test)


def fit_model(
def fit_model_gzip(
train_text: np.ndarray, train_labels: np.ndarray, distance_metric: str = "ncd"
) -> KnnClassifier:
"""
Expand All @@ -68,15 +70,45 @@ def fit_model(

return model

def fit_model_bz2(
train_text: np.ndarray, train_labels: np.ndarray, distance_metric: str = "ncd"
) -> KnnClassifier:

compressor: BaseCompressor = Bz2Compressor()
model: KnnClassifier = KnnClassifier(
compressor=compressor,
training_inputs=train_text,
training_labels=train_labels,
distance_metric=distance_metric,
)

return model

def fit_model_lzma(
train_text: np.ndarray, train_labels: np.ndarray, distance_metric: str = "ncd"
) -> KnnClassifier:

compressor: BaseCompressor = LzmaCompressor()
model: KnnClassifier = KnnClassifier(
compressor=compressor,
training_inputs=train_text,
training_labels=train_labels,
distance_metric=distance_metric,
)

return model


def main() -> None:
print("Fetching data...")
((train_text, train_labels), (test_text, test_labels)) = get_data()

print("Fitting model...")
model = fit_model(train_text, train_labels)
random_indicies = np.random.choice(test_text.shape[0], 1000, replace=False)

print("Fitting model...")
print("gzip")
model = fit_model_gzip(train_text, train_labels)

sample_test_text = test_text[random_indicies]
sample_test_labels = test_labels[random_indicies]

Expand All @@ -94,6 +126,39 @@ def main() -> None:

print(classification_report(sample_test_labels, labels.reshape(-1)))


"""

print("Fitting model...")
print("bz2")
model = fit_model_bz2(train_text, train_labels)

print("Generating predictions...")

(distances, labels, similar_samples) = model.predict(
sample_test_text, top_k, sampling_percentage=0.01
)

print(classification_report(sample_test_labels, labels.reshape(-1)))\




print("Fitting model...")
print("lzma")
model = fit_model_lzma(train_text, train_labels)

print("Generating predictions...")

(distances, labels, similar_samples) = model.predict(
sample_test_text, top_k, sampling_percentage=0.01
)

print(classification_report(sample_test_labels, labels.reshape(-1)))


"""


if __name__ == "__main__":
main()
Expand Down