diff --git a/mlp/errors.py b/mlp/errors.py index 8412d4c..a57decc 100644 --- a/mlp/errors.py +++ b/mlp/errors.py @@ -170,7 +170,9 @@ class CrossEntropySoftmaxError(object): Returns: Gradient of error function with respect to outputs. """ - probs = np.exp(outputs) + # subtract max inside exponential to improve numerical stability - + # when we divide through by sum this term cancels + probs = np.exp(outputs - outputs.max(-1)[:, None]) probs /= probs.sum(-1)[:, None] return (probs - targets) / outputs.shape[0]