Skip to content

Commit 7f3979e

Browse files
authored
Merge pull request #1320 from cta-observatory/coszd_RF_interpolation
Interpolation of RF predictions with cosZD, for homogeneous performance
2 parents 0e1ab4d + 324a8c6 commit 7f3979e

File tree

3 files changed

+261
-16
lines changed

3 files changed

+261
-16
lines changed

lstchain/data/lstchain_standard_config.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@
6969
"pointing_wise_weights": true
7070
},
7171

72+
"random_forest_zd_interpolation": {
73+
"interpolate_energy": true,
74+
"interpolate_gammaness": true,
75+
"interpolate_direction": true
76+
},
77+
7278
"random_forest_energy_regressor_args": {
7379
"max_depth": 30,
7480
"min_samples_leaf": 10,

lstchain/reco/dl1_to_dl2.py

Lines changed: 204 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import joblib
1515
import numpy as np
1616
import pandas as pd
17+
import warnings
1718
from astropy.coordinates import SkyCoord, Angle
1819
from astropy.time import Time
1920
from pathlib import Path
@@ -38,10 +39,12 @@
3839
logger = logging.getLogger(__name__)
3940

4041
__all__ = [
42+
'add_zd_interpolation_info',
4143
'apply_models',
4244
'build_models',
4345
'get_expected_source_pos',
4446
'get_source_dependent_parameters',
47+
'predict_with_zd_interpolation',
4548
'train_disp_norm',
4649
'train_disp_sign',
4750
'train_disp_vector',
@@ -51,6 +54,158 @@
5154
'update_disp_with_effective_focal_length'
5255
]
5356

57+
def add_zd_interpolation_info(dl2table, training_pointings):
58+
"""
59+
Compute necessary parameters for the interpolation of RF predictions
60+
between the zenith pointings of the MC data in the training sample on
61+
which the RFs were trained.
62+
63+
Parameters
64+
----------
65+
dl2table : pandas.DataFrame
66+
DataFrame containing DL2 information, including 'alt_tel' and 'az_tel'.
67+
Four columns will be added: alt0, alt1, w0, w1.
68+
alt0 and alt1 are the alt_tel values (telescope elevation, in radians) of
69+
the closest and second-closest training MC pointings (closest in elevation,
70+
on the same side of culmination) for each event in the table. The values
71+
w0 and w1 are the corresponding weights that, multiplied by the RF
72+
predictions at those two pointings, provide the interpolated result for
73+
each event's pointing.
74+
75+
training_pointings : astropy.table.Table
76+
Table containing the pointings (zd, az) of the MC training nodes.
77+
78+
Returns
79+
-------
80+
pandas.DataFrame
81+
Updated DL2 pandas dataframe with additional columns alt0, alt1, w0, w1.
82+
83+
"""
84+
85+
alt_tel = np.array(dl2table['alt_tel'])
86+
az_tel = np.array(dl2table['az_tel'])
87+
88+
training_alt_rad = np.pi / 2 - training_pointings['zd'].to(u.rad).value
89+
training_az_rad = training_pointings['az'].to(u.rad).value
90+
91+
tiled_az = np.broadcast_to(az_tel[:, np.newaxis],
92+
(len(dl2table), len(training_az_rad)))
93+
tiled_alt = np.broadcast_to(alt_tel[:, np.newaxis],
94+
(len(dl2table), len(training_az_rad)))
95+
96+
delta_alt = np.abs(training_alt_rad - tiled_alt)
97+
# mask to select training nodes only on the same side of the source
98+
# culmination as the event:
99+
same_side_of_culmination = np.sign(np.sin(training_az_rad) *
100+
np.sin(tiled_az)) > 0
101+
# Just fill a large value for pointings on the other side of culmination:
102+
delta_alt = np.where(same_side_of_culmination, delta_alt, np.pi/2)
103+
# indices ordered according to distance in telescope elevation
104+
sorted_indices = np.argsort(delta_alt, axis=1)
105+
closest_alt = training_alt_rad[sorted_indices[:, 0]]
106+
second_closest_alt = training_alt_rad[sorted_indices[:, 1]]
107+
108+
c0 = np.cos(np.pi / 2 - closest_alt)
109+
c1 = np.cos(np.pi / 2 - second_closest_alt)
110+
cos_tel_zd = np.cos(np.pi / 2 - alt_tel)
111+
112+
# Compute the weights w0, w1 that multiplied times the RF predictions at
113+
# the closest (0) and 2nd-closest (1) nodes (in alt_tel) result in the
114+
# interpolated value. Take care of cases in which the two closest nodes
115+
# happen to have the same zenith (or very close)! (if so, both nodes are
116+
# set to have equal weight in the interpolation)
117+
w1 = np.where(np.isclose(closest_alt, second_closest_alt, atol=1e-4, rtol=0),
118+
0.5, (cos_tel_zd - c0) / (c1 - c0))
119+
w0 = 1 - w1
120+
121+
# Update the dataframe:
122+
with pd.option_context('mode.copy_on_write', True):
123+
dl2table = dl2table.assign(alt0=closest_alt,
124+
alt1=second_closest_alt,
125+
w0=w0,
126+
w1=w1)
127+
128+
return dl2table
129+
130+
131+
def predict_with_zd_interpolation(rf, param_array, features):
132+
"""
133+
Obtain a RF prediction which takes into account the difference between
134+
the telescope elevation (alt_tel, i.e. 90 deg - zenith) and those of the
135+
MC training nodes. The dependence of image parameters (for a shower of
136+
given characteristics) with zenith is strong at angles beyond ~50 deg,
137+
due to the change in airmass. Given the way Random Forests work, if the
138+
training is performed with a discrete distribution of pointings,
139+
the prediction of the RF will be biased for pointings in between those
140+
used in training. If zenith is used as one of the RF features, there will
141+
be a sudden jump in the predictions halfway between the training nodes.
142+
143+
To solve this, we compute here two predictions for each event, one using
144+
the elevation (alt_tel) of the training pointing which is closest to the
145+
telescope pointing, and another one usimg the elevation of the
146+
sceond-closest pointing. Then the values are interpolated (linearly in
147+
cos(zenith)) to the actual zenith pointing (90 deg - alt_tel) of the event.
148+
149+
Parameters
150+
----------
151+
rf : sklearn.ensemble.RandomForestRegressor or RandomForestClassifier,
152+
The random forest we want to apply (must contain alt_tel among the
153+
training parameters).
154+
param_array : pandas.DataFrame
155+
Dataframe containing the features needed by the RF.
156+
It must also contain four additional columns: alt0, alt1, w0, w1, which
157+
can be added with the function add_zd_interpolation_info. These are the
158+
event-wise telescope elevations for the closest and 2nd-closest training
159+
pointings (alt0 and alt1), and the event-wise weights (w0 and w1) which
160+
must be applied to the RF prediction at the two pointings to obtain the
161+
interpolated value at the actual telescope pointing. Since the weights
162+
are the same (for a given event) for different RFs, it does not make
163+
sense to compute them here - they are pre-calculated by
164+
`add_zd_interpolation_info`.
165+
features : list of str
166+
List of the names of the image features used by the RF.
167+
168+
Return
169+
------
170+
numpy.ndarray
171+
Interpolated RF predictions. 1D array for regressors (log energy,
172+
or disp_norm), 2D (events, # of classes) for classifiers.
173+
174+
"""
175+
176+
# Type of RF (classifier or regressor):
177+
is_classifier = isinstance(rf, RandomForestClassifier)
178+
179+
features_copy = features.copy()
180+
alt_index_in_features = features_copy.index('alt_tel')
181+
182+
with warnings.catch_warnings():
183+
warnings.simplefilter("ignore")
184+
# This is just to avoid the RFs to warn about the features
185+
# unnamed (passed as an array). We do this because we want to replace
186+
# alt_tel by alt0, then by alt1...
187+
# First use alt_tel of closest MC training node:
188+
features_copy[alt_index_in_features] = 'alt0'
189+
if is_classifier:
190+
prediction_0 = rf.predict_proba(param_array[features_copy].to_numpy())
191+
else:
192+
prediction_0 = rf.predict(param_array[features_copy].to_numpy())
193+
# Now the alt_tel value of the second-closest node:
194+
features_copy[alt_index_in_features] = 'alt1'
195+
if is_classifier:
196+
prediction_1 = rf.predict_proba(param_array[features_copy].to_numpy())
197+
else:
198+
prediction_1 = rf.predict(param_array[features_copy].to_numpy())
199+
200+
# Interpolated RF prediction:
201+
if is_classifier:
202+
prediction = (prediction_0.T * param_array['w0'].values +
203+
prediction_1.T * param_array['w1'].values).T
204+
else:
205+
prediction = (prediction_0 * param_array['w0'] +
206+
prediction_1 * param_array['w1']).values
207+
208+
return prediction
54209

55210
def train_energy(train, custom_config=None):
56211
"""
@@ -60,7 +215,7 @@ def train_energy(train, custom_config=None):
60215
Parameters
61216
----------
62217
train: `pandas.DataFrame`
63-
custom_config: dictionnary
218+
custom_config : dict
64219
Modified configuration to update the standard one
65220
66221
Returns
@@ -602,6 +757,8 @@ def apply_models(dl1,
602757
cls_disp_sign=None,
603758
effective_focal_length=29.30565 * u.m,
604759
custom_config=None,
760+
interpolate_rf=None,
761+
training_pointings=None
605762
):
606763
"""
607764
Apply previously trained Random Forests to a set of data
@@ -629,6 +786,13 @@ def apply_models(dl1,
629786
effective_focal_length: `astropy.unit`
630787
custom_config: dictionary
631788
Modified configuration to update the standard one
789+
interpolate_rf : dict
790+
Contains three booleans, 'energy_regression',
791+
'particle_classification', 'disp', indicating which RF predictions
792+
should be interpolated linearly in cos(zenith).
793+
training_pointings : astropy.table.Table
794+
Table with azimuth (az), zenith (zd) pointings of the MC sample used
795+
in the training. Needed for the interpolation of RF predictions.
632796
633797
Returns
634798
-------
@@ -643,6 +807,12 @@ def apply_models(dl1,
643807
classification_features = config["particle_classification_features"]
644808
events_filters = config["events_filters"]
645809

810+
# If no settings are provided for RF interpolation, it is switched off:
811+
if interpolate_rf is None:
812+
interpolate_rf = {'energy_regression': False,
813+
'particle_classification': False,
814+
'disp': False}
815+
646816
dl2 = utils.filter_events(dl1,
647817
filters=events_filters,
648818
finite_params=config['disp_regression_features']
@@ -659,30 +829,52 @@ def apply_models(dl1,
659829
# taking into account of the abrration effect using effective focal length
660830
is_simu = 'disp_norm' in dl2.columns
661831
if is_simu:
662-
dl2 = update_disp_with_effective_focal_length(dl2, effective_focal_length = effective_focal_length)
663-
832+
dl2 = update_disp_with_effective_focal_length(dl2,
833+
effective_focal_length=effective_focal_length)
834+
835+
if True in interpolate_rf.values():
836+
# Interpolation of RF predictions is switched on
837+
dl2 = add_zd_interpolation_info(dl2, training_pointings)
664838

665839
# Reconstruction of Energy and disp_norm distance
666840
if isinstance(reg_energy, (str, bytes, Path)):
667841
reg_energy = joblib.load(reg_energy)
668-
dl2['log_reco_energy'] = reg_energy.predict(dl2[energy_regression_features])
842+
if interpolate_rf['energy_regression']:
843+
# Interpolation of RF predictions (linear in cos(zenith)):
844+
dl2['log_reco_energy'] = predict_with_zd_interpolation(reg_energy, dl2,
845+
energy_regression_features)
846+
else:
847+
dl2['log_reco_energy'] = reg_energy.predict(dl2[energy_regression_features])
669848
del reg_energy
670849
dl2['reco_energy'] = 10 ** (dl2['log_reco_energy'])
671850

672851
if config['disp_method'] == 'disp_vector':
673852
if isinstance(reg_disp_vector, (str, bytes, Path)):
674853
reg_disp_vector = joblib.load(reg_disp_vector)
675-
disp_vector = reg_disp_vector.predict(dl2[disp_regression_features])
854+
if interpolate_rf['disp']:
855+
disp_vector = predict_with_zd_interpolation(reg_disp_vector, dl2,
856+
disp_regression_features)
857+
else:
858+
disp_vector = reg_disp_vector.predict(dl2[disp_regression_features])
676859
del reg_disp_vector
677860
elif config['disp_method'] == 'disp_norm_sign':
678861
if isinstance(reg_disp_norm, (str, bytes, Path)):
679862
reg_disp_norm = joblib.load(reg_disp_norm)
680-
disp_norm = reg_disp_norm.predict(dl2[disp_regression_features])
863+
if interpolate_rf['disp']:
864+
disp_norm = predict_with_zd_interpolation(reg_disp_norm, dl2,
865+
disp_regression_features)
866+
else:
867+
disp_norm = reg_disp_norm.predict(dl2[disp_regression_features])
681868
del reg_disp_norm
682869

683870
if isinstance(cls_disp_sign, (str, bytes, Path)):
684871
cls_disp_sign = joblib.load(cls_disp_sign)
685-
disp_sign_proba = cls_disp_sign.predict_proba(dl2[disp_classification_features])
872+
if interpolate_rf['disp']:
873+
disp_sign_proba = predict_with_zd_interpolation(cls_disp_sign, dl2,
874+
disp_classification_features)
875+
else:
876+
disp_sign_proba = cls_disp_sign.predict_proba(dl2[disp_classification_features])
877+
686878
col = list(cls_disp_sign.classes_).index(1)
687879
disp_sign = np.where(disp_sign_proba[:, col] > 0.5, 1, -1)
688880
del cls_disp_sign
@@ -748,7 +940,11 @@ def apply_models(dl1,
748940

749941
if isinstance(classifier, (str, bytes, Path)):
750942
classifier = joblib.load(classifier)
751-
probs = classifier.predict_proba(dl2[classification_features])
943+
if interpolate_rf['particle_classification']:
944+
probs = predict_with_zd_interpolation(classifier, dl2,
945+
classification_features)
946+
else:
947+
probs = classifier.predict_proba(dl2[classification_features])
752948

753949
# This check is valid as long as we train on only two classes (gammas and protons)
754950
if probs.shape[1] > 2:

0 commit comments

Comments
 (0)