-
Notifications
You must be signed in to change notification settings - Fork 94
Add Spectrum Sense Dataset #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
BrentK-ADI
wants to merge
2
commits into
analogdevicesinc:develop
Choose a base branch
from
BrentK-ADI:spectrum-sense
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,384 @@ | ||
################################################################################################### | ||
# | ||
# Copyright (C) 2025 Analog Devices, Inc. All Rights Reserved. | ||
# This software is proprietary to Analog Devices, Inc. and its licensors. | ||
# | ||
################################################################################################### | ||
""" | ||
Classes and functions used to create spectrumsense dataset. | ||
""" | ||
import errno | ||
import glob | ||
import os | ||
import shutil | ||
import subprocess | ||
import tarfile | ||
import tempfile | ||
import urllib | ||
import urllib.error | ||
import urllib.request | ||
|
||
import numpy as np | ||
from torch.utils.data import Dataset | ||
from torchvision import transforms | ||
|
||
from PIL import Image | ||
|
||
import ai8x | ||
|
||
|
||
class SpectrumSense(Dataset): | ||
""" | ||
Dataloader for the Matlab Spectrum Sensing 5G/LTE Data Set/ | ||
|
||
The dataset is provided by Mathworks and accompanies their | ||
'Spectrum Sensing with Deep Learning to Identify 5G and LTE Signals' | ||
which can be found at the following URL. | ||
https://www.mathworks.com/help/comm/ug/spectrum-sensing-with-deep-learning-to-identify-5g-and-lte-signals.html | ||
|
||
Possible class selections: | ||
0: None | ||
1: LTE | ||
2: NR | ||
""" | ||
|
||
class_dict = {'None': 0, 'LTE': 1, 'NR': 2} | ||
|
||
# Values in the label images mapped to names | ||
mask_map = {'None': 0, 'LTE': 255, 'NR': 127} | ||
|
||
# pylint: disable-next=line-too-long | ||
url_spectrumsense = 'https://www.mathworks.com/supportfiles/spc/SpectrumSensing/SpectrumSenseTrainingDataNetwork.tar.gz' # noqa: E501 | ||
|
||
def __init__(self, root_dir, d_type, classes=None, transform=None, | ||
im_size=(256, 256), download=True): | ||
self.transform = transform | ||
self.classes = classes | ||
self.root = root_dir | ||
|
||
# Download the data set if necessary | ||
if download: | ||
self.__download() | ||
|
||
if d_type not in ('test', 'train'): | ||
raise ValueError('Data type name msut be test or train.') | ||
|
||
src_folder = os.path.join(self.root, d_type) | ||
lbl_folder = os.path.join(self.root, d_type + '_labels') | ||
self.src_list = [] | ||
self.lbl_list = [] | ||
|
||
self.__create_mask_dict(im_size) | ||
|
||
for src_file in sorted(glob.glob('*.png', root_dir=src_folder, recursive=False)): | ||
lbl_file = os.path.splitext(src_file)[0] + '_lbl.png' | ||
|
||
src_img = Image.open(os.path.join(src_folder, src_file)) | ||
lbl_img = Image.open(os.path.join(lbl_folder, lbl_file)) | ||
|
||
# Dataset images are 256x256. Pad out when the image size is 352 | ||
if im_size == [352, 352]: | ||
(src_img, lbl_img) = self.pad_image_and_label(src_img, lbl_img, im_size) | ||
|
||
src_data = SpectrumSense.normalize(np.asarray(src_img).astype(np.float32)) | ||
|
||
# Convert the grayscale label image to 3 channel (RGB) | ||
lbl_gray = np.asarray(lbl_img) | ||
lbl_rgb = np.empty((lbl_gray.shape[0], lbl_gray.shape[1], 3)) | ||
for i in range(3): | ||
lbl_rgb[:, :, i] = lbl_gray | ||
lbl_data = np.zeros((lbl_rgb.shape[0], lbl_rgb.shape[1]), dtype=np.uint8) | ||
|
||
for label_idx, (_, mask) in enumerate(self.label_mask_dict.items()): | ||
res = lbl_rgb == mask | ||
res = (label_idx + 1) * res.all(axis=2) | ||
lbl_data += res.astype(np.uint8) | ||
|
||
self.src_list.append(src_data) | ||
self.lbl_list.append(lbl_data) | ||
|
||
if self.classes: | ||
self.__filter_classes() | ||
|
||
def __create_mask_dict(self, img_dims): | ||
self.label_mask_dict = {} | ||
for lbl in self.class_dict: | ||
val = self.mask_map[lbl] | ||
label_mask = np.zeros((img_dims[0], img_dims[1], 3), dtype=np.uint8) | ||
label_mask[:, :, 0] = np.uint8(val) | ||
label_mask[:, :, 1] = np.uint8(val) | ||
label_mask[:, :, 2] = np.uint8(val) | ||
self.label_mask_dict[lbl] = label_mask | ||
|
||
def __filter_classes(self): | ||
for e in self.lbl_list: | ||
initial_new_class_label = len(self.class_dict) + 5 | ||
new_class_label = initial_new_class_label | ||
for l_class in self.classes: | ||
if l_class not in self.class_dict: | ||
print(f'Class is not in the data: {l_class}') | ||
return | ||
|
||
e[(e == self.class_dict[l_class])] = new_class_label | ||
new_class_label += 1 | ||
|
||
e[(e < initial_new_class_label)] = new_class_label | ||
e -= initial_new_class_label | ||
|
||
@classmethod | ||
def pad_image_and_label(cls, img, lbl, im_size): | ||
""" | ||
Pads the original images to the specified size and | ||
pastes the original content at the origin | ||
""" | ||
|
||
img_pad = Image.new(img.mode, im_size) | ||
img_pad.paste(img, (0, 0)) | ||
|
||
lbl_pad = Image.new(lbl.mode, im_size) | ||
lbl_pad.paste(lbl, (0, 0)) | ||
return (img_pad, lbl_pad) | ||
|
||
@staticmethod | ||
def normalize(data): | ||
"""Normalizes data to the range [0, 1)""" | ||
return data / 255. | ||
|
||
def __len__(self): | ||
return len(self.src_list) | ||
|
||
def __getitem__(self, idx): | ||
if self.transform is not None: | ||
img = self.transform(self.src_list[idx]) | ||
return img, self.lbl_list[idx].astype(np.int64) | ||
|
||
def __download(self): | ||
train_dir = os.path.join(self.root, 'train') | ||
test_dir = os.path.join(self.root, 'test') | ||
train_lbl_dir = os.path.join(self.root, 'train_labels') | ||
test_lbl_dir = os.path.join(self.root, 'test_labels') | ||
|
||
self.__makedir_exist_ok(train_dir) | ||
self.__makedir_exist_ok(test_dir) | ||
self.__makedir_exist_ok(train_lbl_dir) | ||
self.__makedir_exist_ok(test_lbl_dir) | ||
|
||
if self.__check_data_exists(): | ||
return # skip do | ||
|
||
# Need to make sure the user has the hdftor8 utility to convert Matlab | ||
# HDF4 format to a PNG. The pyHDF library didn't seem to expose the | ||
# DFR8 API | ||
if shutil.which('hdftor8') is None: | ||
print('This data set requires the HDF4 Tools to be installed\n' | ||
'Please use apt-get install hdf4-tools, or the respective ' | ||
'command for your platform') | ||
raise RuntimeError('HDF4 Tools Not installed.') | ||
|
||
# Do everything in temporary directory that can be cleaned up later | ||
tmpdir = tempfile.mkdtemp() | ||
extract_dir = os.path.join(tmpdir, 'rawdata') | ||
filename = self.url_spectrumsense.rpartition('/')[2] | ||
self.__download_url(self.url_spectrumsense, tmpdir, filename) | ||
self.__extract_archive(os.path.join(tmpdir, filename), extract_dir) | ||
|
||
# Copy the Mathworks licenses over | ||
print('This dataset is subject the license terms as described by Mathworks.\n' | ||
'Please review the license files extracted from the archive.') | ||
lic_files = list(glob.glob('*license*', root_dir=extract_dir, recursive=False)) | ||
for lic in lic_files: | ||
shutil.copy(os.path.join(extract_dir, lic), self.root) | ||
|
||
# Process the data files. TrainingData is used for training | ||
processed_files = self.__process_data( | ||
os.path.join(extract_dir, 'TrainingData'), train_dir, train_lbl_dir) | ||
# The LTE_NR sub-folder is used for test | ||
processed_files += self.__process_data( | ||
os.path.join(extract_dir, 'TrainingData', 'LTE_NR'), test_dir, test_lbl_dir) | ||
|
||
# Write all the files processed to manifest.txt. This doesn't really add | ||
# any value, but we'll use it to confirm this entire process was | ||
# successful next time training is requested. | ||
with open(os.path.join(self.root, 'manifest.txt'), 'w', encoding='utf-8') as mf: | ||
mf.write('\n'.join(sorted(processed_files))) | ||
|
||
# Clean up downloaded archive and files | ||
shutil.rmtree(tmpdir) | ||
|
||
def __process_data(self, from_path, to_path, to_lbl_path): | ||
""" | ||
Processes all the PNG files from the from_path and converts their | ||
respective HDF files to grayscale PNGs | ||
""" | ||
# Find all the PNGs. Don't recurse as a sub folder may be used for | ||
# different train/test data sets | ||
img_files = list(glob.glob('*.png', root_dir=from_path, recursive=False)) | ||
proc_files = [] | ||
|
||
for img in img_files: | ||
# Grab the original source image | ||
shutil.copy(os.path.join(from_path, img), to_path) | ||
|
||
# Create a temp file. Convert the HDF to that temp file. This will | ||
# just be a byte stream of data 65536 (256x256) | ||
rast_file = tempfile.mktemp() | ||
hdf_file = os.path.join(from_path, os.path.splitext(img)[0] + '.hdf') | ||
subprocess.run('hdftor8 ' + hdf_file + ' -r ' + rast_file, shell=True, check=False) | ||
|
||
# Now open the file as binary, and tell PIL to treat it like a | ||
# 256x256 8-bit single channel ('L') | ||
result_file = os.path.splitext(img)[0] + '_lbl.png' | ||
with open(rast_file, 'rb') as f: | ||
im = Image.frombytes('L', (256, 256), f.read()) | ||
im.save(os.path.join(to_lbl_path, result_file)) | ||
os.remove(rast_file) # Clean up | ||
|
||
proc_files.append(img) | ||
proc_files.append(result_file) | ||
return proc_files | ||
|
||
def __extract_archive(self, from_path, to_path=None, remove_finished=False): | ||
if to_path is None: | ||
to_path = os.path.dirname(from_path) | ||
|
||
if from_path.endswith('.tar.gz'): | ||
with tarfile.open(from_path, 'r:gz') as tar: | ||
tar.extractall(path=to_path) | ||
else: | ||
raise ValueError(f"Extraction of {from_path} not supported") | ||
|
||
if remove_finished: | ||
os.remove(from_path) | ||
|
||
def __download_url(self, url, dl_dir, filename=None): | ||
if not filename: | ||
filename = os.path.basename(url) | ||
fpath = os.path.join(dl_dir, filename) | ||
self.__makedir_exist_ok(dl_dir) | ||
|
||
# downloads file | ||
try: | ||
print('Downloading ' + url + ' to ' + fpath) | ||
urllib.request.urlretrieve(url, fpath) | ||
except (urllib.error.URLError, IOError) as e: | ||
if url[:5] == 'https': | ||
url = url.replace('https:', 'http:') | ||
print('Failed download. Trying https -> http instead.' | ||
' Downloading ' + url + ' to ' + fpath) | ||
urllib.request.urlretrieve(url, fpath) | ||
else: | ||
raise e | ||
|
||
def __check_data_exists(self): | ||
# Check for manifest.txt to determine if was already done | ||
return os.path.exists(os.path.join(self.root, 'manifest.txt')) | ||
|
||
def __makedir_exist_ok(self, dirpath): | ||
try: | ||
os.makedirs(dirpath) | ||
except OSError as e: | ||
if e.errno == errno.EEXIST: | ||
pass | ||
else: | ||
raise | ||
|
||
|
||
def spectrumsense_get_datasets_s256(data, load_train=True, load_test=True): | ||
""" | ||
Load the spectrumsense dataset in 48x88x88 format which are composed of 3x352x352 images folded | ||
with a fold_ratio of 4. | ||
""" | ||
(data_dir, args) = data | ||
|
||
classes = ['LTE', 'NR'] | ||
|
||
transform = transforms.Compose([transforms.ToTensor(), | ||
ai8x.normalize(args=args), | ||
ai8x.fold(fold_ratio=4)]) | ||
|
||
if load_train: | ||
train_dataset = SpectrumSense(root_dir=os.path.join(data_dir, 'SpectrumSense'), | ||
d_type='train', | ||
im_size=[256, 256], classes=classes, | ||
transform=transform) | ||
else: | ||
train_dataset = None | ||
|
||
if load_test: | ||
test_dataset = SpectrumSense(root_dir=os.path.join(data_dir, 'SpectrumSense'), | ||
d_type='test', | ||
im_size=[256, 256], classes=classes, | ||
transform=transform) | ||
if args.truncate_testset: | ||
test_dataset.src_list = test_dataset.src_list[:1] | ||
else: | ||
test_dataset = None | ||
|
||
return train_dataset, test_dataset | ||
|
||
|
||
def spectrumsense_get_datasets_s352(data, load_train=True, load_test=True): | ||
""" | ||
Load the spectrumsense dataset in 48x88x88 format which are composed of 3x352x352 images folded | ||
with a fold_ratio of 4. | ||
""" | ||
(data_dir, args) = data | ||
|
||
classes = ['LTE', 'NR'] | ||
|
||
transform = transforms.Compose([transforms.ToTensor(), | ||
ai8x.normalize(args=args), | ||
ai8x.fold(fold_ratio=4)]) | ||
|
||
if load_train: | ||
train_dataset = SpectrumSense(root_dir=os.path.join(data_dir, 'SpectrumSense'), | ||
d_type='train', | ||
im_size=[352, 352], classes=classes, | ||
transform=transform) | ||
else: | ||
train_dataset = None | ||
|
||
if load_test: | ||
test_dataset = SpectrumSense(root_dir=os.path.join(data_dir, 'SpectrumSense'), | ||
d_type='test', | ||
im_size=[352, 352], classes=classes, | ||
transform=transform) | ||
if args.truncate_testset: | ||
test_dataset.src_list = test_dataset.src_list[:1] | ||
else: | ||
test_dataset = None | ||
|
||
return train_dataset, test_dataset | ||
|
||
|
||
def spectrumsense_get_datasets_s256_c3(data, load_train=True, load_test=True): | ||
""" | ||
Load the spectrumsense dataset for 3 classes in 48x64x64 images. | ||
""" | ||
return spectrumsense_get_datasets_s256(data, load_train, load_test) | ||
|
||
|
||
def spectrumsense_get_datasets_s352_c3(data, load_train=True, load_test=True): | ||
""" | ||
Load the spectrumsense dataset for 3 classes in 48x88x88 images. | ||
""" | ||
return spectrumsense_get_datasets_s352(data, load_train, load_test) | ||
|
||
|
||
datasets = [ | ||
{ | ||
'name': 'SpectrumSense_s256_c2', # 2 classes | ||
'input': (48, 64, 64), | ||
'output': (0, 1, 2), | ||
'weight': (1, 1, 1), | ||
'loader': spectrumsense_get_datasets_s256_c3, | ||
}, | ||
{ | ||
'name': 'SpectrumSense_s352_c2', # 2 classes | ||
'input': (48, 88, 88), | ||
'output': (0, 1, 2), | ||
'weight': (1, 1, 1), | ||
'loader': spectrumsense_get_datasets_s352_c3, | ||
'fold_ratio': 4, | ||
} | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
--- | ||
start_epoch: 40 | ||
weight_bits: 8 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use of filter classes method is not clear (esp +5 part). If you are utilizing the function it is better to add some explanatory comments here. If not, better to remove this and related constructor classes parameter.