Skip to content

Finding Neighbors

lahuta.core.neighbor_finder.NeighborSearch

Handle atom related operations, including finding neighbors and preparation for computation.

The class provides methods to find neighbors of each atom in the universe and to remove pairs of atoms that are adjacent in the sequence.

Parameters:

Name Type Description Default
mda AtomGroupType

The AtomGroup containing the atoms.

required

Attributes:

Name Type Description
ag_no_h AtomGroup

Atom group of a universe excluding hydrogen atoms.

og_resids ndarray

The residue IDs of each atom in the universe.

Source code in lahuta/core/neighbor_finder.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class NeighborSearch:
    """Handle atom related operations, including finding neighbors and preparation for computation.

    The class provides methods to find neighbors of each atom in the universe and to remove pairs of atoms
    that are adjacent in the sequence.

    Args:
        mda (AtomGroupType): The AtomGroup containing the atoms.


    Attributes:
        ag_no_h (AtomGroup): Atom group of a universe excluding hydrogen atoms.
        og_resids (np.ndarray): The residue IDs of each atom in the universe.
    """

    def __init__(self, mda: AtomGroupType) -> None:
        mda_atoms, mda_universe = mda.atoms, mda.universe
        self.ag_no_h = mda_atoms.select_atoms("not name H*")
        self.og_resids = mda_universe.atoms.resids

    def compute(self, radius: float = 5.0, res_dif: int = 1) -> PairsDistances:
        """Compute the neighbors of each atom in the Universe.

        Args:
            radius (float, optional): The cutoff radius. Default is 5.0.
            res_dif (int, optional): The residue difference to consider. Default is 1.

        Returns:
            PairsDistances: A tuple containing the pairs of atom indices and the distances.
        """
        pairs, distances = self.get_neighbors(radius)

        if res_dif > 0:
            idx = self.remove_adjacent_residue_pairs(pairs, res_dif=res_dif)
            pairs = pairs[idx]
            distances = distances[idx]

        return pairs, distances

    def get_neighbors(self, radius: float) -> PairsDistances:
        """Get the neighbors of an AtomGroup.

        Args:
            radius (float, optional): The cutoff radius.

        Returns:
            PairsDistances: A tuple containing the pairs of atom indices and the distances.
        """
        # check for dimensions
        if not hasattr(self.ag_no_h.universe, "dimensions") or self.ag_no_h.universe.dimensions is None:
            positions, dimensions = mda_psuedobox_from_atomgroup(self.ag_no_h)
            pbc = False
        else:
            positions = self.ag_no_h.positions
            dimensions = self.ag_no_h.universe.dimensions
            pbc = True

        gridsearch = FastNS(cutoff=radius, coords=positions, box=dimensions, pbc=pbc)
        neighbors = gridsearch.self_search()

        return (
            self.ag_no_h[neighbors.get_pairs()].ix,
            neighbors.get_pair_distances(),
        )

    def remove_adjacent_residue_pairs(self, pairs: NDArray[np.int32], res_dif: int = 1) -> NDArray[np.bool_]:
        """Remove pairs where the difference in residue ids is less than `res_dif`.

        Args:
            pairs (NDArray[np.int32]): An array of shape (n_pairs, 2) where each row is a pair of atom indices.
            res_dif (int, optional): The difference in residue ids to remove. Default is 1.

        Returns:
            NDArray[np.bool_]: An array of shape (n_pairs,) containing the indices of the pairs to keep.
        """
        resids = self.og_resids[pairs]
        # return np.any(np.abs(resids - resids[:, ::-1]) > res_dif, axis=1) # noqa: ERA001
        mask: NDArray[np.bool_] = np.abs(np.diff(resids, axis=1)) > res_dif
        return np.ravel(mask)

compute

compute(radius=5.0, res_dif=1)

Compute the neighbors of each atom in the Universe.

Parameters:

Name Type Description Default
radius float

The cutoff radius. Default is 5.0.

5.0
res_dif int

The residue difference to consider. Default is 1.

1

Returns:

Name Type Description
PairsDistances PairsDistances

A tuple containing the pairs of atom indices and the distances.

Source code in lahuta/core/neighbor_finder.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def compute(self, radius: float = 5.0, res_dif: int = 1) -> PairsDistances:
    """Compute the neighbors of each atom in the Universe.

    Args:
        radius (float, optional): The cutoff radius. Default is 5.0.
        res_dif (int, optional): The residue difference to consider. Default is 1.

    Returns:
        PairsDistances: A tuple containing the pairs of atom indices and the distances.
    """
    pairs, distances = self.get_neighbors(radius)

    if res_dif > 0:
        idx = self.remove_adjacent_residue_pairs(pairs, res_dif=res_dif)
        pairs = pairs[idx]
        distances = distances[idx]

    return pairs, distances

get_neighbors

get_neighbors(radius)

Get the neighbors of an AtomGroup.

Parameters:

Name Type Description Default
radius float

The cutoff radius.

required

Returns:

Name Type Description
PairsDistances PairsDistances

A tuple containing the pairs of atom indices and the distances.

Source code in lahuta/core/neighbor_finder.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def get_neighbors(self, radius: float) -> PairsDistances:
    """Get the neighbors of an AtomGroup.

    Args:
        radius (float, optional): The cutoff radius.

    Returns:
        PairsDistances: A tuple containing the pairs of atom indices and the distances.
    """
    # check for dimensions
    if not hasattr(self.ag_no_h.universe, "dimensions") or self.ag_no_h.universe.dimensions is None:
        positions, dimensions = mda_psuedobox_from_atomgroup(self.ag_no_h)
        pbc = False
    else:
        positions = self.ag_no_h.positions
        dimensions = self.ag_no_h.universe.dimensions
        pbc = True

    gridsearch = FastNS(cutoff=radius, coords=positions, box=dimensions, pbc=pbc)
    neighbors = gridsearch.self_search()

    return (
        self.ag_no_h[neighbors.get_pairs()].ix,
        neighbors.get_pair_distances(),
    )

remove_adjacent_residue_pairs

remove_adjacent_residue_pairs(pairs, res_dif=1)

Remove pairs where the difference in residue ids is less than res_dif.

Parameters:

Name Type Description Default
pairs NDArray[int32]

An array of shape (n_pairs, 2) where each row is a pair of atom indices.

required
res_dif int

The difference in residue ids to remove. Default is 1.

1

Returns:

Type Description
NDArray[bool_]

NDArray[np.bool_]: An array of shape (n_pairs,) containing the indices of the pairs to keep.

Source code in lahuta/core/neighbor_finder.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def remove_adjacent_residue_pairs(self, pairs: NDArray[np.int32], res_dif: int = 1) -> NDArray[np.bool_]:
    """Remove pairs where the difference in residue ids is less than `res_dif`.

    Args:
        pairs (NDArray[np.int32]): An array of shape (n_pairs, 2) where each row is a pair of atom indices.
        res_dif (int, optional): The difference in residue ids to remove. Default is 1.

    Returns:
        NDArray[np.bool_]: An array of shape (n_pairs,) containing the indices of the pairs to keep.
    """
    resids = self.og_resids[pairs]
    # return np.any(np.abs(resids - resids[:, ::-1]) > res_dif, axis=1) # noqa: ERA001
    mask: NDArray[np.bool_] = np.abs(np.diff(resids, axis=1)) > res_dif
    return np.ravel(mask)