Quantum Convolutional Neural Network

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial implements a simplified Quantum Convolutional Neural Network (QCNN), a proposed quantum analogue to a classical convolutional neural network that is also translationally invariant.

This example demonstrates how to detect certain properties of a quantum data source, such as a quantum sensor or a complex simulation from a device. The quantum data source being a cluster state that may or may not have an excitation—what the QCNN will learn to detect (The dataset used in the paper was SPT phase classification).

Setup

pip install tensorflow==2.15.0

Install TensorFlow Quantum:

pip install tensorflow-quantum==0.7.3
# Update package resources to account for version changes.
import importlib, pkg_resources

importlib.reload(pkg_resources)
/tmpfs/tmp/ipykernel_33006/2730622109.py:2: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  import importlib, pkg_resources
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/pkg_resources/__init__.py'>

Now import TensorFlow and the module dependencies:

import tensorflow as tf
import tensorflow_quantum as tfq

import cirq
import sympy
import numpy as np

# visualization tools
%matplotlib inline
import matplotlib.pyplot as plt
from cirq.contrib.svg import SVGCircuit
2025-01-10 12:43:56.895591: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-10 12:43:56.895636: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-10 12:43:56.897203: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-10 12:43:59.923405: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

1. Build a QCNN

1.1 Assemble circuits in a TensorFlow graph

TensorFlow Quantum (TFQ) provides layer classes designed for in-graph circuit construction. One example is the tfq.layers.AddCircuit layer that inherits from tf.keras.Layer. This layer can either prepend or append to the input batch of circuits, as shown in the following figure.

The following snippet uses this layer:

qubit = cirq.GridQubit(0, 0)

# Define some circuits.
circuit1 = cirq.Circuit(cirq.X(qubit))
circuit2 = cirq.Circuit(cirq.H(qubit))

# Convert to a tensor.
input_circuit_tensor = tfq.convert_to_tensor([circuit1, circuit2])

# Define a circuit that we want to append
y_circuit = cirq.Circuit(cirq.Y(qubit))

# Instantiate our layer
y_appender = tfq.layers.AddCircuit()

# Run our circuit tensor through the layer and save the output.
output_circuit_tensor = y_appender(input_circuit_tensor, append=y_circuit)

Examine the input tensor:

print(tfq.from_tensor(input_circuit_tensor))
[cirq.Circuit([
     cirq.Moment(
         cirq.X(cirq.GridQubit(0, 0)),
     ),
 ])
 cirq.Circuit([
     cirq.Moment(
         cirq.H(cirq.GridQubit(0, 0)),
     ),
 ])                                   ]

And examine the output tensor:

print(tfq.from_tensor(output_circuit_tensor))
[cirq.Circuit([
     cirq.Moment(
         cirq.X(cirq.GridQubit(0, 0)),
     ),
     cirq.Moment(
         cirq.Y(cirq.GridQubit(0, 0)),
     ),
 ])
 cirq.Circuit([
     cirq.Moment(
         cirq.H(cirq.GridQubit(0, 0)),
     ),
     cirq.Moment(
         cirq.Y(cirq.GridQubit(0, 0)),
     ),
 ])                                   ]

While it is possible to run the examples below without using tfq.layers.AddCircuit, it's a good opportunity to understand how complex functionality can be embedded into TensorFlow compute graphs.

1.2 Problem overview

You will prepare a cluster state and train a quantum classifier to detect if it is "excited" or not. The cluster state is highly entangled but not necessarily difficult for a classical computer. For clarity, this is a simpler dataset than the one used in the paper.

For this classification task you will implement a deep MERA-like QCNN architecture since:

  1. Like the QCNN, the cluster state on a ring is translationally invariant.
  2. The cluster state is highly entangled.

This architecture should be effective at reducing entanglement, obtaining the classification by reading out a single qubit.

An "excited" cluster state is defined as a cluster state that had a cirq.rx gate applied to any of its qubits. Qconv and QPool are discussed later in this tutorial.

1.3 Building blocks for TensorFlow

One way to solve this problem with TensorFlow Quantum is to implement the following:

  1. The input to the model is a circuit tensor—either an empty circuit or an X gate on a particular qubit indicating an excitation.
  2. The rest of the model's quantum components are constructed with tfq.layers.AddCircuit layers.
  3. For inference a tfq.layers.PQC layer is used. This reads \(\langle \hat{Z} \rangle\) and compares it to a label of 1 for an excited state, or -1 for a non-excited state.

