unitary_module#
Unitary matrix module.
This module provides functions for generating and working with unitary matrices. A unitary matrix is a complex square matrix whose conjugate transpose is equal to its inverse. For real matrices, this reduces to an orthogonal matrix where the transpose equals the inverse.
- jrystal._src.unitary_module.unitary_matrix(params: Dict[str, Array], complex: bool = False) Array [source]#
Generate a unitary (or orthogonal) matrix from its parameters.
This function constructs a unitary matrix (or orthogonal matrix in the real case) using the QR decomposition method. The resulting matrix will have orthogonal columns and determinant ±1.
Example:
key = jax.random.PRNGKey(0) # Generate a 3x3 complex unitary matrix params = unitary_matrix_param_init(key, (3, 3), complex=True) U = unitary_matrix(params, complex=True) # Verify unitarity: U†U should be close to identity jnp.allclose(U.conj().T @ U, jnp.eye(3))
- Parameters:
params (Dict[str, Array]) –
Dictionary containing the matrix parameters with keys:
’w_re’: Real part of the weight matrix
’w_im’: Imaginary part of the weight matrix (used if complex=True)
params can be generated by
unitary_matrix_param_init()
.complex (bool) – If True, generates a complex unitary matrix. If False, generates a real orthogonal matrix. Defaults to False.
- Returns:
A unitary matrix U where \(U^\dagger U = UU^\dagger = I\) (complex case) or an orthogonal matrix O where \(O^T O = O O^T = I\) (real case).
- Return type:
Array
- jrystal._src.unitary_module.unitary_matrix_param_init(key: Array, shape: tuple | List[int], complex: bool = True) Dict[str, Array] [source]#
Initialize parameters for generating a unitary matrix.
This function creates the necessary parameters to construct a unitary matrix of the specified shape. It initializes the real and optionally complex parts using uniform random distributions.
Example:
key = jax.random.PRNGKey(42) params = unitary_matrix_param_init(key, (2, 2)) sorted(params.keys()) >>> ['w_im', 'w_re']
- Parameters:
- Returns:
Dictionary containing:
’w_re’: Real weights of shape shape
’w_im’: Imaginary weights of shape shape (if complex=True)
- Return type: