Skip to content

How to run QONetwork learning cycle#

To run learning cycle, first you have to instantiate QOConstants object containing parameters of QONetwork and QOTracker for tracking learning process metrics.

1
2
3
4
5
6
7
8
9
constants = QOConstants(
    k=4.0,
    mass=1.0,
    x_left=-6.0,
    x_right=6.0,
    fb=0.0,
    sample_size=1200,
    tracker=QOTracker(),
)

Then you have to instantiate QONetwork object

network = QONetwork(constants=constants)

And after above steps you can run learning loop

1
2
3
4
5
6
7
8
for index, nn in enumerate(
    network.train_generations(
        QOParams(c=-2.0, c_step=0.16),
        generations=150,
        epochs=1000,
    )
):
    pass

After each generation of learning, body of loop will be executed, thus you can stuff any kind of plotting there.

For example you can use QOTracker.plot() (see full code snippet at the very bottom)

Learning graph

Example learning graph created using QOTracker.plot()

Full code snippet:

import gc
from pathlib import Path

import matplotlib
import tensorflow as tf
from matplotlib import pyplot as plt

from nneve.quantum_oscillator import (
    QOConstants,
    QONetwork,
    QOParams,
    QOTracker,
)
from nneve.utility.testing import disable_gpu_or_skip

EXAMPLES_CODE = Path(__file__).parent
EXAMPLES_DIR = EXAMPLES_CODE.parent
WEIGHTS_DIR = EXAMPLES_DIR / "weights"
PLOTS_DIR = EXAMPLES_DIR / "plots"


tf.random.set_seed(0)
disable_gpu_or_skip()

constants = QOConstants(
    k=4.0,
    mass=1.0,
    x_left=-6.0,
    x_right=6.0,
    fb=0.0,
    sample_size=1200,
    tracker=QOTracker(),
)
network = QONetwork(constants=constants, is_debug=True)

network.summary()
matplotlib.use("Agg")

for index, nn in enumerate(
    network.train_generations(
        QOParams(c=-2.0, c_step=0.16),
        generations=150,
        epochs=1000,
    )
):
    x = nn.constants.get_sample()
    y2, _ = nn(x)
    nn.constants.tracker.plot(y2, x)
    # savefig tends to create memory leaks
    plt.savefig(PLOTS_DIR / f"{index}.png")
    plt.cla()
    plt.clf()
    plt.close("all")
    gc.collect()
    nn.save(WEIGHTS_DIR / f"{index}.w")