1.4 Data

Before building your model, you can generate your data. In this case it's going to be excitations to the cluster state (The original paper uses a more complicated dataset). Excitations are represented with cirq.rx gates. A large enough rotation is deemed an excitation and is labeled 1 and a rotation that isn't large enough is labeled -1 and deemed not an excitation.

def generate_data(qubits):
    """Generate training and testing data."""
    n_rounds = 20  # Produces n_rounds * n_qubits datapoints.
    excitations = []
    labels = []
    for n in range(n_rounds):
        for bit in qubits:
            rng = np.random.uniform(-np.pi, np.pi)
            excitations.append(cirq.Circuit(cirq.rx(rng)(bit)))
            labels.append(1 if (-np.pi / 2) <= rng <= (np.pi / 2) else -1)

    split_ind = int(len(excitations) * 0.7)
    train_excitations = excitations[:split_ind]
    test_excitations = excitations[split_ind:]

    train_labels = labels[:split_ind]
    test_labels = labels[split_ind:]

    return tfq.convert_to_tensor(train_excitations), np.array(train_labels), \
        tfq.convert_to_tensor(test_excitations), np.array(test_labels)

You can see that just like with regular machine learning you create a training and testing set to use to benchmark the model. You can quickly look at some datapoints with:

sample_points, sample_labels, _, __ = generate_data(cirq.GridQubit.rect(1, 4))
print('Input:', tfq.from_tensor(sample_points)[0], 'Output:', sample_labels[0])
print('Input:', tfq.from_tensor(sample_points)[1], 'Output:', sample_labels[1])
Input: (0, 0): ───X^-0.701─── Output: -1
Input: (0, 1): ───X^-0.073─── Output: 1

1.5 Define layers

Now define the layers shown in the figure above in TensorFlow.

1.5.1 Cluster state

The first step is to define the cluster state using Cirq, a Google-provided framework for programming quantum circuits. Since this is a static part of the model, embed it using the tfq.layers.AddCircuit functionality.

def cluster_state_circuit(bits):
    """Return a cluster state on the qubits in `bits`."""
    circuit = cirq.Circuit()
    circuit.append(cirq.H.on_each(bits))
    for this_bit, next_bit in zip(bits, bits[1:] + [bits[0]]):
        circuit.append(cirq.CZ(this_bit, next_bit))
    return circuit

Display a cluster state circuit for a rectangle of cirq.GridQubits:

SVGCircuit(cluster_state_circuit(cirq.GridQubit.rect(1, 4)))
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.

svg

1.5.2 QCNN layers

Define the layers that make up the model using the Cong and Lukin QCNN paper. There are a few prerequisites:

  • The one- and two-qubit parameterized unitary matrices from the Tucci paper.
  • A general parameterized two-qubit pooling operation.
def one_qubit_unitary(bit, symbols):
    """Make a Cirq circuit enacting a rotation of the bloch sphere about the X,
    Y and Z axis, that depends on the values in `symbols`.
    """
    return cirq.Circuit(
        cirq.X(bit)**symbols[0],
        cirq.Y(bit)**symbols[1],
        cirq.Z(bit)**symbols[2])


def two_qubit_unitary(bits, symbols):
    """Make a Cirq circuit that creates an arbitrary two qubit unitary."""
    circuit = cirq.Circuit()
    circuit += one_qubit_unitary(bits[0], symbols[0:3])
    circuit += one_qubit_unitary(bits[1], symbols[3:6])
    circuit += [cirq.ZZ(*bits)**symbols[6]]
    circuit += [cirq.YY(*bits)**symbols[7]]
    circuit += [cirq.XX(*bits)**symbols[8]]
    circuit += one_qubit_unitary(bits[0], symbols[9:12])
    circuit += one_qubit_unitary(bits[1], symbols[12:])
    return circuit


def two_qubit_pool(source_qubit, sink_qubit, symbols):
    """Make a Cirq circuit to do a parameterized 'pooling' operation, which
    attempts to reduce entanglement down from two qubits to just one."""
    pool_circuit = cirq.Circuit()
    sink_basis_selector = one_qubit_unitary(sink_qubit, symbols[0:3])
    source_basis_selector = one_qubit_unitary(source_qubit, symbols[3:6])
    pool_circuit.append(sink_basis_selector)
    pool_circuit.append(source_basis_selector)
    pool_circuit.append(cirq.CNOT(source_qubit, sink_qubit))
    pool_circuit.append(sink_basis_selector**-1)
    return pool_circuit

