|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2019 The TensorFlow Datasets Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""So2SAT remote sensing dataset.""" |
| 17 | + |
| 18 | +from __future__ import absolute_import |
| 19 | +from __future__ import division |
| 20 | +from __future__ import print_function |
| 21 | + |
| 22 | +import h5py |
| 23 | +import numpy as np |
| 24 | +import tensorflow as tf |
| 25 | +import tensorflow_datasets.public_api as tfds |
| 26 | + |
| 27 | +_DESCRIPTION = """\ |
| 28 | +So2Sat LCZ42 is a dataset consisting of co-registered synthetic aperture radar |
| 29 | +and multispectral optical image patches acquired by the Sentinel-1 and |
| 30 | +Sentinel-2 remote sensing satellites, and the corresponding local climate zones |
| 31 | +(LCZ) label. The dataset is distributed over 42 cities across different |
| 32 | +continents and cultural regions of the world. |
| 33 | +
|
| 34 | +The full dataset (`all`) consists of 8 Sentinel-1 and 10 Sentinel-2 channels. |
| 35 | +Alternatively, one can select the `rgb` subset, which contains only the optical |
| 36 | +frequency bands of Sentinel-2, rescaled and encoded as JPEG. |
| 37 | +
|
| 38 | +Dataset URL: http://doi.org/10.14459/2018MP1454690 |
| 39 | +License: http://creativecommons.org/licenses/by/4.0 |
| 40 | +""" |
| 41 | + |
| 42 | +_LABELS = [ |
| 43 | + 'Compact high-rise', 'Compact mid-rise', 'Compact low-rise', |
| 44 | + 'Open high-rise', 'Open mid-rise', 'Open low-rise', 'Lightweight low-rise', |
| 45 | + 'Large low-rise', 'Sparsely built', 'Heavy industry', 'Dense trees', |
| 46 | + 'Scattered trees', 'Bush or scrub', 'Low plants', 'Bare rock or paved', |
| 47 | + 'Bare soil or sand', 'Water' |
| 48 | +] |
| 49 | + |
| 50 | +_DATA_OPTIONS = ['rgb', 'all'] |
| 51 | + |
| 52 | +# Maximal value observed for the optical channels of Sentinel-2 in this dataset. |
| 53 | +_OPTICAL_MAX_VALUE = 2.8 |
| 54 | + |
| 55 | + |
| 56 | +class So2satConfig(tfds.core.BuilderConfig): |
| 57 | + """BuilderConfig for so2sat.""" |
| 58 | + |
| 59 | + def __init__(self, selection=None, **kwargs): |
| 60 | + """Constructs a So2satConfig. |
| 61 | +
|
| 62 | + Args: |
| 63 | + selection: `str`, one of `_DATA_OPTIONS`. |
| 64 | + **kwargs: keyword arguments forwarded to super. |
| 65 | + """ |
| 66 | + if selection not in _DATA_OPTIONS: |
| 67 | + raise ValueError('selection must be one of %s' % _DATA_OPTIONS) |
| 68 | + |
| 69 | + super(So2satConfig, self).__init__(**kwargs) |
| 70 | + self.selection = selection |
| 71 | + |
| 72 | + |
| 73 | +class So2sat(tfds.core.GeneratorBasedBuilder): |
| 74 | + """So2SAT remote sensing dataset.""" |
| 75 | + |
| 76 | + BUILDER_CONFIGS = [ |
| 77 | + So2satConfig( |
| 78 | + selection='rgb', |
| 79 | + name='rgb', |
| 80 | + version='0.0.1', |
| 81 | + description='Sentinel-2 RGB channels'), |
| 82 | + So2satConfig( |
| 83 | + selection='all', |
| 84 | + name='all', |
| 85 | + version='0.0.1', |
| 86 | + description='8 Sentinel-1 and 10 Sentinel-2 channels'), |
| 87 | + ] |
| 88 | + |
| 89 | + def _info(self): |
| 90 | + if self.builder_config.selection == 'rgb': |
| 91 | + features = tfds.features.FeaturesDict({ |
| 92 | + 'image': tfds.features.Image(shape=[32, 32, 3]), |
| 93 | + 'label': tfds.features.ClassLabel(names=_LABELS), |
| 94 | + 'sample_id': tfds.features.Tensor(shape=(), dtype=tf.int64), |
| 95 | + }) |
| 96 | + supervised_keys = ('image', 'label') |
| 97 | + elif self.builder_config.selection == 'all': |
| 98 | + features = tfds.features.FeaturesDict({ |
| 99 | + 'sentinel1': |
| 100 | + tfds.features.Tensor(shape=[32, 32, 8], dtype=tf.float32), |
| 101 | + 'sentinel2': |
| 102 | + tfds.features.Tensor(shape=[32, 32, 10], dtype=tf.float32), |
| 103 | + 'label': |
| 104 | + tfds.features.ClassLabel(names=_LABELS), |
| 105 | + 'sample_id': |
| 106 | + tfds.features.Tensor(shape=(), dtype=tf.int64), |
| 107 | + }) |
| 108 | + supervised_keys = None |
| 109 | + return tfds.core.DatasetInfo( |
| 110 | + builder=self, |
| 111 | + description=_DESCRIPTION, |
| 112 | + features=features, |
| 113 | + supervised_keys=supervised_keys, |
| 114 | + urls=['http://doi.org/10.14459/2018MP1454690'], |
| 115 | + ) |
| 116 | + |
| 117 | + def _split_generators(self, dl_manager): |
| 118 | + """Returns SplitGenerators.""" |
| 119 | + paths = dl_manager.download({ |
| 120 | + 'train': 'ftp://m1454690:m1454690@dataserv.ub.tum.de/training.h5', |
| 121 | + 'val': 'ftp://m1454690:m1454690@dataserv.ub.tum.de/validation.h5' |
| 122 | + }) |
| 123 | + return [ |
| 124 | + tfds.core.SplitGenerator( |
| 125 | + name=tfds.Split.TRAIN, |
| 126 | + num_shards=20, |
| 127 | + gen_kwargs={ |
| 128 | + 'path': paths['train'], |
| 129 | + 'selection': self.builder_config.selection, |
| 130 | + }, |
| 131 | + ), |
| 132 | + tfds.core.SplitGenerator( |
| 133 | + name=tfds.Split.VALIDATION, |
| 134 | + num_shards=5, |
| 135 | + gen_kwargs={ |
| 136 | + 'path': paths['val'], |
| 137 | + 'selection': self.builder_config.selection, |
| 138 | + }, |
| 139 | + ), |
| 140 | + ] |
| 141 | + |
| 142 | + def _generate_examples(self, path, selection): |
| 143 | + """Yields examples.""" |
| 144 | + with h5py.File(path, 'r') as fid: |
| 145 | + sen1 = fid['sen1'] |
| 146 | + sen2 = fid['sen2'] |
| 147 | + label = fid['label'] |
| 148 | + for i in range(len(sen1)): |
| 149 | + if selection == 'rgb': |
| 150 | + yield { |
| 151 | + 'image': _create_rgb(sen2[i]), |
| 152 | + 'label': np.argmax(label[i]).astype(int), |
| 153 | + 'sample_id': i, |
| 154 | + } |
| 155 | + elif selection == 'all': |
| 156 | + yield { |
| 157 | + 'sentinel1': sen1[i].astype(np.float32), |
| 158 | + 'sentinel2': sen2[i].astype(np.float32), |
| 159 | + 'label': np.argmax(label[i]).astype(int), |
| 160 | + 'sample_id': i, |
| 161 | + } |
| 162 | + |
| 163 | + |
| 164 | +def _create_rgb(sen2_bands): |
| 165 | + return np.clip(sen2_bands[..., [2, 1, 0]] / _OPTICAL_MAX_VALUE * 255.0, 0, |
| 166 | + 255).astype(np.uint8) |
0 commit comments