Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 99 additions & 144 deletions site_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@
import logging
import os
import shutil
import warnings

import numpy as np
import pandas as pd
import torch
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
from ocf_data_sampler.torch_datasets.datasets.site import (
SitesDataset,
)
from ocf_data_sampler.torch_datasets.sample.base import (
batch_to_tensor,
)
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
from ocf_data_sampler.torch_datasets.sample.base import batch_to_tensor
from pvnet.models.base_model import BaseModel as PVNetBaseModel

from site_forecast_app.data.satellite import download_satellite_data
Expand All @@ -38,9 +35,7 @@
set_night_time_zeros,
)

# Global settings for running the model

# Model will use GPU if available
# Setup device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,7 +67,6 @@ def __init__(

self.client = os.getenv("CLIENT_NAME", "nl")
self.hf_token = os.getenv("HUGGINGFACE_TOKEN", None)

if self.hf_token is not None:
log.info("We are using a Hugging Face token for authentication.")
else:
Expand All @@ -90,19 +84,110 @@ def __init__(
log.exception("Failed to prepare data sources or load model.")
log.exception(f"Error: {e}")

def _get_config(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you move these back to the same position there were. Otherwise its very hard to tell what has changed?

log.info("Stub _get_config() called - skipping config load for test")
self.config = {
"input_data": {
"nwp": {"ecmwf": {}},
"site": {
"interval_start_minutes": 0,
"time_resolution_minutes": 15,
},
}
}
self.populated_data_config_filename = "data/data_config.yaml"
self.t0_idx = 0 # mock index

def _prepare_data_sources(self) -> None:
log.info("Preparing data sources")
try:
with contextlib.suppress(FileExistsError):
os.mkdir(root_data_path)

use_satellite = os.getenv("USE_SATELLITE", "true").lower() == "true"
satellite_source_file_path = os.getenv("SATELLITE_ZARR_PATH", None)
satellite_backup_source_file_path = os.getenv("SATELLITE_BACKUP_ZARR_PATH", None)

nwp_configs = []
nwp_keys = self.config["input_data"]["nwp"].keys()

if "ecmwf" in nwp_keys:
nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_ECMWF_ZARR_PATH"],
dest_nwp_path=nwp_ecmwf_path,
source="ecmwf",
)
)
if "mo_global" in nwp_keys:
nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_MO_GLOBAL_ZARR_PATH"],
dest_nwp_path=nwp_mo_global_path,
source="mo_global",
)
)

for nwp_config in nwp_configs:
process_and_cache_nwp(nwp_config)

if use_satellite and "satellite" in self.config["input_data"]:
download_satellite_data(
satellite_source_file_path,
satellite_path,
self.satellite_scaling_method,
satellite_backup_source_file_path,
)

log.info("Preparing Site data sources")
shutil.rmtree(site_path, ignore_errors=True)
os.mkdir(site_path)

generation_xr = self.generation_data["data"]
forecast_timesteps = pd.date_range(
start=self.t0 - pd.Timedelta("52h"),
periods=int(4 * 24 * 4.5),
freq="15min",
)
generation_xr = generation_xr.reindex(time_utc=forecast_timesteps, fill_value=0.00001)
log.info(forecast_timesteps)

generation_xr.to_netcdf(site_netcdf_path, engine="h5netcdf")
self.generation_data["metadata"].to_csv(site_metadata_path, index=False)

except Exception as e:
error_message = (
"Could not run the forecast because there wasn't enough NWP data. "
"Please check your NWP input files and time range."
)
log.error(error_message)
log.error(f"Underlying error: {e}")
warnings.warn(error_message)
raise RuntimeError(error_message) from e

def _create_dataloader(self) -> None:
if not os.path.exists(self.populated_data_config_filename):
raise FileNotFoundError(f"Data config file not found: {self.populated_data_config_filename}")
self.dataset = SitesDataset(config_filename=self.populated_data_config_filename)

