diff --git a/mlp/models.py b/mlp/models.py index b2be888..63d034f 100644 --- a/mlp/models.py +++ b/mlp/models.py @@ -131,7 +131,7 @@ class MultipleLayerModel(object): inputs = activations[-i - 2] outputs = activations[-i - 1] grads_wrt_inputs = layer.bprop(inputs, outputs, grads_wrt_outputs) - if isinstance(layer, LayerWithParameters): + if isinstance(layer, LayerWithParameters) or isinstance(layer, StochasticLayerWithParameters): grads_wrt_params += layer.grads_wrt_params( inputs, grads_wrt_outputs)[::-1] grads_wrt_outputs = grads_wrt_inputs