To see what you created, print out the one-qubit unitary circuit:

SVGCircuit(one_qubit_unitary(cirq.GridQubit(0, 0), sympy.symbols('x0:3')))
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.

svg

And the two-qubit unitary circuit:

SVGCircuit(two_qubit_unitary(cirq.GridQubit.rect(1, 2), sympy.symbols('x0:15')))
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.

svg

And the two-qubit pooling circuit:

SVGCircuit(two_qubit_pool(*cirq.GridQubit.rect(1, 2), sympy.symbols('x0:6')))
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.

svg

1.5.2.1 Quantum convolution

As in the Cong and Lukin paper, define the 1D quantum convolution as the application of a two-qubit parameterized unitary to every pair of adjacent qubits with a stride of one.

def quantum_conv_circuit(bits, symbols):
    """Quantum Convolution Layer following the above diagram.
    Return a Cirq circuit with the cascade of `two_qubit_unitary` applied
    to all pairs of qubits in `bits` as in the diagram above.
    """
    circuit = cirq.Circuit()
    for first, second in zip(bits[0::2], bits[1::2]):
        circuit += two_qubit_unitary([first, second], symbols)
    for first, second in zip(bits[1::2], bits[2::2] + [bits[0]]):
        circuit += two_qubit_unitary([first, second], symbols)
    return circuit

Display the (very horizontal) circuit:

SVGCircuit(
    quantum_conv_circuit(cirq.GridQubit.rect(1, 8), sympy.symbols('x0:15')))
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.

svg

1.5.2.2 Quantum pooling

A quantum pooling layer pools from \(N\) qubits to \(\frac{N}{2}\) qubits using the two-qubit pool defined above.

def quantum_pool_circuit(source_bits, sink_bits, symbols):
    """A layer that specifies a quantum pooling operation.
    A Quantum pool tries to learn to pool the relevant information from two
    qubits onto 1.
    """
    circuit = cirq.Circuit()
    for source, sink in zip(source_bits, sink_bits):
        circuit += two_qubit_pool(source, sink, symbols)
    return circuit

Examine a pooling component circuit:

test_bits = cirq.GridQubit.rect(1, 8)

SVGCircuit(
    quantum_pool_circuit(test_bits[:4], test_bits[4:], sympy.symbols('x0:6')))
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.

svg

1.6 Model definition

Now use the defined layers to construct a purely quantum CNN. Start with eight qubits, pool down to one, then measure \(\langle \hat{Z} \rangle\).

def create_model_circuit(qubits):
    """Create sequence of alternating convolution and pooling operators 
    which gradually shrink over time."""
    model_circuit = cirq.Circuit()
    symbols = sympy.symbols('qconv0:63')
    # Cirq uses sympy.Symbols to map learnable variables. TensorFlow Quantum
    # scans incoming circuits and replaces these with TensorFlow variables.
    model_circuit += quantum_conv_circuit(qubits, symbols[0:15])
    model_circuit += quantum_pool_circuit(qubits[:4], qubits[4:],
                                          symbols[15:21])
    model_circuit += quantum_conv_circuit(qubits[4:], symbols[21:36])
    model_circuit += quantum_pool_circuit(qubits[4:6], qubits[6:],
                                          symbols[36:42])
    model_circuit += quantum_conv_circuit(qubits[6:], symbols[42:57])
    model_circuit += quantum_pool_circuit([qubits[6]], [qubits[7]],
                                          symbols[57:63])
    return model_circuit


# Create our qubits and readout operators in Cirq.
cluster_state_bits = cirq.GridQubit.rect(1, 8)
readout_operators = cirq.Z(cluster_state_bits[-1])

# Build a sequential model enacting the logic in 1.3 of this notebook.
# Here you are making the static cluster state prep as a part of the AddCircuit and the
# "quantum datapoints" are coming in the form of excitation
excitation_input = tf.keras.Input(shape=(), dtype=tf.dtypes.string)
cluster_state = tfq.layers.AddCircuit()(
    excitation_input, prepend=cluster_state_circuit(cluster_state_bits))

quantum_model = tfq.layers.PQC(create_model_circuit(cluster_state_bits),
                               readout_operators)(cluster_state)