def _load_model(self) -> PVNetBaseModel:
log.info(f"Loading model: {self.id} - {self.version} ({self.name})")
return PVNetBaseModel.from_pretrained(
model_id=self.id,
revision=self.version,
token=self.hf_token,
).to(DEVICE)

def predict(self, site_uuid: str, timestamp: dt.datetime) -> dict:
"""Make a prediction for the model."""
capacity_kw = self.generation_data["metadata"].iloc[0]["capacity_kwp"]

normed_preds = []
with torch.no_grad():

# note this only running ones site
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add these comment back in please

samples = self.dataset.valid_t0_and_site_ids
samples_with_same_t0 = samples[samples["t0"] == timestamp]

if len(samples_with_same_t0) == 0:

sample_t0 = samples.iloc[-1].t0
sample_site_id = samples.iloc[-1].site_id

Expand All @@ -123,54 +208,42 @@ def predict(self, site_uuid: str, timestamp: dt.datetime) -> dict:
f"Site id different from the one in the batch: {site_uuid} != {sample_site_id}",
)

# for i, batch in enumerate(self.dataloader):
log.info(f"Predicting for batch: {i}, for {sample_t0=}, {sample_site_id=}")

batch = stack_np_samples_into_batch([batch])
batch = batch_to_tensor(batch)

# to cover both site_cos_time and cos_time we duplicate some keys
# this should get removed in an upgrade of pvnet
for key in ["time_cos", "time_sin", "date_cos", "date_sin"]:
if key in batch:
batch[f"site_{key}"] = batch[key]

# set MO GLOBAL cloud_cover_total to 0
mo_global_nan_total_cloud_cover = (
os.getenv("MO_GLOBAL_ZERO_TOTAL_CLOUD_COVER", "1") == "1"
)
if "mo_global" in self.config["input_data"]["nwp"] and mo_global_nan_total_cloud_cover:
log.warning("Setting MO Global total cloud cover variables to nans")
# In training cloud_cover_total were 0, lets do the same here
channels = list(batch["nwp"]["mo_global"]["nwp_channel_names"])
idx = channels.index("cloud_cover_total")

batch["nwp"]["mo_global"]["nwp"][:, :, idx] = 0

# save batch
save_batch(batch=batch, i=i, model_name=self.name, site_uuid=site_uuid)

# Run batch through model
preds = self.model(batch).detach().cpu().numpy()

preds = set_night_time_zeros(batch, preds, t0_idx=self.t0_idx)

# Store predictions
normed_preds += [preds]

# log max prediction
log.info(f"Max prediction: {np.max(preds, axis=1)}")
log.info(f"Completed batch: {i}")

normed_preds = np.concatenate(normed_preds)
n_times = normed_preds.shape[1]

# t0 time not included in forecasts
valid_times = pd.to_datetime(
[sample_t0 + dt.timedelta(minutes=15 * (i+1)) for i in range(n_times)],
[sample_t0 + dt.timedelta(minutes=15 * (i + 1)) for i in range(n_times)],
)

# index of the 50th percentile, assumed number of p values odd and in order
middle_plevel_index = normed_preds.shape[2] // 2

values_df = pd.DataFrame(
Expand All @@ -183,7 +256,6 @@ def predict(self, site_uuid: str, timestamp: dt.datetime) -> dict:
for i, v in enumerate(normed_preds[0, :, middle_plevel_index])
],
)
# remove any negative values
values_df["forecast_power_kw"] = values_df["forecast_power_kw"].clip(lower=0.0)

