Newton’s Method Optimization#

Function in One Variable#

You will use Newton’s method to optimize a function \(f\left(x\right)\). Aiming to find a point, where the derivative equals to zero, you need to start from some initial point \(x_0\), calculate first and second derivatives (\(f'(x_0)\) and \(f''(x_0)\)) and step to the next point using the expression:

\[x_1 = x_0 - \frac{f'(x_0)}{f''(x_0)} \tag{1}\]

Repeat the process iteratively. Number of iterations \(n\) is usually also a parameter.

Let’s optimize function \(f\left(x\right)=e^x - \log(x)\) (defined for \(x>0\)) using Newton’s method. To implement it in the code, define function \(f\left(x\right)=e^x - \log(x)\), its first and second derivatives:

\[f(x) = e^x - \log(x)\]
\[f'(x) = e^x - \frac{1}{x}\]
\[f''(x) = e^x + \frac{1}{x^2}\]
import numpy as np
import matplotlib.pyplot as plt
import jax, jax.numpy as jnp

def f_example_1(x):
    return jnp.exp(x) - jnp.log(x)

dfdx_example_1 = jax.grad(f_example_1)
d2fdx2_example_1 = jax.grad(dfdx_example_1)

x_0 = 1.6
print(f"f({x_0}) = {f_example_1(x_0)}")
print(f"f'({x_0}) = {dfdx_example_1(x_0)}")
print(f"f''({x_0}) = {d2fdx2_example_1(x_0)}")

Hide code cell output

f(1.6) = 4.483028888702393
f'(1.6) = 4.328032493591309
f''(1.6) = 5.343657493591309

Implement Newton’s method described above.

import numpy as np
import matplotlib.pyplot as plt
import jax, jax.numpy as jnp

def newton_method(dfdx, d2fdx2, x, num_iterations=100):
    for iteration in range(num_iterations):
        x = x - dfdx(x) /d2fdx2(x)
        print(x)
    return x
num_iterations_example_1 = 25; x_initial = 1.6
newtons_example_1 = newton_method(dfdx_example_1, d2fdx2_example_1, x_initial, num_iterations_example_1)
print("Newton's method result: x_min =", newtons_example_1)

Hide code cell output

0.7900618
0.5436325
0.5665913
0.567143
0.5671433
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
0.56714326
Newton's method result: x_min = 0.56714326

Let’s compare with Gradient Descent method.

import numpy as np
import matplotlib.pyplot as plt
import jax, jax.numpy as jnp

def gradient_descent(dfdx, x, learning_rate=0.1, num_iterations=100):
    for iteration in range(num_iterations):
        x = x - learning_rate * dfdx(x)
        print(x)
    return x
num_iterations = 35; learning_rate = 0.25; x_initial = 1.6
gd_example_1 = gradient_descent(dfdx_example_1, x_initial, learning_rate, num_iterations)
print("Gradient descent result: x_min =", gd_example_1) 

Hide code cell output

0.5179919
0.5809616
0.56434345
0.56776285
0.5670086
0.56717265
0.5671369
0.5671447
0.56714296
0.5671434
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
0.56714326
0.5671433
Gradient descent result: x_min = 0.5671433

Function in Two Variables#

In case of a function in two variables, Newton’s method will require even more computations. Starting from the intial point \((x_0, y_0)\), the step to the next point shoud be done using the expression:

\[\begin{split} \begin{bmatrix}x_1 \\ y_1\end{bmatrix} = \begin{bmatrix}x_0 \\ y_0\end{bmatrix} - H^{-1}\left(x_0, y_0\right)\nabla f\left(x_0, y_0\right) \end{split}\]

where \(H^{-1}\left(x_0, y_0\right)\) is an inverse of a Hessian matrix at point \((x_0, y_0)\) and \(\nabla f\left(x_0, y_0\right)\) is the gradient at that point.

Let’s implement that in the code. Define the function \(f(x, y)\) like in the videos, its gradient and Hessian:

\[f(x, y) = x^4 + 0.8 y^4 + 4x^2 + 2y^2 - xy - 0.2x^2y\]
\[\begin{split}\nabla f(x, y) = \begin{bmatrix}4x^3 + 8x - y - 0.4xy \\ 3.2y^3 + 4y - x - 0.2x^2\end{bmatrix}\end{split}\]
\[\begin{split}H(x, y) = \begin{bmatrix}12x^2 + 8 - 0.4y & -1 - 0.4x \\ -1 - 0.4x & 9.6y^2 + 4\end{bmatrix}\end{split}\]
import jax
import jax.numpy as jnp

def f_example_2(xy):
    x, y = xy[0], xy[1]
    return x**4 + 0.8*y**4 + 4*x**2 + 2*y**2 - x*y - 0.2*x**2*y

grad_f_example_2 = jax.grad(f_example_2)
hessian_f_example_2 = jax.hessian(f_example_2)

# Example usage:
xy_0 = jnp.array([4.0, 4.0])
print("f(4,4) =", f_example_2(xy_0))
print("grad f(4,4) =", grad_f_example_2(xy_0))
print("Hessian f(4,4) =", hessian_f_example_2(xy_0))

Hide code cell output

f(4,4) = 528.0
grad f(4,4) = [277.6 213.6]
Hessian f(4,4) = [[198.4  -2.6]
 [ -2.6 157.6]]
def newton_method_2(hessian_f, grad_f, x_y, num_iterations=100):
    for iteration in range(num_iterations):
        x_y = x_y - np.linalg.inv(hessian_f(x_y)) @ grad_f(x_y)
        print(x_y.T)
    return x_y

num_iterations_example_2 = 25
x_y_initial = jnp.array([4.0, 4.0])
newtons_example_2 = newton_method_2(hessian_f_example_2, grad_f_example_2,
                                    x_y_initial, num_iterations=num_iterations_example_2)
print("Newton's method result: x_min, y_min =", newtons_example_2)

Hide code cell output

[2.5827386 2.6212888]
[1.5922568 1.6748161]
[0.870589 1.001821]
[0.33519423 0.49397618]
[0.04123583 0.12545902]
[0.00019466 0.00301028]
[-2.4869223e-08  3.5157427e-08]
[-1.7763568e-15  0.0000000e+00]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
Newton's method result: x_min, y_min = [0. 0.]

Compare with Gradient Descent method.

def gradient_descent_2(grad_f, x_y, learning_rate=0.1, num_iterations=100):
    for iteration in range(num_iterations):
        x_y = x_y - learning_rate * grad_f(x_y)
        print(x_y.T)
    return x_y
num_iterations_2 = 300; learning_rate_2 = 0.01; x_y_initial = np.array([4., 4.])
# num_iterations_2 = 300; learning_rate_2 = 0.03; x_y_initial = np.array([[4], [4]])
gd_example_2 = gradient_descent_2(grad_f_example_2, x_y_initial, learning_rate_2, num_iterations_2)
print("Gradient descent result: x_min, y_min =", gd_example_2) 

Hide code cell output

[1.224     1.8640001]
[1.0804955 1.5974296]
[0.9664763 1.416231 ]
[0.872685  1.2802172]
[0.7935566 1.1721154]
[0.72552466 1.0828958 ]
[0.6661781 1.0072521]
[0.6138146  0.94181013]
[0.5671893 0.8842969]
[0.5253647  0.83311224]
[0.4876172  0.78708965]
[0.45337626 0.74535424]
[0.42218375 0.70723426]
[0.39366573 0.67220336]
[0.3675127  0.63984215]
[0.34346518 0.6098113 ]
[0.32130316 0.58183277]
[0.3008382 0.555676 ]
[0.2819075 0.5311478]
[0.26436916 0.5080848 ]
[0.24809869 0.48634768]
[0.23298606 0.46581665]
[0.21893358 0.446388  ]
[0.20585394 0.42797133]
[0.19366881 0.41048738]
[0.18230762 0.39386624]
[0.17170653 0.37804592]
[0.16180763 0.36297116]
[0.1525582 0.3485925]
[0.14391015 0.3348654 ]
[0.13581954 0.3217497 ]
[0.12824605 0.30920893]
[0.12115271 0.2972099 ]
[0.11450549 0.28572226]
[0.10827309 0.27471823]
[0.10242663 0.26417223]
[0.09693947 0.25406063]
[0.091787   0.24436162]
[0.08694644 0.23505495]
[0.08239673 0.22612175]
[0.07811836 0.21754445]
[0.07409324 0.20930661]
[0.07030462 0.20139283]
[0.06673691 0.19378866]
[0.06337569 0.1864805 ]
[0.06020753 0.17945556]
[0.05721997 0.17270173]
[0.05440142 0.16620758]
[0.05174111 0.15996228]
[0.04922901 0.15395558]
[0.04685579 0.14817773]
[0.04461276 0.14261946]
[0.04249183 0.13727196]
[0.04048547 0.13212684]
[0.03858665 0.12717609]
[0.03678881 0.12241207]
[0.03508585 0.11782748]
[0.03347206 0.11341536]
[0.03194214 0.10916902]
[0.0304911  0.10508209]
[0.02911432 0.10114844]
[0.02780745 0.09736223]
[0.02656644 0.09371783]
[0.02538751 0.09020985]
[0.02426712 0.08683313]
[0.02320194 0.0835827 ]
[0.02218887 0.08045381]
[0.021225   0.07744186]
[0.02030761 0.07454248]
[0.01943415 0.07175142]
[0.01860221 0.06906464]
[0.01780957 0.06647822]
[0.01705409 0.06398842]
[0.01633382 0.06159163]
[0.01564688 0.05928436]
[0.01499153 0.05706327]
[0.01436612 0.05492516]
[0.01376912 0.05286692]
[0.01319907 0.05088559]
[0.0126546  0.04897829]
[0.01213441 0.04714226]
[0.0116373  0.04537486]
[0.01116211 0.04367352]
[0.01070777 0.04203578]
[0.01027326 0.04045928]
[0.00985761 0.03894174]
[0.00945992 0.03748095]
[0.00907932 0.03607481]
[0.008715   0.03472127]
[0.0083662  0.03341838]
[0.00803218 0.03216426]
[0.00771226 0.03095707]
[0.00740579 0.02979508]
[0.00711214 0.0286766 ]
[0.00683074 0.0276    ]
[0.00656102 0.02656373]
[0.00630246 0.02556628]
[0.00605456 0.0246062 ]
[0.00581685 0.02368209]
[0.00558886 0.02279262]
[0.00537018 0.02193649]
[0.0051604  0.02111245]
[0.00495912 0.02031931]
[0.00476598 0.01955591]
[0.00458063 0.01882114]
[0.00440273 0.01811393]
[0.00423197 0.01743324]
[0.00406804 0.0167781 ]
[0.00391064 0.01614754]
[0.00375952 0.01554064]
[0.0036144  0.01495652]
[0.00347502 0.01439432]
[0.00334116 0.01385323]
[0.00321259 0.01333245]
[0.00308907 0.01283122]
[0.00297042 0.01234881]
[0.00285642 0.01188452]
[0.00274688 0.01143767]
[0.00264164 0.0110076 ]
[0.0025405  0.01059368]
[0.0024433  0.01019531]
[0.00234989 0.00981191]
[0.00226011 0.00944292]
[0.00217381 0.00908778]
[0.00209086 0.008746  ]
[0.00201113 0.00841705]
[0.00193448 0.00810047]
[0.00186079 0.00779579]
[0.00178994 0.00750255]
[0.00172182 0.00722034]
[0.00165633 0.00694874]
[0.00159336 0.00668735]
[0.0015328  0.00643579]
[0.00147458 0.00619368]
[0.00141858 0.00596067]
[0.00136474 0.00573643]
[0.00131295 0.00552062]
[0.00126315 0.00531292]
[0.00121526 0.00511303]
[0.00116919 0.00492066]
[0.00112489 0.00473553]
[0.00108227 0.00455735]
[0.00104128 0.00438588]
[0.00100186 0.00422086]
[0.00096393 0.00406204]
[0.00092746 0.0039092 ]
[0.00089237 0.00376211]
[0.00085861 0.00362055]
[0.00082614 0.00348431]
[0.0007949 0.0033532]
[0.00076485 0.00322702]
[0.00073595 0.00310559]
[0.00070813 0.00298872]
[0.00068138 0.00287626]
[0.00065564 0.00276802]
[0.00063088 0.00266386]
[0.00060705 0.00256361]
[0.00058413 0.00246714]
[0.00056208 0.00237429]
[0.00054086 0.00228494]
[0.00052044 0.00219895]
[0.0005008 0.0021162]
[0.0004819  0.00203656]
[0.00046372 0.00195992]
[0.00044623 0.00188616]
[0.00042939 0.00181517]
[0.0004132  0.00174686]
[0.00039761 0.00168112]
[0.00038262 0.00161785]
[0.00036819 0.00155696]
[0.00035431 0.00149837]
[0.00034095 0.00144198]
[0.00032809 0.00138771]
[0.00031572 0.00133548]
[0.00030382 0.00128522]
[0.00029237 0.00123685]
[0.00028135 0.0011903 ]
[0.00027075 0.0011455 ]
[0.00026054 0.00110239]
[0.00025073 0.0010609 ]
[0.00024128 0.00102097]
[0.00023219 0.00098254]
[0.00022344 0.00094556]
[0.00021502 0.00090997]
[0.00020692 0.00087573]
[0.00019912 0.00084277]
[0.00019162 0.00081105]
[0.0001844  0.00078052]
[0.00017746 0.00075114]
[0.00017077 0.00072287]
[0.00016434 0.00069567]
[0.00015815 0.00066948]
[0.00015219 0.00064428]
[0.00014646 0.00062004]
[0.00014094 0.0005967 ]
[0.00013564 0.00057424]
[0.00013053 0.00055263]
[0.00012561 0.00053183]
[0.00012088 0.00051181]
[0.00011633 0.00049255]
[0.00011195 0.00047401]
[0.00010773 0.00045617]
[0.00010368 0.000439  ]
[9.9772391e-05 4.2247464e-04]
[9.6015516e-05 4.0657338e-04]
[9.2400165e-05 3.9127062e-04]
[8.8921006e-05 3.7654382e-04]
[8.557290e-05 3.623713e-04]
[8.235090e-05 3.487322e-04]
[7.9250269e-05 3.3560643e-04]
[7.626642e-05 3.229747e-04]
[7.3394949e-05 3.1081837e-04]
[7.063163e-05 2.991196e-04]
[6.7972374e-05 2.8786113e-04]
[6.5413275e-05 2.7702641e-04]
[6.295055e-05 2.665995e-04]
[6.0580569e-05 2.5656502e-04]
[5.8299836e-05 2.4690822e-04]
[5.610499e-05 2.376149e-04]
[5.3992793e-05 2.2867137e-04]
[5.1960134e-05 2.2006445e-04]
[5.0004015e-05 2.1178149e-04]
[4.8121550e-05 2.0381027e-04]
[4.6309968e-05 1.9613908e-04]
[4.4566597e-05 1.8875662e-04]
[4.2888871e-05 1.8165202e-04]
[4.1274314e-05 1.7481484e-04]
[3.9720548e-05 1.6823498e-04]
[3.8225280e-05 1.6190279e-04]
[3.6786310e-05 1.5580893e-04]
[3.5401517e-05 1.4994443e-04]
[3.4068860e-05 1.4430068e-04]
[3.2786378e-05 1.3886933e-04]
[3.1552179e-05 1.3364243e-04]
[3.0364447e-05 1.2861226e-04]
[2.9221428e-05 1.2377142e-04]
[2.81214416e-05 1.19112774e-04]
[2.7062868e-05 1.1462948e-04]
[2.6044147e-05 1.1031493e-04]
[2.5063775e-05 1.0616278e-04]
[2.4120311e-05 1.0216691e-04]
[2.3212364e-05 9.8321434e-05]
[2.2338598e-05 9.4620700e-05]
[2.1497726e-05 9.1059257e-05]
[2.0688509e-05 8.7631866e-05]
[1.9909754e-05 8.4333478e-05]
[1.9160316e-05 8.1159240e-05]
[1.8439088e-05 7.8104473e-05]
[1.7745011e-05 7.5164688e-05]
[1.7077062e-05 7.2335548e-05]
[1.6434256e-05 6.9612899e-05]
[1.5815649e-05 6.6992725e-05]
[1.5220328e-05 6.4471176e-05]
[1.4647418e-05 6.2044535e-05]
[1.4096073e-05 5.9709229e-05]
[1.3565483e-05 5.7461821e-05]
[1.3054865e-05 5.5299002e-05]
[1.2563469e-05 5.3217591e-05]
[1.2090570e-05 5.1214523e-05]
[1.1635472e-05 4.9286849e-05]
[1.1197505e-05 4.7431731e-05]
[1.0776024e-05 4.5646437e-05]
[1.0370409e-05 4.3928339e-05]
[9.980061e-06 4.227491e-05]
[9.6044068e-06 4.0683713e-05]
[9.2428927e-06 3.9152408e-05]
[8.894986e-06 3.767874e-05]
[8.560176e-06 3.626054e-05]
[8.237968e-06 3.489572e-05]
[7.927889e-06 3.358227e-05]
[7.6294818e-06 3.2318258e-05]
[7.3423071e-06 3.1101823e-05]
[7.0659416e-06 2.9931172e-05]
[6.7999790e-06 2.8804585e-05]
[6.5440272e-06 2.7720402e-05]
[6.2977097e-06 2.6677026e-05]
[6.0606640e-06 2.5672922e-05]
[5.8325409e-06 2.4706611e-05]
[5.6130043e-06 2.3776673e-05]
[5.4017314e-06 2.2881735e-05]
[5.1984107e-06 2.2020484e-05]
[5.0027429e-06 2.1191649e-05]
[4.814440e-06 2.039401e-05]
[4.6332257e-06 1.9626395e-05]
[4.4588319e-06 1.8887671e-05]
[4.2910024e-06 1.8176752e-05]
[4.1294902e-06 1.7492592e-05]
[3.9740571e-06 1.6834183e-05]
[3.8244748e-06 1.6200556e-05]
[3.6805227e-06 1.5590778e-05]
[3.5419889e-06 1.5003952e-05]
[3.4086695e-06 1.4439214e-05]
[3.2803682e-06 1.3895732e-05]
[3.1568964e-06 1.3372706e-05]
[3.0380718e-06 1.2869367e-05]
[2.9237199e-06 1.2384973e-05]
[2.8136722e-06 1.1918812e-05]
[2.7077667e-06 1.1470196e-05]
[2.6058474e-06 1.1038466e-05]
[2.5077643e-06 1.0622986e-05]
Gradient descent result: x_min, y_min = [2.5077643e-06 1.0622986e-05]