qcnn_model = tf.keras.Model(inputs=[excitation_input], outputs=[quantum_model])

# Show the keras plot of the model
tf.keras.utils.plot_model(qcnn_model,
                          show_shapes=True,
                          show_layer_names=False,
                          dpi=70)

png

1.7 Train the model

Train the model over the full batch to simplify this example.

# Generate some training data.
train_excitations, train_labels, test_excitations, test_labels = generate_data(
    cluster_state_bits)


# Custom accuracy metric.
@tf.function
def custom_accuracy(y_true, y_pred):
    y_true = tf.squeeze(y_true)
    y_pred = tf.map_fn(lambda x: 1.0 if x >= 0 else -1.0, y_pred)
    return tf.keras.backend.mean(tf.keras.backend.equal(y_true, y_pred))


qcnn_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.02),
                   loss=tf.losses.mse,
                   metrics=[custom_accuracy])

history = qcnn_model.fit(x=train_excitations,
                         y=train_labels,
                         batch_size=16,
                         epochs=25,
                         verbose=1,
                         validation_data=(test_excitations, test_labels))
Epoch 1/25
7/7 [==============================] - 1s 133ms/step - loss: 0.8493 - custom_accuracy: 0.7500 - val_loss: 0.7326 - val_custom_accuracy: 0.7708
Epoch 2/25
7/7 [==============================] - 1s 100ms/step - loss: 0.8037 - custom_accuracy: 0.7589 - val_loss: 0.6988 - val_custom_accuracy: 0.8333
Epoch 3/25
7/7 [==============================] - 1s 96ms/step - loss: 0.7600 - custom_accuracy: 0.7500 - val_loss: 0.6644 - val_custom_accuracy: 0.7708
Epoch 4/25
7/7 [==============================] - 1s 95ms/step - loss: 0.7503 - custom_accuracy: 0.7500 - val_loss: 0.6600 - val_custom_accuracy: 0.7708
Epoch 5/25
7/7 [==============================] - 1s 94ms/step - loss: 0.7353 - custom_accuracy: 0.7679 - val_loss: 0.6509 - val_custom_accuracy: 0.8333
Epoch 6/25
7/7 [==============================] - 1s 91ms/step - loss: 0.7236 - custom_accuracy: 0.7857 - val_loss: 0.6283 - val_custom_accuracy: 0.8125
Epoch 7/25
7/7 [==============================] - 1s 93ms/step - loss: 0.7237 - custom_accuracy: 0.8125 - val_loss: 0.6414 - val_custom_accuracy: 0.8542
Epoch 8/25
7/7 [==============================] - 1s 93ms/step - loss: 0.7193 - custom_accuracy: 0.7589 - val_loss: 0.6266 - val_custom_accuracy: 0.8542
Epoch 9/25
7/7 [==============================] - 1s 93ms/step - loss: 0.7137 - custom_accuracy: 0.8125 - val_loss: 0.6160 - val_custom_accuracy: 0.8333
Epoch 10/25
7/7 [==============================] - 1s 92ms/step - loss: 0.7157 - custom_accuracy: 0.7946 - val_loss: 0.6174 - val_custom_accuracy: 0.8542
Epoch 11/25
7/7 [==============================] - 1s 91ms/step - loss: 0.7113 - custom_accuracy: 0.8125 - val_loss: 0.6256 - val_custom_accuracy: 0.8542
Epoch 12/25
7/7 [==============================] - 1s 93ms/step - loss: 0.7069 - custom_accuracy: 0.8304 - val_loss: 0.5883 - val_custom_accuracy: 0.8958
Epoch 13/25
7/7 [==============================] - 1s 93ms/step - loss: 0.6894 - custom_accuracy: 0.8214 - val_loss: 0.5804 - val_custom_accuracy: 0.8958
Epoch 14/25
7/7 [==============================] - 1s 95ms/step - loss: 0.6827 - custom_accuracy: 0.8125 - val_loss: 0.5962 - val_custom_accuracy: 0.8958
Epoch 15/25
7/7 [==============================] - 1s 92ms/step - loss: 0.6717 - custom_accuracy: 0.8125 - val_loss: 0.5910 - val_custom_accuracy: 0.8958
Epoch 16/25
7/7 [==============================] - 1s 91ms/step - loss: 0.6631 - custom_accuracy: 0.8482 - val_loss: 0.5789 - val_custom_accuracy: 0.8958
Epoch 17/25
7/7 [==============================] - 1s 90ms/step - loss: 0.6581 - custom_accuracy: 0.8125 - val_loss: 0.5666 - val_custom_accuracy: 0.8958
Epoch 18/25
7/7 [==============================] - 1s 91ms/step - loss: 0.6500 - custom_accuracy: 0.8304 - val_loss: 0.5543 - val_custom_accuracy: 0.9375
Epoch 19/25
7/7 [==============================] - 1s 93ms/step - loss: 0.6251 - custom_accuracy: 0.8304 - val_loss: 0.5248 - val_custom_accuracy: 0.9583
Epoch 20/25
7/7 [==============================] - 1s 91ms/step - loss: 0.5941 - custom_accuracy: 0.8750 - val_loss: 0.4657 - val_custom_accuracy: 0.9583
Epoch 21/25
7/7 [==============================] - 1s 91ms/step - loss: 0.5459 - custom_accuracy: 0.8839 - val_loss: 0.3872 - val_custom_accuracy: 0.9375
Epoch 22/25
7/7 [==============================] - 1s 90ms/step - loss: 0.5098 - custom_accuracy: 0.8929 - val_loss: 0.3442 - val_custom_accuracy: 0.9583
Epoch 23/25
7/7 [==============================] - 1s 91ms/step - loss: 0.4778 - custom_accuracy: 0.8929 - val_loss: 0.3443 - val_custom_accuracy: 0.9375
Epoch 24/25
7/7 [==============================] - 1s 90ms/step - loss: 0.4860 - custom_accuracy: 0.8929 - val_loss: 0.3315 - val_custom_accuracy: 0.9583
Epoch 25/25
7/7 [==============================] - 1s 97ms/step - loss: 0.4579 - custom_accuracy: 0.8929 - val_loss: 0.3236 - val_custom_accuracy: 0.9583
plt.plot(history.history['loss'][1:], label='Training')
plt.plot(history.history['val_loss'][1:], label='Validation')
plt.title('Training a Quantum CNN to Detect Excited Cluster States')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

