-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
44 lines (31 loc) · 1.09 KB
/
main.py
File metadata and controls
44 lines (31 loc) · 1.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# import numpy
import torch
import logging
import argparse
import data_utils
import dataset
from dataset import FrEnDataset
# logging config
logging.basicConfig(filename='./logs/test.log', level=logging.INFO)
# torch version 1.6.0+cuda 10.1
logging.info('Torch Version {}'.format(torch.__version__))
# test path => TODO argparse
DATA_PATH = './data/fr-en'
if __name__ == '__main__':
# Load Dataset
data_utils.load_data_from_file(DATA_PATH, DATA_PATH)
en_dataset = FrEnDataset(txt_files='./data/fr-en/fr-en.en.txt',
root_dir='./data/fr-en/')
fr_dataset = FrEnDataset(txt_files='./data/fr-en/fr-en.fr.txt',
root_dir='./data/fr-en/')
en_loader = dataset.DataLoader(en_dataset, batch_size=4,
shuffle=True, num_workers=1)
# for i_batch, sample_batched in enumerate(dataloader):
# print(i_batch, sample_batched.size())
#
# # observe 4th batch and stop.
# if i_batch == 3:
# show_text_batch(sample_batched)
# break
# Train
# Test