-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDiscrepancyGP.py
More file actions
398 lines (346 loc) · 16 KB
/
DiscrepancyGP.py
File metadata and controls
398 lines (346 loc) · 16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import sys
import math
import torch
import gpytorch
import pandas as pd
import numpy as np
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.kernels import MultitaskKernel, ScaleKernel
from gpytorch.likelihoods import MultitaskGaussianLikelihood
from gpytorch.means import MultitaskMean
from torch import nn, optim
from torch.optim.lr_scheduler import LRScheduler
from autoemulate.callbacks.early_stopping import EarlyStopping, EarlyStoppingException
from autoemulate.core.device import TorchDeviceMixin
from autoemulate.core.types import (
DeviceLike,
GaussianLike,
GaussianProcessLike,
TensorLike,
)
from autoemulate.emulators.base import GaussianProcessEmulator
from autoemulate.emulators.gaussian_process.exact import GaussianProcess, create_gp_subclass
from autoemulate.emulators.gaussian_process import CovarModuleFn, MeanModuleFn
from autoemulate.transforms.standardize import StandardizeTransform
from autoemulate.transforms.utils import make_positive_definite
from autoemulate.emulators.gaussian_process.kernel import (
matern_3_2_kernel,
matern_5_2_kernel,
matern_5_2_plus_rq,
rbf_kernel,
rbf_plus_constant,
rbf_plus_linear,
rbf_times_linear,
rq_kernel,
)
from autoemulate.emulators.gaussian_process.mean import constant_mean, linear_mean, poly_mean, zero_mean
class aKernel(gpytorch.kernels.Kernel):
# Function to iteratively build weighted sum of reference kernels
def __init__(self,a,ref_model,ref_likelihood, length_prior=None, length_constraint=None, **kwargs):
super().__init__(**kwargs)
self.a=a
self.ref_model=ref_model
self.ref_likelihood=ref_likelihood
def forward(self, x1, x2, **params):
self.ref_model.eval()
self.ref_likelihood.eval()
a2=(self.a**2)
diff=a2[None,None].T*(self.ref_model.covar_module(x1,x2)+((self.ref_likelihood.task_noises+self.ref_likelihood.noise)[None,None].T*(torch.eye(x1.shape[0],x2.shape[0]).tile([a2.shape[0],1,1]))))
return diff
class DiscrepancyMeanCohort(gpytorch.means.Mean):
# Function to generate weighted sum of reference means and discrepancy mean
def __init__(self, input_size,mean_module,ref_model,ref_likelihood,a, batch_shape=torch.Size(), bias=True):
super().__init__()
#self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 2)))
#if bias:
#self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 2)))
#else:
#self.bias = None
self.a=a
self.ref_model=ref_model
self.ref_likelihood=ref_likelihood
self.mean_module=mean_module
def forward(self, x):
res2=0
res1 = self.mean_module.forward(x).T #x.matmul(self.weights).squeeze(-1)
for i in range(len(self.ref_model)):
self.ref_model[i].eval()
self.ref_likelihood[i].eval()
res2+=self.a[i]*self.ref_likelihood[i](self.ref_model[i](x)).mean
res=res1+res2
return res.T
class DiscrepancyGaussianProcess(GaussianProcess, gpytorch.models.ExactGP):
"""
Discrepancy Gaussian Process Emulator.
This class implements an exact Discrepancy Gaussian Process emulator using the GPyTorch library
It supports:
- multi-task Gaussian processes
- custom mean and kernel specification
"""
# TODO: refactor to work more like PyTorchBackend once any subclasses implemented
optimizer_cls: type[optim.Optimizer] = optim.AdamW
optimizer: optim.Optimizer
lr: float = 2e-1
scheduler_cls: type[LRScheduler] | None = None
def __init__(
self,
x: TensorLike,
y: TensorLike,
ref_model,
ref_likelihood,
standardize_x: bool = False,
standardize_y: bool = True,
likelihood_cls: type[MultitaskGaussianLikelihood] = MultitaskGaussianLikelihood,
mean_module_fn: MeanModuleFn = constant_mean,
covar_module_fn: CovarModuleFn = rbf_plus_constant,
fixed_mean_params: bool = False,
fixed_covar_params: bool = False,
posterior_predictive: bool = False,
epochs: int = 50,
lr: float = 2e-1,
early_stopping: EarlyStopping | None = None,
device: DeviceLike | None = None,
scheduler_cls: type[LRScheduler] | None = None,
scheduler_params: dict | None = None,
):
"""
Initialize the GaussianProcess emulator.
Parameters
----------
x: TensorLike
Input features, expected to be a 2D tensor of shape (n_samples, n_features).
y: TensorLike
Target values, expected to be a 2D tensor of shape (n_samples, n_tasks).
ref_model: list
List of reference emulators to build discrepancy emulator
ref_likelihood: list
List of reference likelihoods to build discrepancy emulator
likelihood_cls: type[MultitaskGaussianLikelihood]
Likelihood class to use for the model. Defaults to
`MultitaskGaussianLikelihood`.
epochs: int
Number of training epochs. Defaults to 50.
lr: float
Learning rate for the optimizer. Defaults to 2e-1.
early_stopping: EarlyStopping | None
An optional EarlyStopping callback. Defaults to None.
device: DeviceLike | None
Device to run the model on. If None, uses the default device (usually CPU or
GPU). Defaults to None.
scheduler_cls: type[LRScheduler] | None
Learning rate scheduler class. If None, no scheduler is used. Defaults to
None.
scheduler_params: dict | None
Additional keyword arguments for the learning rate scheduler.
"""
# Init device
TorchDeviceMixin.__init__(self, device=device)
x, y = self._convert_to_tensors(x, y)
x, y = self._move_tensors_to_device(x, y)
# Local variables for number of features and tasks
n_features = x.shape[1]
num_tasks = y.shape[1]
num_tasks_torch = torch.Size([num_tasks])
# Initialize the mean and covariance modules
self.mean_module_fn = mean_module_fn
self.covar_module_fn = covar_module_fn
mean_module = self.mean_module_fn(n_features, num_tasks_torch)
covar_module = self.covar_module_fn(n_features, num_tasks_torch)
self.ref_model=ref_model
self.ref_likelihood=ref_likelihood
# If the combined kernel is not a ScaleKernel, wrap it in one
covar_module = (
covar_module
if isinstance(covar_module, ScaleKernel)
else ScaleKernel(covar_module, batch_shape=num_tasks_torch)
)
# Init likelihood
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_tasks)
likelihood = likelihood.to(self.device)
# Init must be called with preprocessed data
gpytorch.models.ExactGP.__init__(
self, train_inputs=x, train_targets=y, likelihood=likelihood
)
self.likelihood = likelihood
self.x_transform = StandardizeTransform() if standardize_x else None
self.y_transform = StandardizeTransform() if standardize_y else None
self.covar_module = covar_module #gpytorch.kernels.RBFKernel(ard_num_dims=n_features,batch_shape=num_tasks_torch)
self.register_parameter(name="a", parameter=torch.nn.Parameter(torch.randn([len(self.ref_model),num_tasks])))
self.mean_module= DiscrepancyMeanCohort(input_size=n_features,mean_module=mean_module,ref_model=self.ref_model,ref_likelihood=self.ref_likelihood,a=self.a)
for i in range(len(self.ref_model)): #Iterate over ref_models to build additive kernel TODO: Put this loop inside of the aKernel function
self.covar_module = aKernel(self.a[i],self.ref_model[i],self.ref_likelihood[i]) + self.covar_module
self.covar_module.kernels[0].ref_model.requires_grad_(False)
self.covar_module.kernels[0].ref_likelihood.requires_grad_(False)
self.epochs = epochs
self.lr = lr
self.optimizer = self.optimizer_cls(self.parameters(), lr=self.lr) # type: ignore[call-arg] since all optimizers include lr
self.scheduler_cls = scheduler_cls
self.scheduler_params = scheduler_params or {}
self.scheduler_setup(self.scheduler_params)
self.early_stopping = early_stopping
self.posterior_predictive = posterior_predictive
self.num_tasks = num_tasks
self.to(self.device)
def create_dgp_subclass(
# very slight alteration of create_gp_subclass adds in the ref_model and ref_likelihood arguments
name: str,
gp_base_class: type[DiscrepancyGaussianProcess],
covar_module_fn: CovarModuleFn,
mean_module_fn: MeanModuleFn,
ref_model: list,
ref_likelihood: list,
auto_register: bool = True,
overwrite: bool = True,
**fixed_kwargs,
) -> type[GaussianProcess]:
"""
Create a subclass of GaussianProcess with given fixed_kwargs.
This function creates a subclass of GaussianProcess where certain parameters
are fixed to specific values, reducing the parameter space for tuning.
The created subclass is automatically registered with the main emulator Registry
(unless auto_register=False), making it discoverable by AutoEmulate.
Parameters
----------
name : str
Name for the created subclass.
gp_base_class : type[GaussianProcess]
Base class to inherit from (typically GaussianProcess).
covar_module_fn : CovarModuleFn
Covariance module function to use in the subclass.
mean_module_fn : MeanModuleFn
Mean module function to use in the subclass. Defaults to `constant_mean`.
auto_register : bool
Whether to automatically register the created subclass with the main emulator
Registry. Defaults to True.
overwrite : bool
Whether to allow overwriting an existing class with the same name in the
main Registry. Useful for interactive development in notebooks. Defaults to
True.
**fixed_kwargs
Keyword arguments to fix in the subclass. These parameters will be
set to the provided values and excluded from hyperparameter tuning.
Returns
-------
type[GaussianProcess]
A new subclass of GaussianProcess with the specified parameters fixed.
The returned class can be pickled and used like any other GP emulator.
Raises
------
ValueError
If `name` matches `model_name()` or `short_name()` of an already registered
emulator in the main Registry and `overwrite=False`.
Notes
-----
- Fixed parameters are automatically excluded from `get_tune_params()` to prevent
them from being included in hyperparameter optimization.
- Pickling: The created subclass is registered in the caller's module namespace,
ensuring it can be pickled and unpickled correctly even when created in downstream
code that uses autoemulate as a dependency.
- If auto_register=True (default), the class is also added to the main Registry.
"""
standardize_x = fixed_kwargs.get("standardize_x", False)
standardize_y = fixed_kwargs.get("standardize_y", True)
fixed_mean_params = fixed_kwargs.get("fixed_mean_params", False)
fixed_covar_params = fixed_kwargs.get("fixed_covar_params", False)
posterior_predictive = fixed_kwargs.get("posterior_predictive", False)
epochs = fixed_kwargs.get("epochs", 50)
lr = fixed_kwargs.get("lr", 2e-1)
early_stopping = fixed_kwargs.get("early_stopping")
device = fixed_kwargs.get("device")
class DGaussianProcessSubclass(gp_base_class):
def __init__(
self,
x: TensorLike,
y: TensorLike,
ref_model=ref_model,
ref_likelihood=ref_likelihood,
standardize_x: bool = standardize_x,
standardize_y: bool = standardize_y,
likelihood_cls: type[
MultitaskGaussianLikelihood
] = MultitaskGaussianLikelihood,
mean_module_fn: MeanModuleFn = mean_module_fn,
covar_module_fn: CovarModuleFn = covar_module_fn,
fixed_mean_params: bool = fixed_mean_params,
fixed_covar_params: bool = fixed_covar_params,
posterior_predictive: bool = posterior_predictive,
epochs: int = epochs,
lr: float = lr,
early_stopping: EarlyStopping | None = early_stopping,
device: DeviceLike | None = device,
**scheduler_params,
):
super().__init__(
x,
y,
ref_model=ref_model,
ref_likelihood=ref_likelihood,
standardize_x=standardize_x,
standardize_y=standardize_y,
likelihood_cls=likelihood_cls,
mean_module_fn=mean_module_fn,
covar_module_fn=covar_module_fn,
fixed_mean_params=fixed_mean_params,
fixed_covar_params=fixed_covar_params,
posterior_predictive=posterior_predictive,
epochs=epochs,
lr=lr,
early_stopping=early_stopping,
device=device,
**scheduler_params,
)
@staticmethod
def get_tune_params():
"""Get tunable parameters, excluding those that are fixed."""
tune_params = gp_base_class.get_tune_params()
# Remove fixed parameters from tuning
tune_params.pop("mean_module_fn", None)
tune_params.pop("covar_module_fn", None)
for key in fixed_kwargs:
tune_params.pop(key, None)
return tune_params
# Create a more descriptive docstring that includes fixed parameters
mean_covar_and_fixed_kwargs = {
"mean_module_fn": mean_module_fn,
"covar_module_fn": covar_module_fn,
**fixed_kwargs,
}
fixed_params_str = "\n ".join(
f"- {k} = {v.__name__ if callable(v) else v}"
for k, v in mean_covar_and_fixed_kwargs.items()
)
DGaussianProcessSubclass.__doc__ = f"""
{gp_base_class.__doc__}
Notes
-----
{name} is a subclass of {gp_base_class.__name__} and has the following parameters
set during initialization:
{fixed_params_str}
For any parameters set with this approach, they are also excluded from the search
space when tuning. For example, if the `covar_module_fn` is set to `rbf_kernel`,
the RBF kernel will always be used as the `covar_module`. Note that in this case
the associated hyperparameters (such as lengthscale) will still be fitted during
model training and are not fixed.
"""
# Determine the caller's module for proper pickling support.
# When called from autoemulate itself, use __name__.
# When called from user code, use the caller's module
caller_frame = sys._getframe(1)
caller_module_name = caller_frame.f_globals.get("__name__", __name__)
# Set the class name and module
DGaussianProcessSubclass.__name__ = name
DGaussianProcessSubclass.__qualname__ = name
DGaussianProcessSubclass.__module__ = caller_module_name
# Register class in the caller's module globals for pickling
# This ensures the class can be pickled/unpickled correctly
caller_frame.f_globals[name] = DGaussianProcessSubclass
# Also register in the caller's module if it's a real module (not __main__)
if caller_module_name in sys.modules and caller_module_name != "__main__":
setattr(sys.modules[caller_module_name], name, DGaussianProcessSubclass)
# Automatically register with the main emulator Registry if requested
if auto_register:
# Lazy import to avoid circular dependency with __init__.py
from autoemulate.emulators import register # noqa: PLC0415
register(DGaussianProcessSubclass, overwrite=overwrite)
return DGaussianProcessSubclass