Skip to content

Commit 2172ffa

Browse files
committed
Replace interp2d with RectBivariateSpline
interp2d is gone in scipy 1.14, and RecBivariateSpline is 5x faster to evaluate Fixes #270
1 parent 49e712b commit 2172ffa

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

flarestack/core/llh.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def create_acceptance_function(self):
826826
with open(acc_path, "rb") as f:
827827
[dec_bins, gamma_bins, acc] = pickle.load(f)
828828

829-
f = scipy.interpolate.interp2d(dec_bins, gamma_bins, acc.T, kind="linear")
829+
f = scipy.interpolate.RectBivariateSpline(dec_bins, gamma_bins, acc, kx=1, ky=1)
830830
return f
831831

832832
def new_acceptance(self, source, params=None):
@@ -845,7 +845,7 @@ def new_acceptance(self, source, params=None):
845845
dec = source["dec_rad"]
846846
gamma = params[-1]
847847

848-
return self.acceptance_f(dec, gamma)
848+
return self.acceptance_f(dec, gamma).squeeze()
849849

850850
def create_kwargs(self, data, pull_corrector, weight_f=None):
851851
kwargs = dict()
@@ -1480,9 +1480,8 @@ def create_kwargs(self, data, pull_corrector, weight_f=None):
14801480
coincident_data = data[coincident_nu_mask]
14811481
coincident_sources = self.sources[coincident_source_mask]
14821482

1483-
season_weight = lambda x: weight_f([1.0, x], self.season)[
1484-
coincident_source_mask
1485-
]
1483+
def season_weight(x):
1484+
return weight_f([1.0, x], self.season)[coincident_source_mask]
14861485

14871486
SoB_energy_cache = self.create_SoB_energy_cache(coincident_data)
14881487

flarestack/icecube_utils/reference_sensitivity.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import logging
2-
import os
32

43
import numpy as np
5-
from scipy.interpolate import interp1d, interp2d
4+
from scipy.interpolate import RectBivariateSpline
65

76
from flarestack.data.icecube.ic_season import get_published_sens_ref_dir
87

@@ -62,7 +61,9 @@ def reference_7year_sensitivity(sindec=np.array(0.0), gamma=2.0):
6261

6362
sens = np.vstack((sens[0], sens))
6463
sens = np.vstack((sens, sens[-1]))
65-
sens_ref = interp2d(np.array(sindecs), np.array(gammas), np.log(sens.T))
64+
sens_ref = RectBivariateSpline(
65+
np.array(sindecs), np.array(gammas), np.log(sens), kx=1, ky=1
66+
)
6667

6768
if np.array(sindec).ndim > 0:
6869
return np.array([np.exp(sens_ref(x, gamma))[0] for x in sindec])
@@ -98,7 +99,9 @@ def reference_7year_discovery_potential(sindec=0.0, gamma=2.0):
9899

99100
disc = np.vstack((disc[0], disc))
100101
disc = np.vstack((disc, disc[-1]))
101-
disc_ref = interp2d(np.array(sindecs), np.array(gammas), np.log(disc.T))
102+
disc_ref = RectBivariateSpline(
103+
np.array(sindecs), np.array(gammas), np.log(disc), kx=1, ky=1
104+
)
102105

103106
if np.array(sindec).ndim > 0:
104107
return np.array([np.exp(disc_ref(x, gamma))[0] for x in sindec])
@@ -130,7 +133,9 @@ def reference_10year_sensitivity(sindec=np.array(0.0), gamma=2.0):
130133
scaling = np.array([10 ** (3 * (i)) for i in range(2)])
131134
sens *= scaling
132135

133-
sens_ref = interp2d(np.array(sindecs), np.array(gammas), np.log(sens.T))
136+
sens_ref = RectBivariateSpline(
137+
np.array(sindecs), np.array(gammas), np.log(sens), kx=1, ky=1
138+
)
134139

135140
if np.array(sindec).ndim > 0:
136141
return np.array([np.exp(sens_ref(x, gamma))[0] for x in sindec])
@@ -162,7 +167,9 @@ def reference_10year_discovery_potential(sindec=np.array(0.0), gamma=2.0):
162167
scaling = np.array([10 ** (3 * i) for i in range(2)])
163168
sens *= scaling
164169

165-
sens_ref = interp2d(np.array(sindecs), np.array(gammas), np.log(sens.T))
170+
sens_ref = RectBivariateSpline(
171+
np.array(sindecs), np.array(gammas), np.log(sens), kx=1, ky=1
172+
)
166173

167174
if np.array(sindec).ndim > 0:
168175
return np.array([np.exp(sens_ref(x, gamma))[0] for x in sindec])

flarestack/utils/percentile_SoB.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,9 @@ def make_plot(hist, savepath, normed=True):
244244

245245
order = 1
246246

247-
spline = scipy.interpolate.interp2d(x, y, np.log(ratio))
247+
spline = scipy.interpolate.RectBivariateSpline(
248+
x, y, np.log(ratio.T), kx=order, ky=order
249+
)
248250

249251
for x_val in [2.0, 3.0, 7.0]:
250252
print(x_val, spline(0.0, x_val), spline(0.5, x_val))

0 commit comments

Comments
 (0)