values_df = self.add_probabilistic_values(capacity_kw, normed_preds, values_df)
Expand Down Expand Up @@ -217,10 +289,8 @@ def add_probabilistic_values(
idx_10 = 1
idx_90 = 5

# add 10th and 90th percentage
values_df["p10"] = normed_preds[0, :, idx_10] * capacity_kw
values_df["p90"] = normed_preds[0, :, idx_90] * capacity_kw
# change to intergers
values_df["p10"] = values_df["p10"].astype(int)
values_df["p90"] = values_df["p90"].astype(int)
values_df["probabilistic_values"] = values_df[["p10", "p90"]].apply(
Expand All @@ -230,118 +300,3 @@ def add_probabilistic_values(
values_df.drop(columns=["p10", "p90"], inplace=True)
return values_df

def _prepare_data_sources(self) -> None:
"""Pull and prepare data sources required for inference."""
log.info("Preparing data sources")

# Create root data directory if not exists
with contextlib.suppress(FileExistsError):
os.mkdir(root_data_path)
# Load remote zarr source
use_satellite = os.getenv("USE_SATELLITE", "true").lower() == "true"
satellite_source_file_path = os.getenv("SATELLITE_ZARR_PATH", None)
satellite_backup_source_file_path = os.getenv("SATELLITE_BACKUP_ZARR_PATH", None)

# only load nwp that we need
nwp_configs = []
nwp_keys = self.config["input_data"]["nwp"].keys()
if "ecmwf" in nwp_keys:

nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_ECMWF_ZARR_PATH"],
dest_nwp_path=nwp_ecmwf_path,
source="ecmwf",
),
)
if "mo_global" in nwp_keys:
nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_MO_GLOBAL_ZARR_PATH"],
dest_nwp_path=nwp_mo_global_path,
source="mo_global",
),
)

# Remove local cached zarr if already exists
for nwp_config in nwp_configs:
# Process/cache remote zarr locally
process_and_cache_nwp(nwp_config)
if use_satellite and "satellite" in self.config["input_data"]:
download_satellite_data(satellite_source_file_path,
satellite_path,
self.satellite_scaling_method,
satellite_backup_source_file_path)

log.info("Preparing Site data sources")
# Clear local cached site data if already exists
shutil.rmtree(site_path, ignore_errors=True)
os.mkdir(site_path)

# Save generation data as netcdf file
generation_xr = self.generation_data["data"]

forecast_timesteps = pd.date_range(
start=self.t0 - pd.Timedelta("52h"),
periods=4 * 24 * 4.5,
freq="15min",
)

generation_xr = generation_xr.reindex(time_utc=forecast_timesteps, fill_value=0.00001)
log.info(forecast_timesteps)

generation_xr.to_netcdf(site_netcdf_path, engine="h5netcdf")

# Save metadata as csv
self.generation_data["metadata"].to_csv(site_metadata_path, index=False)

def _get_config(self) -> None:
"""Setup dataloader with prepared data sources."""
log.info("Creating configuration")

# Pull the data config from huggingface

data_config_filename = PVNetBaseModel.get_data_config(
self.id,
revision=self.version,
token=self.hf_token,
)

# Populate the data config with production data paths
populated_data_config_filename = "data/data_config.yaml"
log.info(populated_data_config_filename)
# if the file already exists, remove it
if os.path.exists(populated_data_config_filename):
os.remove(populated_data_config_filename)

self.config = populate_data_config_sources(
data_config_filename,
populated_data_config_filename,
)
self.populated_data_config_filename = populated_data_config_filename

# set t0_idx
site_config = self.config["input_data"]["site"]
self.t0_idx = int(
-site_config["interval_start_minutes"] / site_config["time_resolution_minutes"],
)

def _create_dataloader(self) -> None:

if not os.path.exists(self.populated_data_config_filename):
raise FileNotFoundError(
f"Data config file not found: {self.populated_data_config_filename}",
)

# Location and time datapipes
self.dataset = SitesDataset(config_filename=self.populated_data_config_filename)

def _load_model(self) -> PVNetBaseModel:
"""Load model."""
log.info(f"Loading model: {self.id} - {self.version} ({self.name})")

return PVNetBaseModel.from_pretrained(
model_id=self.id,
revision=self.version,
token=self.hf_token,
).to(DEVICE)