png

2. Hybrid models

You don't have to go from eight qubits to one qubit using quantum convolution—you could have done one or two rounds of quantum convolution and fed the results into a classical neural network. This section explores quantum-classical hybrid models.

2.1 Hybrid model with a single quantum filter

Apply one layer of quantum convolution, reading out \(\langle \hat{Z}_n \rangle\) on all bits, followed by a densely-connected neural network.

2.1.1 Model definition

# 1-local operators to read out
readouts = [cirq.Z(bit) for bit in cluster_state_bits[4:]]


def multi_readout_model_circuit(qubits):
    """Make a model circuit with less quantum pool and conv operations."""
    model_circuit = cirq.Circuit()
    symbols = sympy.symbols('qconv0:21')
    model_circuit += quantum_conv_circuit(qubits, symbols[0:15])
    model_circuit += quantum_pool_circuit(qubits[:4], qubits[4:],
                                          symbols[15:21])
    return model_circuit


# Build a model enacting the logic in 2.1 of this notebook.
excitation_input_dual = tf.keras.Input(shape=(), dtype=tf.dtypes.string)

cluster_state_dual = tfq.layers.AddCircuit()(
    excitation_input_dual, prepend=cluster_state_circuit(cluster_state_bits))

quantum_model_dual = tfq.layers.PQC(
    multi_readout_model_circuit(cluster_state_bits),
    readouts)(cluster_state_dual)

d1_dual = tf.keras.layers.Dense(8)(quantum_model_dual)

d2_dual = tf.keras.layers.Dense(1)(d1_dual)

hybrid_model = tf.keras.Model(inputs=[excitation_input_dual], outputs=[d2_dual])

# Display the model architecture
tf.keras.utils.plot_model(hybrid_model,
                          show_shapes=True,
                          show_layer_names=False,
                          dpi=70)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer RandomUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
  warnings.warn(

png

2.1.2 Train the model

hybrid_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.02),
                     loss=tf.losses.mse,
                     metrics=[custom_accuracy])

hybrid_history = hybrid_model.fit(x=train_excitations,
                                  y=train_labels,
                                  batch_size=16,
                                  epochs=25,
                                  verbose=1,
                                  validation_data=(test_excitations,
                                                   test_labels))
