From b9147c269cbfb452d74869a765da4c82589765f8 Mon Sep 17 00:00:00 2001 From: AntreasAntoniou Date: Mon, 13 Nov 2017 23:52:39 +0000 Subject: [PATCH] Numerically stable softmax --- mlp/errors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlp/errors.py b/mlp/errors.py index a61b757..0b0d603 100644 --- a/mlp/errors.py +++ b/mlp/errors.py @@ -154,9 +154,9 @@ class CrossEntropySoftmaxError(object): Returns: Scalar error function value. """ - probs = np.exp(outputs - outputs.max(-1)[:, None]) - probs /= probs.sum(-1)[:, None] - return -np.mean(np.sum(targets * np.log(probs), axis=1)) + normOutputs = outputs - outputs.max(-1)[:, None] + logProb = normOutputs - np.log(np.sum(np.exp(normOutputs))(-1)[:, None]) + return -np.mean(np.sum(targets * logProb, axis=1)) def grad(self, outputs, targets): """Calculates gradient of error function with respect to outputs.