diff --git a/mlp/errors.py b/mlp/errors.py index a61b757..0b0d603 100644 --- a/mlp/errors.py +++ b/mlp/errors.py @@ -154,9 +154,9 @@ class CrossEntropySoftmaxError(object): Returns: Scalar error function value. """ - probs = np.exp(outputs - outputs.max(-1)[:, None]) - probs /= probs.sum(-1)[:, None] - return -np.mean(np.sum(targets * np.log(probs), axis=1)) + normOutputs = outputs - outputs.max(-1)[:, None] + logProb = normOutputs - np.log(np.sum(np.exp(normOutputs))(-1)[:, None]) + return -np.mean(np.sum(targets * logProb, axis=1)) def grad(self, outputs, targets): """Calculates gradient of error function with respect to outputs.