Epoch 1/25
7/7 [==============================] - 1s 92ms/step - loss: 1.0159 - custom_accuracy: 0.5179 - val_loss: 1.0093 - val_custom_accuracy: 0.3958
Epoch 2/25
7/7 [==============================] - 0s 64ms/step - loss: 0.9899 - custom_accuracy: 0.6161 - val_loss: 0.9553 - val_custom_accuracy: 0.8125
Epoch 3/25
7/7 [==============================] - 0s 61ms/step - loss: 0.9317 - custom_accuracy: 0.7411 - val_loss: 0.8141 - val_custom_accuracy: 0.8542
Epoch 4/25
7/7 [==============================] - 0s 61ms/step - loss: 0.7544 - custom_accuracy: 0.8125 - val_loss: 0.5967 - val_custom_accuracy: 0.8750
Epoch 5/25
7/7 [==============================] - 0s 61ms/step - loss: 0.6727 - custom_accuracy: 0.8304 - val_loss: 0.7394 - val_custom_accuracy: 0.7708
Epoch 6/25
7/7 [==============================] - 0s 60ms/step - loss: 0.5776 - custom_accuracy: 0.8839 - val_loss: 0.4036 - val_custom_accuracy: 0.9167
Epoch 7/25
7/7 [==============================] - 0s 59ms/step - loss: 0.4326 - custom_accuracy: 0.8571 - val_loss: 0.3918 - val_custom_accuracy: 0.8958
Epoch 8/25
7/7 [==============================] - 0s 59ms/step - loss: 0.2957 - custom_accuracy: 0.9643 - val_loss: 0.1578 - val_custom_accuracy: 1.0000
Epoch 9/25
7/7 [==============================] - 0s 59ms/step - loss: 0.2192 - custom_accuracy: 0.9554 - val_loss: 0.1255 - val_custom_accuracy: 1.0000
Epoch 10/25
7/7 [==============================] - 0s 59ms/step - loss: 0.1944 - custom_accuracy: 0.9732 - val_loss: 0.1515 - val_custom_accuracy: 1.0000
Epoch 11/25
7/7 [==============================] - 0s 60ms/step - loss: 0.2123 - custom_accuracy: 0.9554 - val_loss: 0.1618 - val_custom_accuracy: 1.0000
Epoch 12/25
7/7 [==============================] - 0s 59ms/step - loss: 0.2083 - custom_accuracy: 0.9554 - val_loss: 0.1303 - val_custom_accuracy: 1.0000
Epoch 13/25
7/7 [==============================] - 0s 60ms/step - loss: 0.2612 - custom_accuracy: 0.9464 - val_loss: 0.1475 - val_custom_accuracy: 1.0000
Epoch 14/25
7/7 [==============================] - 0s 60ms/step - loss: 0.2116 - custom_accuracy: 0.9643 - val_loss: 0.1490 - val_custom_accuracy: 1.0000
Epoch 15/25
7/7 [==============================] - 0s 58ms/step - loss: 0.2025 - custom_accuracy: 0.9643 - val_loss: 0.1660 - val_custom_accuracy: 1.0000
Epoch 16/25
7/7 [==============================] - 0s 58ms/step - loss: 0.2039 - custom_accuracy: 0.9643 - val_loss: 0.1171 - val_custom_accuracy: 1.0000
Epoch 17/25
7/7 [==============================] - 0s 59ms/step - loss: 0.1836 - custom_accuracy: 0.9732 - val_loss: 0.1212 - val_custom_accuracy: 1.0000
Epoch 18/25
7/7 [==============================] - 0s 58ms/step - loss: 0.1774 - custom_accuracy: 0.9732 - val_loss: 0.1192 - val_custom_accuracy: 1.0000
Epoch 19/25
7/7 [==============================] - 0s 58ms/step - loss: 0.1733 - custom_accuracy: 0.9732 - val_loss: 0.1291 - val_custom_accuracy: 1.0000
Epoch 20/25
7/7 [==============================] - 0s 59ms/step - loss: 0.1767 - custom_accuracy: 0.9732 - val_loss: 0.1279 - val_custom_accuracy: 1.0000
Epoch 21/25
7/7 [==============================] - 0s 59ms/step - loss: 0.1748 - custom_accuracy: 0.9732 - val_loss: 0.1294 - val_custom_accuracy: 1.0000
Epoch 22/25
7/7 [==============================] - 0s 58ms/step - loss: 0.1870 - custom_accuracy: 0.9732 - val_loss: 0.1395 - val_custom_accuracy: 1.0000
Epoch 23/25
7/7 [==============================] - 0s 59ms/step - loss: 0.1950 - custom_accuracy: 0.9732 - val_loss: 0.1501 - val_custom_accuracy: 1.0000
Epoch 24/25
7/7 [==============================] - 0s 59ms/step - loss: 0.1866 - custom_accuracy: 0.9643 - val_loss: 0.1474 - val_custom_accuracy: 1.0000
Epoch 25/25
7/7 [==============================] - 0s 58ms/step - loss: 0.1988 - custom_accuracy: 0.9554 - val_loss: 0.1452 - val_custom_accuracy: 1.0000
plt.plot(history.history['val_custom_accuracy'], label='QCNN')
plt.plot(hybrid_history.history['val_custom_accuracy'], label='Hybrid CNN')
plt.title('Quantum vs Hybrid CNN performance')
plt.xlabel('Epochs')
plt.legend()
plt.ylabel('Validation Accuracy')
plt.show()

