diff --git a/mlp/errors.py b/mlp/errors.py index a141784..3f0ae4f 100644 --- a/mlp/errors.py +++ b/mlp/errors.py @@ -155,7 +155,7 @@ class CrossEntropySoftmaxError(object): Scalar error function value. """ normOutputs = outputs - outputs.max(-1)[:, None] - logProb = normOutputs - np.log(np.sum(np.exp(normOutputs))) + logProb = normOutputs - np.log(np.sum(np.exp(normOutputs), axis=-1)[:, None]) return -np.mean(np.sum(targets * logProb, axis=1)) def grad(self, outputs, targets):