unitary_module#

Unitary matrix module.

This module provides functions for generating and working with unitary matrices. A unitary matrix is a complex square matrix whose conjugate transpose is equal to its inverse. For real matrices, this reduces to an orthogonal matrix where the transpose equals the inverse.

jrystal._src.unitary_module.unitary_matrix(params: Dict[str, Array], complex: bool = False) Array[source]#

Generate a unitary (or orthogonal) matrix from its parameters.

This function constructs a unitary matrix (or orthogonal matrix in the real case) using the QR decomposition method. The resulting matrix will have orthogonal columns and determinant ±1.

Example:

key = jax.random.PRNGKey(0)

# Generate a 3x3 complex unitary matrix
params = unitary_matrix_param_init(key, (3, 3), complex=True)
U = unitary_matrix(params, complex=True)

# Verify unitarity: U†U should be close to identity
jnp.allclose(U.conj().T @ U, jnp.eye(3))
Parameters:
  • params (Dict[str, Array]) –

    Dictionary containing the matrix parameters with keys:

    • ’w_re’: Real part of the weight matrix

    • ’w_im’: Imaginary part of the weight matrix (used if complex=True)

    params can be generated by unitary_matrix_param_init().

  • complex (bool) – If True, generates a complex unitary matrix. If False, generates a real orthogonal matrix. Defaults to False.

Returns:

A unitary matrix U where \(U^\dagger U = UU^\dagger = I\) (complex case) or an orthogonal matrix O where \(O^T O = O O^T = I\) (real case).

Return type:

Array

jrystal._src.unitary_module.unitary_matrix_param_init(key: Array, shape: tuple | List[int], complex: bool = True) Dict[str, Array][source]#

Initialize parameters for generating a unitary matrix.

This function creates the necessary parameters to construct a unitary matrix of the specified shape. It initializes the real and optionally complex parts using uniform random distributions.

Example:

key = jax.random.PRNGKey(42)
params = unitary_matrix_param_init(key, (2, 2))
sorted(params.keys())
>>> ['w_im', 'w_re']
Parameters:
  • key (Array) – JAX PRNG key for random number generation.

  • shape (tuple or List[int]) – Tuple of ints specifying the dimensions of the matrix.

  • complex (bool) – If True, initializes both real and imaginary components.If False, only initializes real components. Defaults to True.

Returns:

Dictionary containing:

  • ’w_re’: Real weights of shape shape

  • ’w_im’: Imaginary weights of shape shape (if complex=True)

Return type:

dict