Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 80 additions & 72 deletions site_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@
import logging
import os
import shutil
import warnings

import numpy as np
import pandas as pd
import torch
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
from ocf_data_sampler.torch_datasets.datasets.site import (
SitesDataset,
)
from ocf_data_sampler.torch_datasets.sample.base import (
batch_to_tensor,
)
from pvnet.models.base_model import BaseModel as PVNetBaseModel
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
from ocf_data_sampler.torch_datasets.sample.base import batch_to_tensor

from pvnet.models.base_model import BaseModel as PVNetBaseModel
from site_forecast_app.data.satellite import download_satellite_data

from .consts import (
Expand All @@ -29,6 +26,7 @@
site_netcdf_path,
site_path,
)

from .utils import (
NWPProcessAndCacheConfig,
populate_data_config_sources,
Expand All @@ -38,9 +36,7 @@
set_night_time_zeros,
)

# Global settings for running the model

# Model will use GPU if available
# Setup device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,20 +92,18 @@ def predict(self, site_uuid: str, timestamp: dt.datetime) -> dict:

normed_preds = []
with torch.no_grad():

# note this only running ones site
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add these comment back in please

samples = self.dataset.valid_t0_and_site_ids
samples_with_same_t0 = samples[samples["t0"] == timestamp]

if len(samples_with_same_t0) == 0:

sample_t0 = samples.iloc[-1].t0
sample_site_id = samples.iloc[-1].site_id

log.warning(
"Timestamp different from the one in the batch: "
f"{timestamp} != {sample_t0} (batch)"
f"The other timestamps are: {samples['t0'].unique()}",
f"The other timestamps are: {samples['t0'].unique()}"
)
else:
sample_t0 = samples_with_same_t0.iloc[0].t0
Expand All @@ -120,7 +114,7 @@ def predict(self, site_uuid: str, timestamp: dt.datetime) -> dict:

if site_uuid != sample_site_id:
log.warning(
f"Site id different from the one in the batch: {site_uuid} != {sample_site_id}",
f"Site id different from the one in the batch: {site_uuid} != {sample_site_id}"
)

# for i, batch in enumerate(self.dataloader):
Expand Down Expand Up @@ -167,7 +161,7 @@ def predict(self, site_uuid: str, timestamp: dt.datetime) -> dict:

# t0 time not included in forecasts
valid_times = pd.to_datetime(
[sample_t0 + dt.timedelta(minutes=15 * (i+1)) for i in range(n_times)],
[sample_t0 + dt.timedelta(minutes=15 * (i + 1)) for i in range(n_times)]
)

# index of the 50th percentile, assumed number of p values odd and in order
Expand All @@ -181,7 +175,7 @@ def predict(self, site_uuid: str, timestamp: dt.datetime) -> dict:
"forecast_power_kw": int(v * capacity_kw),
}
for i, v in enumerate(normed_preds[0, :, middle_plevel_index])
],
]
)
# remove any negative values
values_df["forecast_power_kw"] = values_df["forecast_power_kw"].clip(lower=0.0)
Expand Down Expand Up @@ -220,7 +214,7 @@ def add_probabilistic_values(
# add 10th and 90th percentage
values_df["p10"] = normed_preds[0, :, idx_10] * capacity_kw
values_df["p90"] = normed_preds[0, :, idx_90] * capacity_kw
# change to intergers
# change to integers
values_df["p10"] = values_df["p10"].astype(int)
values_df["p90"] = values_df["p90"].astype(int)
values_df["probabilistic_values"] = values_df[["p10", "p90"]].apply(
Expand All @@ -234,66 +228,79 @@ def _prepare_data_sources(self) -> None:
"""Pull and prepare data sources required for inference."""
log.info("Preparing data sources")

# Create root data directory if not exists
with contextlib.suppress(FileExistsError):
os.mkdir(root_data_path)
# Load remote zarr source
use_satellite = os.getenv("USE_SATELLITE", "true").lower() == "true"
satellite_source_file_path = os.getenv("SATELLITE_ZARR_PATH", None)
satellite_backup_source_file_path = os.getenv("SATELLITE_BACKUP_ZARR_PATH", None)

# only load nwp that we need
nwp_configs = []
nwp_keys = self.config["input_data"]["nwp"].keys()
if "ecmwf" in nwp_keys:

nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_ECMWF_ZARR_PATH"],
dest_nwp_path=nwp_ecmwf_path,
source="ecmwf",
),
)
if "mo_global" in nwp_keys:
nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_MO_GLOBAL_ZARR_PATH"],
dest_nwp_path=nwp_mo_global_path,
source="mo_global",
),
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Soryr this wont do it. You need to actually look at the configuration and see if that data is available in the NWP data / timestamps