png

As you can see, with very modest classical assistance, the hybrid model will usually converge faster than the purely quantum version.

2.2 Hybrid convolution with multiple quantum filters

Now let's try an architecture that uses multiple quantum convolutions and a classical neural network to combine them.

2.2.1 Model definition

excitation_input_multi = tf.keras.Input(shape=(), dtype=tf.dtypes.string)

cluster_state_multi = tfq.layers.AddCircuit()(
    excitation_input_multi, prepend=cluster_state_circuit(cluster_state_bits))

# apply 3 different filters and measure expectation values

quantum_model_multi1 = tfq.layers.PQC(
    multi_readout_model_circuit(cluster_state_bits),
    readouts)(cluster_state_multi)

quantum_model_multi2 = tfq.layers.PQC(
    multi_readout_model_circuit(cluster_state_bits),
    readouts)(cluster_state_multi)

quantum_model_multi3 = tfq.layers.PQC(
    multi_readout_model_circuit(cluster_state_bits),
    readouts)(cluster_state_multi)

# concatenate outputs and feed into a small classical NN
concat_out = tf.keras.layers.concatenate(
    [quantum_model_multi1, quantum_model_multi2, quantum_model_multi3])

dense_1 = tf.keras.layers.Dense(8)(concat_out)

dense_2 = tf.keras.layers.Dense(1)(dense_1)

multi_qconv_model = tf.keras.Model(inputs=[excitation_input_multi],
                                   outputs=[dense_2])

