forked from acornprover/acorn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.py
More file actions
29 lines (25 loc) · 777 Bytes
/
training.py
File metadata and controls
29 lines (25 loc) · 777 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def one_pass(dataloader, model, criterion, optimizer=None):
"""
Do a pass through all the data.
Return the average loss.
If optimizer is provided, we do a round of training.
"""
if optimizer is None:
model.eval()
else:
model.train()
total_loss = 0
total_samples = 0
for inputs, labels in dataloader:
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels.unsqueeze(1))
if optimizer is not None:
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Track the loss
total_loss += loss.item() * inputs.size(0)
total_samples += inputs.size(0)
return total_loss / total_samples