# Create root data directory if not exists
with contextlib.suppress(FileExistsError):
os.mkdir(root_data_path)

# Load remote zarr source
use_satellite = os.getenv("USE_SATELLITE", "true").lower() == "true"
satellite_source_file_path = os.getenv("SATELLITE_ZARR_PATH", None)
satellite_backup_source_file_path = os.getenv("SATELLITE_BACKUP_ZARR_PATH", None)

# only load nwp that we need
nwp_configs = []
nwp_keys = self.config["input_data"]["nwp"].keys()
if "ecmwf" in nwp_keys:
nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_ECMWF_ZARR_PATH"],
dest_nwp_path=nwp_ecmwf_path,
source="ecmwf",
),
)
if "mo_global" in nwp_keys:
nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_MO_GLOBAL_ZARR_PATH"],
dest_nwp_path=nwp_mo_global_path,
source="mo_global",
),
)

# Remove local cached zarr if already exists
for nwp_config in nwp_configs:
# Process/cache remote zarr locally
process_and_cache_nwp(nwp_config)
if use_satellite and "satellite" in self.config["input_data"]:
download_satellite_data(
satellite_source_file_path,
satellite_path,
self.satellite_scaling_method,
satellite_backup_source_file_path,
)

log.info("Preparing Site data sources")
# Clear local cached site data if already exists
shutil.rmtree(site_path, ignore_errors=True)
os.mkdir(site_path)

# Save generation data as netcdf file
generation_xr = self.generation_data["data"]

forecast_timesteps = pd.date_range(
start=self.t0 - pd.Timedelta("52h"),
periods=int(4 * 24 * 4.5),
freq="15min",
)

# Remove local cached zarr if already exists
for nwp_config in nwp_configs:
# Process/cache remote zarr locally
process_and_cache_nwp(nwp_config)
if use_satellite and "satellite" in self.config["input_data"]:
download_satellite_data(satellite_source_file_path,
satellite_path,
self.satellite_scaling_method,
satellite_backup_source_file_path)

log.info("Preparing Site data sources")
# Clear local cached site data if already exists
shutil.rmtree(site_path, ignore_errors=True)
os.mkdir(site_path)

# Save generation data as netcdf file
generation_xr = self.generation_data["data"]

forecast_timesteps = pd.date_range(
start=self.t0 - pd.Timedelta("52h"),
periods=4 * 24 * 4.5,
freq="15min",
)
generation_xr = generation_xr.reindex(time_utc=forecast_timesteps, fill_value=0.00001)
log.info(forecast_timesteps)

generation_xr = generation_xr.reindex(time_utc=forecast_timesteps, fill_value=0.00001)
log.info(forecast_timesteps)
generation_xr.to_netcdf(site_netcdf_path, engine="h5netcdf")

generation_xr.to_netcdf(site_netcdf_path, engine="h5netcdf")
# Save metadata as csv
self.generation_data["metadata"].to_csv(site_metadata_path, index=False)

# Save metadata as csv
self.generation_data["metadata"].to_csv(site_metadata_path, index=False)
except Exception as e:
error_message = (
"Could not run the forecast because there wasn't enough NWP data. "
"Please check your NWP input files and time range."
)
log.error(error_message)
log.error(f"Underlying error: {e}")
warnings.warn(error_message)
raise RuntimeError(error_message) from e

def _get_config(self) -> None:
"""Setup dataloader with prepared data sources."""
Expand All @@ -308,6 +315,7 @@ def _get_config(self) -> None:
)

# Populate the data config with production data paths

populated_data_config_filename = "data/data_config.yaml"
log.info(populated_data_config_filename)
# if the file already exists, remove it
Expand Down