"""Part of the training engine related to Python generators of array data.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
warnings
import
numpy
as
np
from
..
import
backend
as
K
from
..utils.data_utils
import
Sequence
from
..utils.data_utils
import
GeneratorEnqueuer
from
..utils.data_utils
import
OrderedEnqueuer
from
..utils.generic_utils
import
Progbar
from
..utils.generic_utils
import
to_list
from
..utils.generic_utils
import
unpack_singleton
from
..
import
callbacks
as
cbks
def
fit_generator
(
model
,
generator
,
steps_per_epoch
=
None
,
epochs
=
1
,
verbose
=
1
,
callbacks
=
None
,
validation_data
=
None
,
validation_steps
=
None
,
class_weight
=
None
,
max_queue_size
=
10
,
workers
=
1
,
use_multiprocessing
=
False
,
shuffle
=
True
,
initial_epoch
=
0
):
"""See docstring for `Model.fit_generator`."""
wait_time
=
0.01
# in seconds
epoch
=
initial_epoch
do_validation
=
bool
(
validation_data
)
model
.
_make_train_function
()
if
do_validation
:
model
.
_make_test_function
()
is_sequence
=
isinstance
(
generator
,
Sequence
)
if
not
is_sequence
and
use_multiprocessing
and
workers
>
1
:
warnings
.
warn
(
UserWarning
(
'Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'
))
if
steps_per_epoch
is
None
:
if
is_sequence
:
steps_per_epoch
=
len
(
generator
)
else
:
raise
ValueError
(
'`steps_per_epoch=None` is only valid for a'
' generator based on the '
'`keras.utils.Sequence`'
' class. Please specify `steps_per_epoch` '
'or use the `keras.utils.Sequence` class.'
)
# python 2 has 'next', 3 has '__next__'
# avoid any explicit version checks
val_gen
=
(
hasattr
(
validation_data
,
'next'
)
or
hasattr
(
validation_data
,
'__next__'
)
or
isinstance
(
validation_data
,
Sequence
))
if
(
val_gen
and
not
isinstance
(
validation_data
,
Sequence
)
and
not
validation_steps
):
raise
ValueError
(
'`validation_steps=None` is only valid for a'
' generator based on the `keras.utils.Sequence`'
' class. Please specify `validation_steps` or use'
' the `keras.utils.Sequence` class.'
)
# Prepare display labels.
out_labels
=
model
.
metrics_names
callback_metrics
=
out_labels
+
[
'val_'
+
n
for
n
in
out_labels
]
# prepare callbacks
model
.
history
=
cbks
.
History
()
_callbacks
=
[
cbks
.
BaseLogger
(
stateful_metrics
=
model
.
stateful_metric_names
)]
if
verbose
:
_callbacks
.
append
(
cbks
.
ProgbarLogger
(
count_mode
=
'steps'
,
stateful_metrics
=
model
.
stateful_metric_names
))
_callbacks
+=
(
callbacks
or
[])
+
[
model
.
history
]
callbacks
=
cbks
.
CallbackList
(
_callbacks
)
# it's possible to callback a different model than self:
if
hasattr
(
model
,
'callback_model'
)
and
model
.
callback_model
:
callback_model
=
model
.
callback_model
else
:
callback_model
=
model
callbacks
.
set_model
(
callback_model
)
callbacks
.
set_params
({
'epochs'
:
epochs
,
'steps'
:
steps_per_epoch
,
'verbose'
:
verbose
,
'do_validation'
:
do_validation
,
'metrics'
:
callback_metrics
,
})
callbacks
.
on_train_begin
()
enqueuer
=
None
val_enqueuer
=
None
try
:
if
do_validation
:
if
val_gen
and
workers
>
0
:
# Create an Enqueuer that can be reused
val_data
=
validation_data
if
isinstance
(
val_data
,
Sequence
):
val_enqueuer
=
OrderedEnqueuer
(
val_data
,
use_multiprocessing
=
use_multiprocessing
)
validation_steps
=
len
(
val_data
)
else
:
val_enqueuer
=
GeneratorEnqueuer
(
val_data
,
use_multiprocessing
=
use_multiprocessing
)
val_enqueuer
.
start
(
workers
=
workers
,
max_queue_size
=
max_queue_size
)
val_enqueuer_gen
=
val_enqueuer
.
get
()
elif
val_gen
:
val_data
=
validation_data
if
isinstance
(
val_data
,
Sequence
):
val_enqueuer_gen
=
iter
(
val_data
)
else
:
val_enqueuer_gen
=
val_data
else
:
# Prepare data for validation
if
len
(
validation_data
)
==
2
:
val_x
,
val_y
=
validation_data
val_sample_weight
=
None
elif
len
(
validation_data
)
==
3
:
val_x
,
val_y
,
val_sample_weight
=
validation_data
else
:
raise
ValueError
(
'`validation_data` should be a tuple '
'`(val_x, val_y, val_sample_weight)` '
'or `(val_x, val_y)`. Found: '
+
str
(
validation_data
))
val_x
,
val_y
,
val_sample_weights
=
model
.
_standardize_user_data
(
val_x
,
val_y
,
val_sample_weight
)
val_data
=
val_x
+
val_y
+
val_sample_weights
if
model
.
uses_learning_phase
and
not
isinstance
(
K
.
learning_phase
(),
int
):
val_data
+=
[
0.
]
for
cbk
in
callbacks
:
cbk
.
validation_data
=
val_data
if
workers
>
0
:
if
is_sequence
:
enqueuer
=
OrderedEnqueuer
(
generator
,
use_multiprocessing
=
use_multiprocessing
,
shuffle
=
shuffle
)
else
:
enqueuer
=
GeneratorEnqueuer
(
generator
,
use_multiprocessing
=
use_multiprocessing
,
wait_time
=
wait_time
)
enqueuer
.
start
(
workers
=
workers
,
max_queue_size
=
max_queue_size
)
output_generator
=
enqueuer
.
get
()
else
:
if
is_sequence
:
output_generator
=
iter
(
generator
)
else
:
output_generator
=
generator
callback_model
.
stop_training
=
False
# Construct epoch logs.
epoch_logs
=
{}
while
epoch
<
epochs
:
for
m
in
model
.
stateful_metric_functions
:
m
.
reset_states
()
callbacks
.
on_epoch_begin
(
epoch
)
steps_done
=
0
batch_index
=
0
while
steps_done
<
steps_per_epoch
:
generator_output
=
next
(
output_generator
)
if
not
hasattr
(
generator_output
,
'__len__'
):
raise
ValueError
(
'Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: '
+
str
(
generator_output
))
if
len
(
generator_output
)
==
2
:
x
,
y
=
generator_output
sample_weight
=
None
elif
len
(
generator_output
)
==
3
:
x
,
y
,
sample_weight
=
generator_output
else
:
raise
ValueError
(
'Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: '
+
str
(
generator_output
))
# build batch logs
batch_logs
=
{}
if
x
is
None
or
len
(
x
)
==
0
:
# Handle data tensors support when no input given
# step-size = 1 for data tensors
batch_size
=
1
elif
isinstance
(
x
,
list
):
batch_size
=
x
[
0
]
.
shape
[
0
]
elif
isinstance
(
x
,
dict
):
batch_size
=
list
(
x
.
values
())[
0
]
.
shape
[
0
]
else
:
batch_size
=
x
.
shape
[
0
]
batch_logs
[
'batch'
]
=
batch_index
batch_logs
[
'size'
]
=
batch_size
callbacks
.
on_batch_begin
(
batch_index
,
batch_logs
)
outs
=
model
.
train_on_batch
(
x
,
y
,
sample_weight
=
sample_weight
,
class_weight
=
class_weight
)
outs
=
to_list
(
outs
)
for
l
,
o
in
zip
(
out_labels
,
outs
):
batch_logs
[
l
]
=
o
callbacks
.
on_batch_end
(
batch_index
,
batch_logs
)
batch_index
+=
1
steps_done
+=
1
# Epoch finished.
if
steps_done
>=
steps_per_epoch
and
do_validation
:
if
val_gen
:
val_outs
=
model
.
evaluate_generator
(
val_enqueuer_gen
,
validation_steps
,
workers
=
0
)
else
:
# No need for try/except because
# data has already been validated.
val_outs
=
model
.
evaluate
(
val_x
,
val_y
,
batch_size
=
batch_size
,
sample_weight
=
val_sample_weights
,
verbose
=
0
)
val_outs
=
to_list
(
val_outs
)
# Same labels assumed.
for
l
,
o
in
zip
(
out_labels
,
val_outs
):
epoch_logs
[
'val_'
+
l
]
=
o
if
callback_model
.
stop_training
:
break
callbacks
.
on_epoch_end
(
epoch
,
epoch_logs
)
epoch
+=
1
if
callback_model
.
stop_training
:
break
finally
:
try
:
if
enqueuer
is
not
None
:
enqueuer
.
stop
()
finally
:
if
val_enqueuer
is
not
None
:
val_enqueuer
.
stop
()
callbacks
.
on_train_end
()
return
model
.
history
def
evaluate_generator
(
model
,
generator
,
steps
=
None
,
max_queue_size
=
10
,
workers
=
1
,
use_multiprocessing
=
False
,
verbose
=
0
):
"""See docstring for `Model.evaluate_generator`."""
model
.
_make_test_function
()
stateful_metric_indices
=
[]
if
hasattr
(
model
,
'metrics'
):
for
m
in
model
.
stateful_metric_functions
:
m
.
reset_states
()
stateful_metric_indices
=
[
i
for
i
,
name
in
enumerate
(
model
.
metrics_names
)
if
str
(
name
)
in
model
.
stateful_metric_names
]
else
:
stateful_metric_indices
=
[]
steps_done
=
0
wait_time
=
0.01
outs_per_batch
=
[]
batch_sizes
=
[]
is_sequence
=
isinstance
(
generator
,
Sequence
)
if
not
is_sequence
and
use_multiprocessing
and
workers
>
1
:
warnings
.
warn
(
UserWarning
(
'Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'
))
if
steps
is
None
:
if
is_sequence
:
steps
=
len
(
generator
)
else
:
raise
ValueError
(
'`steps=None` is only valid for a generator'
' based on the `keras.utils.Sequence` class.'
' Please specify `steps` or use the'
' `keras.utils.Sequence` class.'
)
enqueuer
=
None
try
:
if
workers
>
0
:
if
is_sequence
:
enqueuer
=
OrderedEnqueuer
(
generator
,
use_multiprocessing
=
use_multiprocessing
)
else
:
enqueuer
=
GeneratorEnqueuer
(
generator
,
use_multiprocessing
=
use_multiprocessing
,
wait_time
=
wait_time
)
enqueuer
.
start
(
workers
=
workers
,
max_queue_size
=
max_queue_size
)
output_generator
=
enqueuer
.
get
()
else
:
if
is_sequence
:
output_generator
=
iter
(
generator
)
else
:
output_generator
=
generator
if
verbose
==
1
:
progbar
=
Progbar
(
target
=
steps
)
while
steps_done
<
steps
:
generator_output
=
next
(
output_generator
)
if
not
hasattr
(
generator_output
,
'__len__'
):
raise
ValueError
(
'Output of generator should be a tuple '
'(x, y, sample_weight) '
'or (x, y). Found: '
+
str
(
generator_output
))
if
len
(
generator_output
)
==
2
:
x
,
y
=
generator_output
sample_weight
=
None
elif
len
(
generator_output
)
==
3
:
x
,
y
,
sample_weight
=
generator_output
else
:
raise
ValueError
(
'Output of generator should be a tuple '
'(x, y, sample_weight) '
'or (x, y). Found: '
+
str
(
generator_output
))
outs
=
model
.
test_on_batch
(
x
,
y
,
sample_weight
=
sample_weight
)
outs
=
to_list
(
outs
)
outs_per_batch
.
append
(
outs
)
if
x
is
None
or
len
(
x
)
==
0
:
# Handle data tensors support when no input given
# step-size = 1 for data tensors
batch_size
=
1
elif
isinstance
(
x
,
list
):
batch_size
=
x
[
0
]
.
shape
[
0
]
elif
isinstance
(
x
,
dict
):
batch_size
=
list
(
x
.
values
())[
0
]
.
shape
[
0
]
else
:
batch_size
=
x
.
shape
[
0
]
if
batch_size
==
0
:
raise
ValueError
(
'Received an empty batch. '
'Batches should contain '
'at least one item.'
)
steps_done
+=
1
batch_sizes
.
append
(
batch_size
)
if
verbose
==
1
:
progbar
.
update
(
steps_done
)
finally
:
if
enqueuer
is
not
None
:
enqueuer
.
stop
()
averages
=
[]
for
i
in
range
(
len
(
outs
)):
if
i
not
in
stateful_metric_indices
:
averages
.
append
(
np
.
average
([
out
[
i
]
for
out
in
outs_per_batch
],
weights
=
batch_sizes
))
else
:
averages
.
append
(
np
.
float64
(
outs_per_batch
[
-
1
][
i
]))
return
unpack_singleton
(
averages
)
def
predict_generator
(
model
,
generator
,
steps
=
None
,
max_queue_size
=
10
,
workers
=
1
,
use_multiprocessing
=
False
,
verbose
=
0
):
"""See docstring for `Model.predict_generator`."""
model
.
_make_predict_function
()
steps_done
=
0
wait_time
=
0.01
all_outs
=
[]
is_sequence
=
isinstance
(
generator
,
Sequence
)
if
not
is_sequence
and
use_multiprocessing
and
workers
>
1
:
warnings
.
warn
(
UserWarning
(
'Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'
))
if
steps
is
None
:
if
is_sequence
:
steps
=
len
(
generator
)
else
:
raise
ValueError
(
'`steps=None` is only valid for a generator'
' based on the `keras.utils.Sequence` class.'
' Please specify `steps` or use the'
' `keras.utils.Sequence` class.'
)
enqueuer
=
None
try
:
if
workers
>
0
:
if
is_sequence
:
enqueuer
=
OrderedEnqueuer
(
generator
,
use_multiprocessing
=
use_multiprocessing
)
else
:
enqueuer
=
GeneratorEnqueuer
(
generator
,
use_multiprocessing
=
use_multiprocessing
,
wait_time
=
wait_time
)
enqueuer
.
start
(
workers
=
workers
,
max_queue_size
=
max_queue_size
)
output_generator
=
enqueuer
.
get
()
else
:
if
is_sequence
:
output_generator
=
iter
(
generator
)
else
:
output_generator
=
generator
if
verbose
==
1
:
progbar
=
Progbar
(
target
=
steps
)
while
steps_done
<
steps
:
generator_output
=
next
(
output_generator
)
if
isinstance
(
generator_output
,
tuple
):
# Compatibility with the generators
# used for training.
if
len
(
generator_output
)
==
2
:
x
,
_
=
generator_output
elif
len
(
generator_output
)
==
3
:
x
,
_
,
_
=
generator_output
else
:
raise
ValueError
(
'Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: '
+
str
(
generator_output
))
else
:
# Assumes a generator that only
# yields inputs (not targets and sample weights).
x
=
generator_output
outs
=
model
.
predict_on_batch
(
x
)
outs
=
to_list
(
outs
)
if
not
all_outs
:
for
out
in
outs
:
all_outs
.
append
([])
for
i
,
out
in
enumerate
(
outs
):
all_outs
[
i
]
.
append
(
out
)
steps_done
+=
1
if
verbose
==
1
:
progbar
.
update
(
steps_done
)
finally
:
if
enqueuer
is
not
None
:
enqueuer
.
stop
()
if
len
(
all_outs
)
==
1
:
if
steps_done
==
1
:
return
all_outs
[
0
][
0
]
else
:
return
np
.
concatenate
(
all_outs
[
0
])
if
steps_done
==
1
:
return
[
out
[
0
]
for
out
in
all_outs
]
else
:
return
[
np
.
concatenate
(
out
)
for
out
in
all_outs
]