import numpy as np


if __name__ == "__main__":

    Ns, Msplus2, Nt, Mt = 30, 32, 20, 20

    phi_long = np.random.rand(Ns, Msplus2)
    phi_lat = np.random.rand(Nt, Mt)

    # Numpy wizardry 🧙
    Phi = np.einsum('il, kj->iklj', phi_long, phi_lat)
    assert Phi.shape == (Ns, Nt, Msplus2, Mt)

    print("Bytes needed for Phi:", Phi.nbytes)

    intended_Phi = np.zeros((Ns, Nt, Msplus2, Mt))

    # Computing the outer product on the last axis
    for i in range(Ns):
        for j in range(Nt):
            for l in range(Msplus2):
                for k in range(Mt):
                    intended_Phi[i, j, l, k] = phi_long[i, l] * phi_lat[j, k]

    np.testing.assert_array_equal(Phi, intended_Phi)

    # Testing reshaping
    reshaped_Phi = Phi.reshape((Ns * Nt, Msplus2 * Mt))
    intended_reshaped_Phi = np.zeros((Ns * Nt, Msplus2 * Mt))
    for i in range(Ns):
        for j in range(Nt):
            for l in range(Msplus2):
                for k in range(Mt):
                    p = i * Nt + j
                    m = l * Mt + k
                    intended_reshaped_Phi[p, m] = Phi[i, j, l, k]

    np.testing.assert_array_equal(reshaped_Phi, intended_reshaped_Phi)

    print("Julien loves tensor.")