-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathnet_helpers.py
More file actions
1862 lines (1484 loc) · 84.1 KB
/
net_helpers.py
File metadata and controls
1862 lines (1484 loc) · 84.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import time # For debugging
import copy
import numpy as np
import matplotlib.pyplot as plt
import math
import seaborn as sns
import copy
import math, os, gc
import sys
import torch
from torch import nn
from torch.nn import functional as F
from scipy.stats import lognorm
from typing import Dict
import mpn_tasks
import helper
# 0 Red, 1 blue, 2 green, 3 purple, 4 orange, 5 teal, 6 gray, 7 pink, 8 yellow
c_vals = ['#e53e3e', '#3182ce', '#38a169', '#805ad5','#dd6b20', '#319795', '#718096', '#d53f8c', '#d69e2e',]
c_vals_l = ['#feb2b2', '#90cdf4', '#9ae6b4', '#d6bcfa', '#fbd38d', '#81e6d9', '#e2e8f0', '#fbb6ce', '#faf089',]
c_vals_d = ['#9b2c2c', '#2c5282', '#276749', '#553c9a', '#9c4221', '#285e61', '#2d3748', '#97266d', '#975a16',]
l_vals = ['solid', 'dashed', 'dotted', 'dashdot', '-', '--', '-.', ':', (0, (3, 1, 1, 1)), (0, (5, 10))]
markers_vals = ['.', 'o', 'v', '^', '<', '>', '1', '2', '3', '4', 's', 'p', '*', 'h', 'H', '+', 'x', 'D', 'd', '|', '_']
GLOBAL_PRINT_FREQUENCY = 1
def _to_cpu(x):
# torch tensor -> cpu tensor; numpy stays; python stays
if isinstance(x, torch.Tensor):
return x.detach().cpu()
return x
def save_checkpoint(path, net, params, hyp_dict, seed,
test_bundle=None,
training_bundle=None):
"""
Save a compact checkpoint suitable for reloading the trained net and running downstream analyses.
Args:
path: str, e.g. "./ckpts/exp1.pt"
net: trained model
params: tuple (task_params, train_params, net_params)
hyp_dict: dict
seed: int
test_bundle: dict of tensors/arrays needed for analysis
training_bundle: optional dict of large training traces (counter_lst, db_lst, etc.)
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
ckpt = {
"seed": seed,
"hyp_dict": hyp_dict,
"params": params,
"net_state_dict": net.state_dict(),
"net_class": hyp_dict.get("chosen_network", None), # e.g. "dmpn"/"gru"/"vanilla"
}
if hasattr(net, "hist"):
ckpt["net_hist"] = net.hist # learning curves; usually small
if test_bundle is not None:
# move torch tensors to CPU to make reload portable
ckpt["test_bundle"] = {k: _to_cpu(v) for k, v in test_bundle.items()}
if training_bundle is not None:
# WARNING: this can be huge if you include db_lst/netout_lst
ckpt["training_bundle"] = training_bundle
torch.save(ckpt, path)
print(f"Saved checkpoint to: {path}")
def load_checkpoint(path, device, netFunction, rebuild_net_kwargs=None):
"""
Load checkpoint and reconstruct net.
Args:
path: str to .pt
device: torch.device
netFunction: the class constructor you already use (e.g., mpn.DeepMultiPlasticNet / nets.GRU)
rebuild_net_kwargs: optional dict passed into constructor
Returns:
net, params, hyp_dict, test_bundle, training_bundle (may be None)
"""
ckpt = torch.load(path, map_location="cpu")
task_params, train_params, net_params = ckpt["params"]
hyp_dict = ckpt["hyp_dict"]
rebuild_net_kwargs = rebuild_net_kwargs or {}
# Rebuild network exactly as train_network would
# Most codebases use something like: net = netFunction(task_params, net_params, train_params, **kwargs)
# If your constructor differs, adjust here once.
net = netFunction(task_params, net_params, train_params, **rebuild_net_kwargs)
net.load_state_dict(ckpt["net_state_dict"])
net.to(device)
net.eval()
test_bundle = ckpt.get("test_bundle", None)
training_bundle = ckpt.get("training_bundle", None)
return net, (task_params, train_params, net_params), hyp_dict, test_bundle, training_bundle
def tail_mean_decay(lst, N, decay=0.95):
"""
"""
if N is None:
return None
lst = np.asarray(lst, dtype=float)
if len(lst) < N:
data = lst
else:
data = lst[-N:]
# weights: 1.0 for last element, decay^k for others (reverse order)
w = decay ** np.arange(len(data)-1, -1, -1)
return np.sum(data * w) / np.sum(w)
def expand_and_freeze(net, option, nettype):
"""
Modify the *input weight* of the network and freeze parameters according to `option`.
Two supported cases:
1) "Linear front-end" networks
- Identified by `net.input_layer_active == True` and a module `net.W_initial_linear`
which is an `nn.Linear`.
- Behavior is SAME as your original function:
option=0: expand `W_initial_linear` by +1 input feature; only the new last column trains.
option=1: keep same shape; only the existing last column trains.
2) VanillaRNN-style networks
- Identified by having `net.W_input` which is a (n_hidden, n_input) matrix used in
`torch.einsum('iI, BI -> Bi', self.W_input, x)`.
- Behavior:
option=0: expand `W_input` by +1 input feature (extra column); only that extra
last column trains.
option=1: keep same shape; only the existing last column trains.
"""
print(f"Expanding and Freezing on Network with Type: {nettype}")
was_training = net.training
# -------------------------------------------------------------------------
# CASE A: networks with an nn.Linear front-end: net.W_initial_linear
# -------------------------------------------------------------------------
if getattr(net, "input_layer_active", False) and hasattr(net, "W_initial_linear"):
old_linear = net.W_initial_linear
assert isinstance(old_linear, nn.Linear), \
"net.W_initial_linear must be an nn.Linear."
in_f, out_f = old_linear.in_features, old_linear.out_features
bias_flag = old_linear.bias is not None
device = old_linear.weight.device
dtype = old_linear.weight.dtype
# --- build replacement Linear ---------------------------------------
if option == 0:
# expand input dim by +1
new_in_f = in_f + 1
new_linear = nn.Linear(new_in_f, out_f, bias=bias_flag).to(device)
new_linear.weight.data = new_linear.weight.data.to(dtype)
if bias_flag:
new_linear.bias.data = new_linear.bias.data.to(dtype)
with torch.no_grad():
# copy old weights (out_f, in_f) -> left block
new_linear.weight[:, :in_f].copy_(old_linear.weight.detach())
# init the extra column (out_f, 1) as the "trainable" column
nn.init.kaiming_uniform_(new_linear.weight[:, -1:].t(),
a=math.sqrt(5))
if bias_flag:
new_linear.bias.copy_(old_linear.bias.detach())
elif option == 1:
# keep same shape; replace module to clear any old hooks/state
new_linear = nn.Linear(in_f, out_f, bias=bias_flag).to(device)
new_linear.weight.data = new_linear.weight.data.to(dtype)
if bias_flag:
new_linear.bias.data = new_linear.bias.data.to(dtype)
with torch.no_grad():
new_linear.weight.copy_(old_linear.weight.detach())
if bias_flag:
new_linear.bias.copy_(old_linear.bias.detach())
else:
raise ValueError("option must be 0 or 1")
# swap into the network
net.W_initial_linear = new_linear
# Keep original train/eval mode
net.train(was_training)
# freeze everything by default
for p in net.parameters():
p.requires_grad = False
# allow gradients on the entire weight tensor; mask will zero unwanted cols
net.W_initial_linear.weight.requires_grad = True
# mask grads so only the last input column updates
def _only_last_col_grad(grad: torch.Tensor) -> torch.Tensor:
# grad shape: (out_f, in_features_current)
mask = torch.zeros_like(grad)
mask[:, -1] = 1
return grad * mask
handle = net.W_initial_linear.weight.register_hook(_only_last_col_grad)
return net, handle
# -------------------------------------------------------------------------
# CASE B: VanillaRNN-style networks with a raw matrix net.W_input
# -------------------------------------------------------------------------
if hasattr(net, "W_input"):
old_W = net.W_input
device = old_W.device
dtype = old_W.dtype
# old_W shape: (n_hidden, n_input)
n_hidden, n_input = old_W.shape
if option == 0:
# expand by +1 input feature (add one column)
new_W = torch.zeros(n_hidden, n_input + 1, device=device, dtype=dtype)
with torch.no_grad():
# copy existing columns
new_W[:, :n_input].copy_(old_W.detach())
# initialize new last column as the "trainable" one
nn.init.kaiming_uniform_(new_W[:, -1:].t(), a=math.sqrt(5))
# also update book-keeping of input size if the net tracks it
if hasattr(net, "n_input"):
net.n_input = n_input + 1
elif option == 1:
# keep shape, just make sure we get a fresh parameter
new_W = old_W.detach().clone()
else:
raise ValueError("option must be 0 or 1")
# Replace W_input with a fresh Parameter so we can control its grads
new_W_param = nn.Parameter(new_W, requires_grad=True)
net.W_input = new_W_param
# restore train/eval mode
net.train(was_training)
# freeze ALL other parameters
for name, p in net.named_parameters():
# We'll re-enable W_input below
p.requires_grad = False
print(f"{name} is freezed: shape: {p.shape}")
net.W_input.requires_grad = True
# gradient hook: only last input column gets nonzero grad
def _only_last_col_grad_rnn(grad: torch.Tensor) -> torch.Tensor:
# grad shape: (n_hidden, current_n_input)
mask = torch.zeros_like(grad)
mask[:, -1] = 1
return grad * mask
handle = net.W_input.register_hook(_only_last_col_grad_rnn)
return net, handle
# -------------------------------------------------------------------------
# If we reach here, we don't know how to expand this net
# -------------------------------------------------------------------------
raise RuntimeError(
"expand_and_freeze: network has neither W_initial_linear nor W_input; "
"cannot apply expansion logic."
)
def train_network(params, net=None, device=torch.device('cuda'), verbose=False,
train=True, hyp_dict=None, netFunction=None, test_input=None,
pretraining_shift=0, pretraining_shift_pre=0, print_frequency=1
):
"""
"""
# 2025-11-12: pass GLOBAL_PRINT_FREQUENCY through external input
global GLOBAL_PRINT_FREQUENCY
GLOBAL_PRINT_FREQUENCY = print_frequency
assert isinstance(test_input, list), "test_input must be a list"
task_params, train_params, net_params = params
# 2025-10-30: based on the coding logic, this condition needs to be satisfied
assert net_params["monitor_freq"] <= train_params["n_epochs_per_set"]
# indicates of post-training stage is happening
if net is not None and pretraining_shift != 0:
net, _ = expand_and_freeze(net, option=1, nettype=hyp_dict['chosen_network'])
if task_params['task_type'] in ('multitask',):
def generate_train_data(device='cuda', verbose=True):
# ZIHAN
# correct thing to dox
# "training batches should only be over one rule"
# "but validation should mix them"
train_data, (_, train_trails, _ ) = mpn_tasks.generate_trials_wrap(
task_params, train_params['n_batches'], device=device, verbose=verbose, mode_input=hyp_dict['mode_for_all'], \
pretraining_shift=pretraining_shift, pretraining_shift_pre=pretraining_shift_pre
)
return train_data, train_trails
def generate_valid_data(device='cuda'):
valid_data, (_, valid_trails, _) = mpn_tasks.generate_trials_wrap(
task_params, train_params['valid_n_batch'], rules=task_params['rules'], device=device, mode_input=hyp_dict['mode_for_all'], \
pretraining_shift=pretraining_shift, pretraining_shift_pre=pretraining_shift_pre
)
return valid_data, valid_trails
else:
raise ValueError('Task type not recognized.')
# Create a new network
if net is None:
# overwrite the input information
if pretraining_shift_pre > 0:
n_neurons = net_params["n_neurons"]
new_n_neurons = copy.deepcopy(n_neurons)
new_n_neurons[0] += pretraining_shift_pre
net_params["n_neurons"] = new_n_neurons
print(net_params["n_neurons"])
net = netFunction(net_params, verbose=verbose)
num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
# 2025-11-19: even if expand_and_freeze is used, pytorch will display the whole
# input matrix as trainable, but internally in expand_and_freeze, we mask the
# gradient flow so that only the last row (contextual entry) could be updated
print(f"Trainable parameters: {num_params:,}")
# 2025-11-16: print all trainable parameters and their shapes
for name, param in net.named_parameters():
if param.requires_grad:
print(f"{name}: {tuple(param.shape)}")
# 2025-11-19: double check the setting for the recurrent weight matrix
if name == "W_rec":
print(param)
# Puts network on device
net.to(device)
# 2025-10-29: the validation dataset is only generated ONCE and used throughout
valid_data, valid_trails = generate_valid_data(device=device)
counter_lst = []
# to customize if multiple test data are used
netout_lst, db_lst = [[] for _ in range(len(test_input))], [[] for _ in range(len(test_input))]
Winput_lst, Woutput_lst, Winputbias_lst, Wall_lst, marker_lst, loss_lst, acc_lst = [], [], [], [], [], [], []
net_lst = []
total_dataset = train_params['n_datasets']
dataset_idx_early_stop = None
for dataset_idx in range(total_dataset):
# generate training data
train_data, train_trails = generate_train_data(device=device, verbose=(dataset_idx % GLOBAL_PRINT_FREQUENCY == 0))
new_thresh = True if dataset_idx == 0 else False
if train:
# Jul 19th: test the network's output on the testing dataset at the different stage of the network
# save and register the network's parameter and output
if test_input is not None and (helper.is_power_of_n_or_zero(dataset_idx, 32) or dataset_idx == train_params['n_datasets'] - 1):
# print(f"How about Test Data at dataset {dataset_idx}")
counter_lst.append(dataset_idx)
# test data for each stage
for test_input_index, test_input_ in enumerate(test_input):
# 2025-07-16: should we separate and load the input, and stack the output later
# print(f"test_input_: {test_input_.shape}")
minibatch = 8
net_out_np = []
db = []
for start in range(0, test_input_.shape[0], minibatch):
end = min(start + minibatch, test_input_.shape[0])
test_input_batch = test_input_[start:end]
# both are in cuda
net_out_batch, _, db_batch = net.iterate_sequence_batch(test_input_batch, run_mode='track_states')
# register the values
net_out_np.append(net_out_batch.detach().cpu().numpy())
db.append({k: v.detach().cpu().numpy() for k, v in db_batch.items()})
del net_out_batch, db_batch
gc.collect(); torch.cuda.empty_cache()
# stack
net_out_np = np.concatenate(net_out_np, axis=0)
db_all = {}
for key in db[0].keys():
db_key = np.concatenate([db_[key] for db_ in db], axis=0)
db_all[key] = db_key
netout_lst[test_input_index].append(net_out_np)
db_lst[test_input_index].append(db_all)
Woutput_lst.append(net.W_output.detach().cpu().numpy())
if net_params["input_layer_add"] and net_params["net_type"] == "dmpn":
Winput_lst.append(net.W_initial_linear.weight.detach().cpu().numpy())
if net_params["input_layer_bias"]:
Winputbias_lst.append(net.W_initial_linear.bias.detach().cpu().numpy())
elif net_params["input_layer_add"] and net_params["net_type"] == "vanilla":
Winput_lst.append(net.W_input.detach().cpu().numpy())
# no input bias is set for vanilla RNN in kyle's original design
W_all_ = []
if params[2]["net_type"] == "dmpn":
for i in range(len(net.mp_layers)):
W_all_.append(net.mp_layers[i].W.detach().cpu().numpy())
Wall_lst.append(W_all_)
marker_lst.append(dataset_idx)
_, monitor_loss, monitor_acc, goodness_history, valid_acc_history = net.fit(train_params, train_data, train_trails,
valid_batch=valid_data, valid_trails=valid_trails,
new_thresh=new_thresh, run_mode=hyp_dict['run_mode'], datanum=dataset_idx)
if dataset_idx % GLOBAL_PRINT_FREQUENCY == 0:
# 2025-11-13: we should consider loss temporal convolution with various decay factor
# to make sure the network has good performance across different levels
# i.e. if decay_factor is 1.00, equally focusing across the temporal window
# if decay_factor < 1.00, then allow gradual weight decaying for the past history
print(f"valid_acc_history: {valid_acc_history}")
# if None, then no early stop option
# 2025-11-16: have some manual control to let the network has a minimum exposure to the
# training dataset
if train_params["valid_check"] is not None and dataset_idx >= train_params["pretrain_min"]:
if all(v > 0.97 for v in valid_acc_history):
print(f"valid_acc_history > 0.97; early stop!")
dataset_idx_early_stop = dataset_idx
break
# calculate the change of task sampling proportion
# aim for multi-task training (but clearly work for single-task)
# --- inputs ---------------------------------------------------------------
df = task_params["adjust_task_decay"] # scalar decay factor
# 2025-07-19: for pretraining
max_len = max(len(a) for a in goodness_history)
goodness_history = [
np.pad(a.astype(float), # ensure float → can hold NaN
(0, max_len - len(a)), # (pad_left, pad_right)
constant_values=np.nan) # fill value
for a in goodness_history
]
# print(f"goodness_history: {goodness_history}")
g = np.vstack(goodness_history)
# sanity-check: every row must have the same length
if len({len(row) for row in goodness_history}) != 1:
raise ValueError("`last_group_goodness` rows have unequal lengths.")
if pretraining_shift == 0:
# --- compute weighted sum --------------------------------------------------
# exponent sequence: R, R-1, … , 1 (replicates original i_ = len-adji)
exp = np.arange(g.shape[0], 0, -1)[:, None] # shape (R,1) for broadcasting
weights = df ** exp # shape (R,1)
all_adjust = (g * weights).sum(axis=0) # shape (C,)
# update the sampling probability
assert len(all_adjust) == len(task_params["rules"])
task_params['rules_probs'] = mpn_tasks.normalize_to_one(all_adjust)
else: # by default only one task, so probability trivally to 1
task_params['rules_probs'] = np.array([1])
if test_input is not None and (helper.is_power_of_n_or_zero(dataset_idx, 8) or dataset_idx == train_params['n_datasets'] - 1):
loss_lst.append(monitor_loss)
acc_lst.append(monitor_acc)
return net, (train_data, valid_data), (counter_lst, netout_lst, db_lst, Winput_lst, Winputbias_lst, \
Woutput_lst, Wall_lst, marker_lst, loss_lst, acc_lst), dataset_idx_early_stop
def net_eta_lambda_analysis(net, net_params, hyp_dict=None, verbose=False):
"""
"""
layer_index = 0 # 1 layer MPN
if net_params["input_layer_add"]:
layer_index += 1 # 2 layer MPN
# only make sense for dmpn for eta and lambda information extraction
if net_params['net_type'] in ("dmpn",):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
for mpl_idx, mp_layer in enumerate(net.mp_layers):
if net.mp_layers[mpl_idx].eta_type in ('pre_vector', 'post_vector', 'matrix',):
full_eta = np.concatenate([
eta.flatten()[np.newaxis, :] for eta in net.hist['eta{}'.format(mpl_idx+layer_index)]
], axis=0)
else:
full_eta = net.hist['eta{}'.format(mpl_idx+layer_index)]
if net.mp_layers[mpl_idx].lam_type in ('pre_vector', 'post_vector', 'matrix',):
full_lam = np.concatenate([
lam.flatten()[np.newaxis, :] for lam in net.hist['lam{}'.format(mpl_idx+layer_index)]
], axis=0)
else:
full_lam = net.hist['lam{}'.format(mpl_idx+layer_index)]
ax1.plot(net.hist['iters_monitor'], full_eta, color=c_vals[mpl_idx], label='MPL{}'.format(mpl_idx+layer_index))
ax2.plot(net.hist['iters_monitor'], full_lam, color=c_vals[mpl_idx], label='MPL{}'.format(mpl_idx+layer_index))
ax1.axhline(0.0, color='k', linestyle='dashed')
ax1.set_ylabel('Eta')
ax2.set_ylabel('Lambda')
if net.mp_layers[mpl_idx].eta_type not in ('pre_vector', 'post_vector', 'matrix',):
ax1.legend()
ax2.legend()
if verbose:
fig.savefig(f"./results/eta_lambda_{hyp_dict['ruleset']}_{hyp_dict['chosen_network']}_{hyp_dict['addon_name']}.png")
# only for deep mpn with multiple layers
n_mplayers = len(net.mp_layers)
if n_mplayers > 1:
fig, axs = plt.subplots(1, n_mplayers, figsize=(4+4*n_mplayers, 4))
n_bins = 50
for mpl_idx, (mp_layer, ax) in enumerate(zip(net.mp_layers, axs)):
init_eta = net.hist['eta{}'.format(mpl_idx+layer_index)][0].flatten()
final_eta = net.hist['eta{}'.format(mpl_idx+layer_index)][-1].flatten()
max_eta = np.max((np.max(np.abs(init_eta)), np.max(np.abs(final_eta))))
bins = np.linspace(-max_eta, max_eta, n_bins+1)[1:]
ax.hist(init_eta, bins=bins, color=c_vals_l[mpl_idx], alpha=0.3, label='init')
ax.hist(final_eta, bins=bins, color=c_vals[mpl_idx], alpha=0.3, label='final')
ax.set_ylabel('Count')
ax.set_xlabel('Eta value')
ax.legend()
if verbose:
fig.savefig(f"./results/eta_distribution_{hyp_dict['ruleset']}_{hyp_dict['chosen_network']}_{hyp_dict['addon_name']}.png")
def rand_weight_init(n_inputs, n_outputs=None, init_type='gaussian', cell_types=None,
ei_balance_val=None, sparsity=None, weight_norm=None, self_couplings=True):
"""
Returns a random weight initialization of the specified type, of size
(n_inputs,) if n_ouputs is None, otherwise (n_outputs, n_inputs).
Note: uses numpy throughout, converted to tensor after called
ei_balance_val: balances strength of excitation and inhibition
spasity: override sparsity
self_couplings: whether or not diagonal couplings can be nonzero
"""
sparsity_p = 1.0 if sparsity is None else sparsity
if n_outputs is not None: # 2d case
weight_shape = (n_outputs, n_inputs,)
else: # 1d case
weight_shape = (n_inputs,)
if init_type == 'xavier':
if n_outputs is not None: # 2d case
xavier_bound = np.sqrt(6/(n_inputs + n_outputs)) if weight_norm is None else weight_norm
else: # 1d case
xavier_bound = np.sqrt(6/(n_inputs)) if weight_norm is None else weight_norm
rand_weights = np.random.uniform(low=-xavier_bound, high=xavier_bound, size=weight_shape)
elif init_type in ('gaussian', 'sparse_gaussian', 'sparse_gaussian_ln_in'):
norm_factor = 1/np.sqrt(n_inputs) if weight_norm is None else weight_norm
rand_weights = norm_factor * np.random.normal(scale=1.0, size=weight_shape)
elif init_type in ('mirror_gaussian',): # Two gaussians peaks centered at opposite means (fixed mean/scale ratio for now)
norm_factor = 1/np.sqrt(n_inputs) if weight_norm is None else weight_norm
rand_weights = norm_factor * (
np.random.choice([-1, 1], size=weight_shape) * np.random.normal(loc=1.0, scale=0.5, size=weight_shape)
)
elif init_type in ('log_normal', 'sparse_log_normal'):
norm_factor = 1/np.sqrt(n_inputs) if weight_norm is None else weight_norm # if init_type == 'log_normal' else n_outputs * sparsity_p
rand_weights = norm_factor * np.random.lognormal(mean=0.0, sigma=1.0, size=weight_shape)
elif init_type in ('ones', 'sparse_ones') or type(init_type) == float:
if n_outputs is not None: # 2d case
if n_outputs > 1:
raise NotImplementedError('Normalization is weird for this, should think about it if we use this again')
norm_factor = 1/np.sqrt(n_inputs) if weight_norm is None else weight_norm #if init_type == 'ones' else n_outputs * sparsity_p
else: # 1d case
norm_factor = 1/np.sqrt(n_inputs) if weight_norm is None else weight_norm #if init_type == 'ones' else n_inputs * n_outputs * sparsity_p
if init_type in ('ones', 'sparse_ones'):
rand_weights = norm_factor * np.ones(weight_shape)
else:
rand_weights = init_type * np.ones(weight_shape)
elif init_type in ('rand_one_hot',):
assert len(weight_shape) == 2
shuffle_idxs = np.random.permutation(max(weight_shape))
# Shuffle an identity matrix that is the larger of the two weight dimensions
square_weights = np.eye(max(weight_shape))[shuffle_idxs, :]
# Clip to appropriate size
rand_weights = square_weights[:weight_shape[0], :weight_shape[1]]
elif init_type in ('zeros',):
rand_weights = np.zeros(weight_shape)
else:
raise NotImplementedError('Random weight init type {} not recognized!'.format(init_type))
# Creates sparsification masks (masks that determine which elements are zero)
if init_type in ('sparse_gaussian', 'sparse_log_normal', 'sparse_ones', 'sparse_gaussian_ln_in'):
if init_type in ('sparse_gaussian', 'sparse_ones',): # Note this creates an in-degree distribution that is Gaussian distributed
weight_mask = np.random.uniform(0, 1, size=rand_weights.shape) > sparsity_p # Which weights to zero
elif init_type in ('sparse_log_normal', 'sparse_gaussian_ln_in',):
LOG_NORM_SHAPE = 0.25 # This could be adjusted to experimental data eventually
# Adjusts scale relative to standard log normal to set desired mean
standard_ln = lognorm(s=LOG_NORM_SHAPE)
scaled_ln = lognorm(s=LOG_NORM_SHAPE, scale=sparsity_p / standard_ln.mean())
# Draws twice as many as needed since will be truncating all > 1.0
in_degrees = scaled_ln.rvs(size=2*n_outputs)
in_degrees = in_degrees[in_degrees <=1.0] # Truncates to <1.0 in degrees
if in_degrees.shape[0] < n_outputs:
raise ValueError('Did not draw enough in-degrees that met threshold (only {}).'.format(
in_degrees.shape[0]
))
weight_mask = np.zeros(weight_shape, dtype=bool)
for out_idx in range(n_outputs): # Generates mask one output at a time
in_degree_mask = np.zeros((n_inputs,), dtype=bool) # True for elements set to zero
# Set all elements beyond index to be True, corresponding to no connection
# (small in_degree[out_idx] corresponds to majority 1s)
in_degree_mask[int(np.floor(in_degrees[out_idx] * n_inputs)):] = True
np.random.shuffle(in_degree_mask)
weight_mask[out_idx] = in_degree_mask
if not self_couplings: # Sets all diagonal elements of mask so weights are set to zero
assert n_inputs == n_outputs
for neuron_idx in range(n_inputs):
weight_mask[neuron_idx, neuron_idx] = True
rand_weights[weight_mask] = 0.0
if (cell_types is not None) and (n_outputs is not None): # Assigns cell types
cell_types = cell_types.detach().cpu().numpy()
assert cell_types.shape[0] == 1
assert cell_types.shape[1] == n_inputs
if ei_balance_val is not None:
# Signed based on cell type, but also enhances strength appropriately
cell_weights = np.where(cell_types < 0, ei_balance_val * cell_types, cell_types)
else:
cell_weights = cell_types # Just 1s or -1s to correct sign
rand_weights = cell_weights * np.abs(rand_weights)
# print(' Init type:', init_type)
# print(' Cell types:', cell_types)
# perc_pos = np.sum(rand_weights > 0) / np.prod(rand_weights.shape)
# perc_neg = np.sum(rand_weights < 0) / np.prod(rand_weights.shape)
# print(' Perc pos: {:.2f} perc neg: {:.2f}'.format(perc_pos, perc_neg))
return rand_weights
def append_to_average(avg_raw, raw, n_window=1):
""" Rolling average """
if len(raw) < n_window:
avg_raw.append(np.mean(raw))
else:
avg_raw.append(np.mean(raw[-n_window:]))
return avg_raw
def accumulate_decay(decay_raw, raw, n_window=10, base_val=0.0):
""" Average by accumulation and decay """
gamma = 1. - 1./n_window
if len(decay_raw) == 0: # Use base_val for previous value
if np.isnan(raw[-1]): # Skips nans and just set to base value
decay_raw.append(base_val)
else:
decay_raw.append(gamma*base_val + 1/n_window * raw[-1])
else:
if np.isnan(raw[-1]): # Skips nans and just holds decay_raw constant
decay_raw.append(decay_raw[-1])
else:
decay_raw.append(gamma*decay_raw[-1] + 1/n_window * raw[-1])
# print('Raw {} new avg {}'.format(decay_raw[-1], decay_raw[-1]))
return decay_raw
def relu_fn(x):
return torch.maximum(x, torch.zeros_like(x))
def relu_fn_np(x):
return np.maximum(x, np.zeros_like(x))
def relu_fn_p(x):
return torch.heaviside(x, torch.zeros_like(x)) # For x = 0, return 0 just like pytorch default
def sigmoid(x):
return 1 / (1 + torch.exp(-x))
def sigmoid_np(x):
return 1 / (1 + np.exp(-x))
def sigmoid_p(x):
sx = sigmoid(x)
return sx * (1 - sx)
def tanh_p(x):
return (1/torch.cosh(x))**2 # Sech**2
def tanh_re(x):
return torch.maximum(torch.tanh(x), torch.zeros_like(x))
def tanh_re_np(x):
return np.maximum(np.tanh(x), np.zeros_like(x))
def tanh_re_p(x): # Have to use where here instead of maximum because we want this to return 0 for x = 0, like ReLU (and sech(0) = 1)
return torch.where(x > 1e-6, (1/torch.cosh(x))**2, 0.) # Sech**2
def tanh_re_super(x, alpha=2.0):
return torch.maximum(torch.tanh(alpha*x), torch.zeros_like(x))
def tanh_re_super_np(x, alpha=2.0):
return np.maximum(np.tanh(alpha*x), np.zeros_like(x))
def tanh_re_super_p(x): # Have to use where here instead of maximum because we want this to return 0 for x = 0, like ReLU (and sech(0) = 1)
raise NotImplementedError()
return torch.where(x > 0., (1/torch.cosh(x))**2, 0.) # Sech**2
def linear_fn(x):
return x
def linear_fn_p(x):
return torch.ones_like(x)
def cubed_re(x):
return torch.maximum(x**3, torch.zeros_like(x))
def cubed_re_p(x):
return torch.maximum(3*x**2, torch.zeros_like(x))
def tukey_fn(x):
raw = 1/2 - 1/2 * torch.cos(np.pi * x)
raw[x < 0.] = 0.0
raw[x > 1.] = 1.0
return raw
def tukey_fn_np(x):
raw = 1/2 - 1/2 * np.cos(np.pi * x)
raw[x < 0.] = 0.0
raw[x > 1.] = 1.0
return raw
def tukey_fn_p(x):
raw = np.pi / 2 * torch.sin(np.pi * x)
raw[x < 0.] = 0.0
raw[x > 1.] = 0.0
def heaviside_p(x):
raise NotImplementedError()
# 2025-11-17: add softplus function to match with LD's paper
def softplus_fn(x):
return torch.nn.functional.softplus(x)
def softplus_fn_np(x):
# stable: log(1 + exp(x))
return np.log1p(np.exp(x))
def softplus_fn_p(x):
# derivative of softplus is sigmoid(x)
return torch.sigmoid(x)
def get_activation_function(act_fn):
""" Returns pytorch version, numpy version, and pytorch derivative functions """
if act_fn == 'ReLU':
return relu_fn, relu_fn_np, relu_fn_p
elif act_fn == 'sigmoid':
return sigmoid, sigmoid_np, sigmoid_p
elif act_fn == 'tanh_re':
return tanh_re, tanh_re_np, tanh_re_p
elif act_fn == 'tanh_re_super': # Supra linear version of Tanh
return tanh_re_super, tanh_re_super_np, tanh_re_super_p
elif act_fn == 'tanh':
return torch.tanh, np.tanh, tanh_p
elif act_fn == 'tukey':
return tukey_fn, tukey_fn_np, tukey_fn_p
elif act_fn == 'linear':
return linear_fn, linear_fn, linear_fn_p
elif act_fn == 'cubed_re':
return cubed_re, None, cubed_re_p
elif act_fn == 'heaviside':
return lambda x : torch.heaviside(x, torch.tensor(0.5)), lambda x : np.heaviside(x, 0.5), heaviside_p
elif act_fn == 'softplus':
return softplus_fn, softplus_fn_np, softplus_fn_p
else:
raise ValueError('Activation function: {} not recoginized!'.format(act_fn))
def cosine_similarity_loss(output, pred):
"""
Cosine similiarty loss,
expects output and pred to both be: (B, Ny)
"""
cosine_sim = nn.CosineSimilarity(dim=-1)
return torch.mean(torch.abs(cosine_sim(output, pred)))
def mse_loss_weighted(output, pred, weight):
"""
Weighted version of MSE loss. For fitting curves weighted by mean squared error.
"""
return torch.mean(weight * (output - pred) ** 2)
def shuffle_dataset(inputs, labels, masks=None):
""" Shuffles a dataset over its batch index """
assert inputs.shape[0] == labels.shape[0] # Checks batch indexes are equal
shuffle_idxs = np.arange(inputs.shape[0])
np.random.shuffle(shuffle_idxs)
inputs = inputs[shuffle_idxs, :, :]
labels = labels[shuffle_idxs, :, :]
if masks is not None:
assert inputs.shape[0] == masks.shape[0]
masks = masks[shuffle_idxs, :, :]
else:
masks = None
return inputs, labels, masks
def round_to_values(array, round_vals):
""" Round all elements of an array to closest value in round_vals, numpy version """
assert len(round_vals.shape) == 1
dims_unsqueeze = [i for i in range(len(array.shape))] # How many times to unsqueeze round_vals based on array shape
dists = np.abs(
np.expand_dims(array, axis=-1) - # shape: (array.shape*, 1)
np.expand_dims(round_vals, axis=dims_unsqueeze)
)
return round_vals[np.argmin(dists, axis=-1)]
def round_to_values_torch(array, round_vals):
""" Round all elements of an array to closest value in round_vals, pytorch version """
assert len(round_vals.shape) == 1
round_vals_us = torch.clone(round_vals) # Copy of round_vals to be unsqueezed
dims_unsqueeze = len(array.shape) # How many times to unsqueeze based on array shape
for _ in range(dims_unsqueeze):
round_vals_us = round_vals_us.unsqueeze(0)
dists = torch.abs(
array.unsqueeze(-1) - round_vals_us
)
return round_vals[torch.argmin(dists, axis=-1)]
def one_hot_argidx(x: torch.Tensor) -> torch.Tensor:
"""
"""
if x.dim() != 3:
raise ValueError(f'Input must be 3-D, got {x.shape}')
# Verify one-hot condition
hits_per_slice = (x == 1).sum(dim=-1)
# argmax gives the index of the sole “1” along D
return torch.argmax(x, dim=-1)
def unique_nonzero_value(x: torch.Tensor):
"""
"""
# allow (N,1) input
if x.ndim == 2 and x.shape[1] == 1:
x = x.squeeze(1)
u = torch.unique(x)
# must be exactly two distinct values, one of them 0
if u.numel() != 2 or not (u == 0).any():
return 0
# return the non-zero entry
return u[u != 0].item()
def mean_by_group(A: torch.Tensor, B: torch.Tensor) -> Dict[int, float]:
"""
Compute the mean of A for each integer label in B.
Parameters
----------
A : (N,) or (N,1) tensor
Numeric values whose means you want.
B : (N,) or (N,1) tensor of ints
Integer labels; same length as A.
Returns
-------
dict {label: mean_A_for_that_label}
"""
# --- sanity checks ------------------------------------------------------
if A.numel() != B.numel():
raise ValueError(f"A and B must have the same number of elements "
f"(got {A.numel()} vs {B.numel()}).")
# flatten to 1-D
A = A.squeeze()
B = B.squeeze().to(torch.int64)
# --- fast vectorised solution using torch.bincount ----------------------
# works if labels are non-negative ints (they don’t have to be contiguous)
max_label = int(B.max())
if B.min() < 0:
raise ValueError("B must contain non-negative integers for bincount-based method.")
sums = torch.bincount(B, weights=A, minlength=max_label + 1)
counts = torch.bincount(B, minlength=max_label + 1)
# Avoid division by zero: mask out labels that never occur
valid = counts > 0
means = torch.zeros_like(sums, dtype=A.dtype)
means[valid] = sums[valid] / counts[valid]
# Build dict only for labels present in B
return {int(label): float(means[label]) for label in torch.unique(B)}
##############################################################################################
##############################################################################################
##############################################################################################
class BaseNetworkFunctions(nn.Module):
"""
Functions that may be used in networks and subnetwork layers.
Separate from BaseNetwork so that certain subnetwork layers can use these functions without having
to fully init a new network each time.
"""
def __init__(self):
super().__init__()
def parameter_or_buffer(self, name, tensor=None, verbose=False):
"""
Helper function to register certain variables as parameters or buffers in networks.
Also has capability of updating prameters to buffers or vice versa, in which case
a tensor value does not need to be passed
If "name" is in the "self.params" tuple, then registers as parameter, otherwise buffer.
"""
# If the attribute already exists, copy it to tensor then delete it
update_name = True
if hasattr(self, name): # Automatrically update if it doesnt exist, otherwise some checks
# Find out if this is a parameter or buffer (probably a better way to do this)
param = False
buffer = False
for buff_name, _ in self.named_buffers():
if buff_name == name:
buffer = True
for param_name, _ in self.named_parameters():
if param_name == name:
param = True
assert not (buffer and param)
# If already the correct type
if (buffer and (name not in self.params)) or (param and (name in self.params)):
update_name = False
else:
if verbose: print('Updating atribute: ', name)
tensor = getattr(self, name).detach().clone()
delattr(self, name)
if update_name:
if name in self.params:
self.register_parameter(name, nn.Parameter(tensor))
else:
self.register_buffer(name, tensor)
class BaseNetwork(BaseNetworkFunctions):
def __init__(self, net_params, verbose=False):
super().__init__()
self.loss_type = net_params.get('loss_type', 'XE')
self.acc_measure = net_params.get('acc_measure', 'angle')
if self.loss_type == 'XE':
self.loss_fn = F.cross_entropy
elif self.loss_type in ('MSE', 'MSE_grads',):
self.loss_fn = F.mse_loss
elif self.loss_type == 'MSE_weighted':
self.loss_fn = mse_loss_weighted
elif self.loss_type == 'CS': # cosine similarity
self.loss_fn = cosine_similarity_loss
elif self.loss_type in ('rl_direct_output', 'rl_output_deviations',): # RL loss functions implemented manually for now