-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdiff_cpu.py
More file actions
258 lines (213 loc) · 9.24 KB
/
Copy pathdiff_cpu.py
File metadata and controls
258 lines (213 loc) · 9.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""
diff_cpu.py -- A computer that learns its own logic gates.
Part 1 -- Soft gates
Replace nand(a,b) with soft_nand(a,b) = 1 - a*b.
Exact on {0,1}. Differentiable everywhere. Every gate, adder, and
ALU operation becomes a polynomial over [0,1] that you can
backpropagate through.
Part 2 -- Learnable gates
Each gate has a trainable 4-entry truth table (via bilinear
interpolation over learned logits). A 1-bit full adder built from
5 of these gates starts with random parameters, sees all 8
(a, b, cin) -> (sum, carry) examples, and discovers XOR, AND, OR
purely through gradient descent. No logic is prescribed anywhere.
Run: python diff_cpu.py
"""
import random
import math
# ── Soft gates (product approximation) ───────────────────────────────────────
# Each is exact on {0,1} inputs and smoothly differentiable in between.
def soft_not(a): return 1.0 - a
def soft_and(a, b): return a * b
def soft_or(a, b): return a + b - a * b
def soft_xor(a, b): return a + b - 2.0 * a * b
def soft_nand(a, b): return 1.0 - a * b
def _soft_ha(a, b):
return soft_xor(a, b), soft_and(a, b) # sum, carry
def _soft_fa(a, b, cin):
s1, c1 = _soft_ha(a, b)
s2, c2 = _soft_ha(s1, cin)
return s2, soft_or(c1, c2) # sum, carry_out
def soft_add8(a_val, b_val):
"""8-bit ripple-carry adder using soft gates. Identical to alu.add8 on ints."""
a_bits = [(a_val >> i) & 1 for i in range(8)]
b_bits = [(b_val >> i) & 1 for i in range(8)]
result_bits, carry = [], 0.0
for i in range(8):
s, carry = _soft_fa(float(a_bits[i]), float(b_bits[i]), carry)
result_bits.append(s)
val = sum(round(result_bits[i]) << i for i in range(8))
return val, round(carry)
# ── Learnable gate ────────────────────────────────────────────────────────────
def _sig(x):
return 1.0 / (1.0 + math.exp(-max(-30.0, min(30.0, x))))
_GATE_NAMES = {
(0, 0, 0, 0): 'FALSE', (1, 1, 1, 1): 'TRUE',
(1, 1, 1, 0): 'NAND', (0, 0, 0, 1): 'AND',
(0, 1, 1, 0): 'XOR', (1, 0, 0, 1): 'XNOR',
(0, 1, 1, 1): 'OR', (1, 0, 0, 0): 'NOR',
(1, 1, 0, 0): 'NOT_A', (0, 0, 1, 1): 'A', # f(0,0),f(0,1),f(1,0),f(1,1)
(1, 0, 1, 0): 'NOT_B', (0, 1, 0, 1): 'B', # NOT_A=NOT(a), NOT_B=NOT(b)
}
class Gate:
"""
A gate with a learnable truth table.
logits[i] are the pre-sigmoid weights for truth table entries indexed as:
i=0: f(a=0, b=0) i=1: f(a=0, b=1)
i=2: f(a=1, b=0) i=3: f(a=1, b=1)
Forward: bilinear interpolation over the truth table.
out = Σ w_i * σ(logit_i)
where w = [(1-a)(1-b), (1-a)b, a(1-b), ab]
Backward: analytic gradients via chain rule.
"""
def __init__(self, label):
self.label = label
self.logits = [random.gauss(0, 2.0) for _ in range(4)]
def _sv(self):
return [_sig(l) for l in self.logits]
def fwd(self, a, b):
s = self._sv()
w = [(1-a)*(1-b), (1-a)*b, a*(1-b), a*b]
out = sum(w[i] * s[i] for i in range(4))
return out, (a, b, s, w)
def bwd(self, d_out, cache, lr):
"""Update logits; return (d_a, d_b)."""
a, b, s, w = cache
d_out = max(-5.0, min(5.0, d_out)) # gradient clip
for i in range(4):
self.logits[i] -= lr * d_out * w[i] * s[i] * (1.0 - s[i])
dw_da = [-(1-b), -b, (1-b), b ]
dw_db = [-(1-a), (1-a), -a, a ]
d_a = d_out * sum(dw_da[i] * s[i] for i in range(4))
d_b = d_out * sum(dw_db[i] * s[i] for i in range(4))
return d_a, d_b
def identify(self):
tt = tuple(round(_sig(l)) for l in self.logits)
return _GATE_NAMES.get(tt, str(list(tt)))
def show(self):
s = self._sv()
print(f" {self.label:22s} "
f"({s[0]:.2f}, {s[1]:.2f}, {s[2]:.2f}, {s[3]:.2f})"
f" -> {self.identify()}")
# ── Learnable 1-bit full adder ────────────────────────────────────────────────
class LearnableFullAdder:
"""
A 1-bit full adder from 5 learnable gates, mirroring the ripple-carry
structure of alu.py but with no prescribed logic.
Target function (not told to the gates -- they discover it):
g1(a, b) -> s1 [XOR]
g2(a, b) -> c1 [AND]
g3(s1, cin) -> sum [XOR]
g4(s1, cin) -> c2 [AND]
g5(c1, c2) -> cout [OR]
"""
def __init__(self):
self.g1 = Gate('g1(a, b) -> s1')
self.g2 = Gate('g2(a, b) -> c1')
self.g3 = Gate('g3(s1, cin) -> sum')
self.g4 = Gate('g4(s1, cin) -> c2')
self.g5 = Gate('g5(c1, c2) -> cout')
def fwd(self, a, b, cin):
s1, k1 = self.g1.fwd(a, b)
c1, k2 = self.g2.fwd(a, b)
s2, k3 = self.g3.fwd(s1, cin)
c2, k4 = self.g4.fwd(s1, cin)
co, k5 = self.g5.fwd(c1, c2)
return s2, co, (s1, c1, k1, k2, k3, k4, k5)
def bwd(self, d_s2, d_co, cache, lr):
s1, c1, k1, k2, k3, k4, k5 = cache
d_c1, d_c2 = self.g5.bwd(d_co, k5, lr)
d_s1_4, _ = self.g4.bwd(d_c2, k4, lr)
d_s1_3, _ = self.g3.bwd(d_s2, k3, lr)
d_s1 = d_s1_3 + d_s1_4
self.g2.bwd(d_c1, k2, lr)
self.g1.bwd(d_s1, k1, lr)
def step(self, a, b, cin, ts, tc, lr):
s2, co, cache = self.fwd(a, b, cin)
loss = (s2 - ts) ** 2 + (co - tc) ** 2
self.bwd(2.0 * (s2 - ts), 2.0 * (co - tc), cache, lr)
return loss
def accuracy(self):
ok = 0
for a in range(2):
for b in range(2):
for cin in range(2):
ts = (a + b + cin) & 1
tc = (a + b + cin) >> 1
s2, co, _ = self.fwd(float(a), float(b), float(cin))
ok += (round(s2) == ts and round(co) == tc)
return ok / 8
# ── Demos ─────────────────────────────────────────────────────────────────────
def demo_soft_gates():
print("=" * 64)
print(" Part 1: Soft gates -- exact on {0,1}, differentiable everywhere")
print()
print(" soft_nand(a,b) = 1 - a*b")
print(" soft_xor(a,b) = a + b - 2*a*b")
print(" soft_or(a,b) = a + b - a*b")
print()
print(" Gradients flow through the entire ALU. You can backpropagate")
print(" through any program's execution -- find what inputs produce")
print(" a given output, or how sensitive the output is to each bit.")
print("=" * 64)
print()
print(" soft_add8 vs. alu.add8:")
from alu import add8
cases = [(3, 5), (127, 1), (200, 56), (255, 1), (0, 0), (99, 99)]
for a, b in cases:
soft_r, _ = soft_add8(a, b)
hard_r, _ = add8(a, b)
match = 'ok' if soft_r == hard_r else 'FAIL'
print(f" {a:3d} + {b:3d} = {soft_r:3d} (hard: {hard_r:3d}) {match}")
print()
print(" Identical results. The entire computation is now a polynomial")
print(" you can differentiate.")
def demo_learnable_adder():
print()
print("=" * 64)
print(" Part 2: Learnable full adder")
print()
print(" 5 gates with random truth tables. Trained only on the 8")
print(" (a, b, cin) -> (sum, carry) examples. No logic prescribed.")
print(" Gradient descent must discover XOR, AND, and OR from scratch.")
print("=" * 64)
random.seed(42)
fa = LearnableFullAdder()
examples = [
(a, b, cin, (a+b+cin) & 1, (a+b+cin) >> 1)
for a in range(2) for b in range(2) for cin in range(2)
]
print(f"\n Truth table columns: f(0,0) f(0,1) f(1,0) f(1,1)")
print(f"\n Epoch 0 acc={fa.accuracy():.0%} (random init)")
for g in [fa.g1, fa.g2, fa.g3, fa.g4, fa.g5]:
g.show()
LR = 0.4
prev_acc = 0.0
for epoch in range(1, 5001):
random.shuffle(examples)
for a, b, cin, ts, tc in examples:
fa.step(float(a), float(b), float(cin), float(ts), float(tc), LR)
acc = fa.accuracy()
if epoch % 500 == 0 or (acc == 1.0 and prev_acc < 1.0):
print(f"\n Epoch {epoch:4d} acc={acc:.0%}")
for g in [fa.g1, fa.g2, fa.g3, fa.g4, fa.g5]:
g.show()
if acc == 1.0:
break
prev_acc = acc
print()
found = [g.identify() for g in [fa.g1, fa.g2, fa.g3, fa.g4, fa.g5]]
print(f" Found: {' '.join(f'{n:<6}' for n in found)}")
print()
if fa.accuracy() == 1.0:
print(" Accuracy 100%. The gates implement correct binary addition.")
print()
print(" Note: gradient descent is not required to find the canonical")
print(" XOR/AND/OR circuit. Any valid assignment works -- here it found")
print(" an XNOR-based equivalent. The circuit is different but correct.")
print(" Multiple valid implementations exist; backprop found one of them.")
else:
print(f" Converged to {fa.accuracy():.0%} -- try a different random seed.")
if __name__ == '__main__':
demo_soft_gates()
demo_learnable_adder()