# Display the model architecture
tf.keras.utils.plot_model(multi_qconv_model,
                          show_shapes=True,
                          show_layer_names=True,
                          dpi=70)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer RandomUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
  warnings.warn(
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer RandomUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
  warnings.warn(
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer RandomUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
  warnings.warn(

png

2.2.2 Train the model

multi_qconv_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.02),
    loss=tf.losses.mse,
    metrics=[custom_accuracy])

multi_qconv_history = multi_qconv_model.fit(x=train_excitations,
                                            y=train_labels,
                                            batch_size=16,
                                            epochs=25,
                                            verbose=1,
                                            validation_data=(test_excitations,
                                                             test_labels))
Epoch 1/25
7/7 [==============================] - 2s 112ms/step - loss: 1.0004 - custom_accuracy: 0.5982 - val_loss: 0.8883 - val_custom_accuracy: 0.9375
Epoch 2/25
7/7 [==============================] - 1s 76ms/step - loss: 0.7796 - custom_accuracy: 0.7857 - val_loss: 0.4550 - val_custom_accuracy: 0.8750
Epoch 3/25
7/7 [==============================] - 0s 73ms/step - loss: 0.3783 - custom_accuracy: 0.9018 - val_loss: 0.1615 - val_custom_accuracy: 1.0000
Epoch 4/25
7/7 [==============================] - 0s 69ms/step - loss: 0.2220 - custom_accuracy: 0.9554 - val_loss: 0.1605 - val_custom_accuracy: 1.0000
Epoch 5/25
7/7 [==============================] - 1s 76ms/step - loss: 0.1872 - custom_accuracy: 0.9732 - val_loss: 0.1499 - val_custom_accuracy: 1.0000
Epoch 6/25
7/7 [==============================] - 0s 72ms/step - loss: 0.2077 - custom_accuracy: 0.9732 - val_loss: 0.1606 - val_custom_accuracy: 0.9792
Epoch 7/25
7/7 [==============================] - 0s 70ms/step - loss: 0.1833 - custom_accuracy: 0.9732 - val_loss: 0.1602 - val_custom_accuracy: 1.0000
Epoch 8/25
7/7 [==============================] - 1s 79ms/step - loss: 0.1959 - custom_accuracy: 0.9732 - val_loss: 0.1711 - val_custom_accuracy: 0.9583
Epoch 9/25
7/7 [==============================] - 0s 73ms/step - loss: 0.1924 - custom_accuracy: 0.9643 - val_loss: 0.1611 - val_custom_accuracy: 1.0000
Epoch 10/25
7/7 [==============================] - 0s 71ms/step - loss: 0.2019 - custom_accuracy: 0.9375 - val_loss: 0.1662 - val_custom_accuracy: 1.0000
Epoch 11/25
7/7 [==============================] - 0s 68ms/step - loss: 0.1910 - custom_accuracy: 0.9643 - val_loss: 0.1484 - val_custom_accuracy: 1.0000
Epoch 12/25
7/7 [==============================] - 0s 71ms/step - loss: 0.1913 - custom_accuracy: 0.9643 - val_loss: 0.1518 - val_custom_accuracy: 1.0000
Epoch 13/25
7/7 [==============================] - 1s 74ms/step - loss: 0.1864 - custom_accuracy: 0.9643 - val_loss: 0.1262 - val_custom_accuracy: 1.0000
Epoch 14/25
7/7 [==============================] - 1s 74ms/step - loss: 0.1854 - custom_accuracy: 0.9643 - val_loss: 0.1393 - val_custom_accuracy: 1.0000
Epoch 15/25
7/7 [==============================] - 0s 73ms/step - loss: 0.1720 - custom_accuracy: 0.9643 - val_loss: 0.1230 - val_custom_accuracy: 1.0000
Epoch 16/25
7/7 [==============================] - 1s 82ms/step - loss: 0.1712 - custom_accuracy: 0.9643 - val_loss: 0.1307 - val_custom_accuracy: 1.0000
Epoch 17/25
7/7 [==============================] - 0s 70ms/step - loss: 0.1851 - custom_accuracy: 0.9554 - val_loss: 0.1556 - val_custom_accuracy: 1.0000
Epoch 18/25
7/7 [==============================] - 0s 74ms/step - loss: 0.1777 - custom_accuracy: 0.9732 - val_loss: 0.1461 - val_custom_accuracy: 1.0000
Epoch 19/25
7/7 [==============================] - 0s 74ms/step - loss: 0.1747 - custom_accuracy: 0.9554 - val_loss: 0.1293 - val_custom_accuracy: 1.0000
Epoch 20/25
7/7 [==============================] - 0s 72ms/step - loss: 0.1702 - custom_accuracy: 0.9732 - val_loss: 0.1504 - val_custom_accuracy: 1.0000
Epoch 21/25
7/7 [==============================] - 0s 70ms/step - loss: 0.1770 - custom_accuracy: 0.9732 - val_loss: 0.1516 - val_custom_accuracy: 1.0000
Epoch 22/25
7/7 [==============================] - 1s 81ms/step - loss: 0.1871 - custom_accuracy: 0.9732 - val_loss: 0.1490 - val_custom_accuracy: 1.0000
Epoch 23/25
7/7 [==============================] - 0s 70ms/step - loss: 0.1951 - custom_accuracy: 0.9732 - val_loss: 0.1236 - val_custom_accuracy: 1.0000
Epoch 24/25
7/7 [==============================] - 0s 68ms/step - loss: 0.1788 - custom_accuracy: 0.9554 - val_loss: 0.1805 - val_custom_accuracy: 0.9792
Epoch 25/25
7/7 [==============================] - 1s 73ms/step - loss: 0.1938 - custom_accuracy: 0.9643 - val_loss: 0.1533 - val_custom_accuracy: 1.0000
plt.plot(history.history['val_custom_accuracy'][:25], label='QCNN')
plt.plot(hybrid_history.history['val_custom_accuracy'][:25], label='Hybrid CNN')
plt.plot(multi_qconv_history.history['val_custom_accuracy'][:25],
         label='Hybrid CNN \n Multiple Quantum Filters')
plt.title('Quantum vs Hybrid CNN performance')
plt.xlabel('Epochs')
plt.legend()
plt.ylabel('Validation Accuracy')
plt.show()

png