forked from SeanNaren/deepspeech.torch
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathLoader.lua
More file actions
353 lines (289 loc) · 11.4 KB
/
Loader.lua
File metadata and controls
353 lines (289 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
require 'nn'
require 'torch'
require 'lmdb'
require 'xlua'
require 'paths'
tds = require 'tds'
threads = require 'threads'
-- get model specified methods into this module
local model_t
local cal_size
local get_min_width
local util = require 'Util'
--[[
this file defines Loader and loader:
- Loader returns different inds of nxt btach
- loader loads data from lmdb given the inds
NOTE:
- make sure calculateInputSizes() in DeepSpeechModel.lua is set correctly
--]]
torch.setdefaulttensortype('torch.FloatTensor')
local Loader = torch.class('Loader')
function Loader:__init(_dir, batch_size, feature, dataHeight, modelname)
--[[
input:
feature: is it spect or logfbank we are using
dataHeight: typically 129 for spect; 26 for logfbank
--]]
-- constants to indicate the loading style
self.DEFAULT = 1
self.SAMELEN = 2
self.SORTED = 3
self.RANDOM = 4
self.modelname = modelname
self.batch_size = batch_size
self.dataHeight = dataHeight
self.is_spect = feature == 'spect'
self.cnt = 1
self.sorted_inds = {}
self.len_num = 0 -- number of unique seqLengths
--self.min_width = get_min_width() --from DeepSpeech
local function preprocess()
-- assume the super folder is the lmdb root folder
local lmdb_path = _dir..'/../'
local stats = {} -- mean/std
--print('preparing mean and std of the dataset..')
if paths.filep(lmdb_path..'mean_std') then
--print('found previously saved stats..')
stats = torch.load(lmdb_path..'mean_std')
else
print('did not find previously saved stats, generating..')
if feature == 'spect' then
util.get_mean_std(lmdb_path)
else
util.get_mean_std(lmdb_path, dataHeight)
end
stats = torch.load(lmdb_path..'mean_std')
end
return stats
end
local function init() require('Mapper') require('lmdb') tds = require 'tds' end
local function main(idx)
torch.manualSeed(idx)
torch.setnumthreads(1)
-- get model specified methods
model_t = require(modelname)
_G.cal_size = model_t[2]
_G.get_min_width = model_t[3]
_G.db_spect = lmdb.env { Path = _dir .. '/spect', Name = 'spect' }
_G.db_label = lmdb.env { Path = _dir .. '/label', Name = 'label' }
_G.db_trans = lmdb.env { Path = _dir .. '/trans', Name = 'trans' }
-- get the size of lmdb
_G.db_spect:open()
_G.db_label:open()
_G.db_trans:open()
local l1 = _G.db_spect:stat()['entries']
local l2 = _G.db_label:stat()['entries']
local l3 = _G.db_trans:stat()['entries']
assert(l1 == l2 and l2 == l3, 'data sizes in each lmdb must agree')
_G.db_spect:close()
_G.db_label:close()
_G.db_trans:close()
_G.stats = preprocess()
return l1
end
local pool, lmdb_size = threads.Threads(1, init, main)
self.pool = pool
self.lmdb_size = lmdb_size[1][1]
end
function Loader:prep_sorted_inds()
--[[
prep a table for sorted inds, can detect previously saved table in lmdb folder
--]]
print('preparing sorted indices..')
local indicesFilePath = self._dir .. '/' .. 'sorted_inds_' .. self.min_width
-- check if there is previously saved inds
if paths.filep(indicesFilePath) then
print('found previously saved inds..')
self.sorted_inds = torch.load(indicesFilePath)
print('original size: '..self.lmdb_size..' valid data: '..#self.sorted_inds)
self.lmdb_size = #self.sorted_inds
return
end
-- if not make a new one
print('did not find previously saved indices, generating.')
model_t = require(modelname)
cal_size = model_t[2]
get_min_width = model_t[3]
self.db_spect = lmdb.env { Path = _dir .. '/spect', Name = 'spect' }
self.db_label = lmdb.env { Path = _dir .. '/label', Name = 'label' }
self.db_trans = lmdb.env { Path = _dir .. '/trans', Name = 'trans' }
self.db_spect:open(); local txn = self.db_spect:txn(true)
self.db_label:open(); local txn_label = self.db_label:txn(true)
self.lmdb_size = self.db_spect:stat()['entries']
local lengths = {}
-- those shorter than min_width are ignored
local true_size = 0
for i = 1, self.lmdb_size do
local lengthOfAudio
if self.is_spect then
lengthOfAudio = txn:get(i):size(2)
else
lengthOfAudio = txn:get(i, true):size(1) / (4*self.dataHeight)
end
local lengthOfLabel = #(torch.deserialize(txn_label:get(i)))
if lengthOfAudio >= self.min_width and cal_size(lengthOfAudio) >= lengthOfLabel then
true_size = true_size + 1
table.insert(self.sorted_inds, { i, lengthOfAudio })
if lengths[lengthOfAudio] == nil then lengths[lengthOfAudio] = true end
if i % 100 == 0 then xlua.progress(i, self.lmdb_size) end
end
end
print('original size: '..self.lmdb_size..' valid data: '..true_size)
self.lmdb_size = true_size -- set size to true size
txn:abort(); self.db_spect:close()
txn_label:abort(); self.db_label:close()
local function comp(a, b) return a[2] < b[2] end
table.sort(self.sorted_inds, comp)
for _ in pairs(lengths) do self.len_num = self.len_num + 1 end -- number of different seqLengths
torch.save(indicesFilePath, self.sorted_inds)
end
function Loader:nxt_sorted_inds()
local meta_inds = self:nxt_inds()
local inds = meta_inds:clone()
for i = 1, size(inds,1) do
inds[i] = self.sorted_inds[meta_inds[i]][1]
end
return inds
end
function Loader:nxt_random_inds()
local bid = self.perm[math.floor(self.idx / self.batch_size)+1]-1
local start = (self.offset + bid * self.batch_size)%self.lmdb_size+1
local inds = torch.linspace(start, start+self.batch_size-1, self.batch_size)
local overflow = inds[-1] - self.lmdb_size
if overflow > 0 then
inds:narrow(1, self.batch_size-overflow+1, overflow):copy(torch.linspace(1, overflow, overflow))
self.cnt = overflow + 1
end
return inds
end
function Loader:nxt_same_len_inds()
--[[
return inds with same seqLength, a solution before zero-masking can work
--]]
local _len = self.sorted_inds[self.cnt][2]
while (self.cnt <= self.lmdb_size and self.sorted_inds[self.cnt][2] == _len) do
-- NOTE: true index store in table, instead of cnt
table.insert(inds, self.sorted_inds[self.cnt][1])
self.cnt = self.cnt + 1
end
if self.cnt > self.lmdb_size then self.cnt = 1 end
return inds
end
function Loader:nxt_inds()
--[[
return indices of the next batch
--]]
local inds = torch.linspace(self.cnt, self.cnt+self.batch_size-1, self.batch_size)
self.cnt = self.cnt + self.batch_size
local overflow = inds[-1] - self.lmdb_size
if overflow > 0 then
inds:narrow(1, self.batch_size-overflow+1, overflow):copy(torch.linspace(1, overflow, overflow))
self.cnt = overflow + 1
end
return inds
end
function Loader:convert_tensor(btensor)
--[[
convert a 1d byte tensor to 2d float tensor.
--]]
local num = btensor:size(1) / 4 -- assume real data is float
local s = torch.FloatStorage(num, tonumber(torch.data(btensor, true)))
assert(num % self.dataHeight == 0, 'something wrong with the tensor dims')
return torch.FloatTensor(s, 1, torch.LongStorage{self.dataHeight, num / self.dataHeight})
end
function Loader:nxt_batch(mode)
--[[
return a batch by loading from lmdb just-in-time
input:
mode: should be Loader.DEFAULT/SAMELEN/SORTED/RANDOM; USE ONLY ONE MODE FOR ONE TRAINING
flag: indicates whether to load trans
TODO we allocate 2 * batch_size space
--]]
local pool = self.pool
-- for random idx
self.perm = torch.randperm(math.ceil(self.lmdb_size/self.batch_size))
self.offset = torch.random(self.lmdb_size)
--self.offset = 1
local sample = nil
self.idx = 1 -- index the countings//// cnt indicates the position
local function enqueue()
while self.idx <= self.lmdb_size and pool:acceptsjob() do
-- gen index for this iter batch
local indices
if mode == self.SAMELEN then
assert(#self.sorted_inds > 0, 'call prep_sorted_inds before nxt_batch')
indices = self:nxt_same_len_inds()
elseif mode == self.SORTED then
assert(#self.sorted_inds > 0, 'call prep_sorted_inds before nxt_batch')
indices = self:nxt_sorted_inds()
elseif mode == self.RANDOM then
indices = self:nxt_random_inds()
else -- default
indices = self:nxt_inds()
end
pool:addjob(
function(indices)
local tensor_list = tds.Vec()
local label_list = {}
local sizes_array = torch.Tensor(#indices)
local labelcnt = 0
local max_w = 0
local h = 0
_G.db_spect:open(); local txn_spect = _G.db_spect:txn(true) -- readonly
_G.db_label:open(); local txn_label = _G.db_label:txn(true)
is_spect = true
for i, idx in ipairs(indices:totable()) do
local tensor
if is_spect then
tensor = txn_spect:get(idx)
else
-- tensor = self:convert_tensor(txn_spect:get(idx, true))
end
local label = torch.deserialize(txn_label:get(idx))
h = tensor:size(1)
sizes_array[i] = tensor:size(2)
if max_w < tensor:size(2) then max_w = tensor:size(2) end -- find the max len in this batch
tensor_list:insert(tensor)
table.insert(label_list, label)
labelcnt = labelcnt + #label
end
-- store tensors into a fixed len tensor_array TODO should find a better way to do this
local tensor_array = torch.Tensor(indices:size(1), 1, h, max_w):zero()
for i, tensor in ipairs(tensor_list) do
tensor_array[i][1]:narrow(2, 1, tensor:size(2)):copy(tensor)
end
tensor_array:csub(_G.stats[1])
tensor_array:div(_G.stats[2])
txn_spect:abort(); _G.db_spect:close()
txn_label:abort(); _G.db_label:close()
return {
inputs = tensor_array,
label = label_list,
sizes = sizes_array,
labelcnt = labelcnt,
}
end,
function(_sample_)
sample = _sample_
end,
indices)
self.idx = self.idx + indices:size(1)
end
end
local n = 0
local function loop()
enqueue()
if not pool:hasjob() then
return nil
end
pool:dojob()
if pool:haserror() then
pool:synchronize()
end
enqueue()
n = n + 1
return n, sample
end
return loop
end