Skip to content

Propositional requirements

The backend for propositional requirements — boolean logic, written as Horn rules (head :- body) or disjunctive clauses (y_0 or not y_1). This backend also provides the Memory-efficient Loss (a memory-efficient t-norm loss inspired by Logic Tensor Networks, LTN).

Shield Layer

shield_layer

The propositional Shield Layer.

Defines :class:ShieldLayer, a differentiable PyTorch layer that corrects a model's predictions so they provably satisfy a set of propositional requirements. The requirements are normalised into clauses, stratified into ordered layers, and applied in order by per-stratum :class:ConstraintsModule modules.

ShieldLayer

ShieldLayer(num_classes: int, requirements: Union[str, List[ConstraintsGroup]] = None, ordering_choice: str = None, custom_ordering: str = None)

Bases: Module

Differentiable layer that corrects predictions so they satisfy a set of propositional requirements over binary variables (the num_classes outputs, interpreted as probabilities).

The requirements are normalised into clauses and then stratified: each clause is assigned to a layer (stratum) such that a clause's head is only corrected after the variables in its body, with the centrality/ordering deciding which variable plays the head when there is a choice. Each stratum becomes a ConstraintsModule, and forward applies the modules in order, so correcting one stratum never violates an earlier one — guaranteeing all requirements hold on the output.

requirements may be either a path to a constraints file or a precomputed list of strata.

Build the layer by stratifying the requirements into correction modules.

Parameters:

Name Type Description Default
num_classes int

The number of output variables (labels) the layer operates on.

required
requirements Union[str, List[ConstraintsGroup]]

Either a path to a constraints file or a precomputed list of strata (:class:ConstraintsGroup objects).

None
ordering_choice str

The variable-ordering / centrality strategy used to decide which variable plays the head during stratification.

None
custom_ordering str

An optional explicit comma-separated variable order; with a 'given' ordering choice and no value, defaults to 0,1,...,n-1.

None

Raises:

Type Description
Exception

If requirements is neither a str nor a list.

Example

layer = ShieldLayer(num_classes=5, ... requirements='constraints.txt', ... ordering_choice='given') corrected = layer(predictions) # predictions: (batch, 5)

Source code in pishield/propositional_requirements/shield_layer.py
def __init__(self, num_classes: int,
             requirements: Union[str, List[ConstraintsGroup]] = None,
             ordering_choice: str = None,
             custom_ordering: str = None):
    """Build the layer by stratifying the requirements into correction modules.

    Args:
        num_classes: The number of output variables (labels) the layer operates on.
        requirements: Either a path to a constraints file or a precomputed list of
            strata (:class:`ConstraintsGroup` objects).
        ordering_choice: The variable-ordering / centrality strategy used to decide
            which variable plays the head during stratification.
        custom_ordering: An optional explicit comma-separated variable order; with a
            ``'given'`` ordering choice and no value, defaults to ``0,1,...,n-1``.

    Raises:
        Exception: If ``requirements`` is neither a str nor a list.

    Example:
        >>> layer = ShieldLayer(num_classes=5,
        ...                     requirements='constraints.txt',
        ...                     ordering_choice='given')
        >>> corrected = layer(predictions)  # predictions: (batch, 5)
    """
    super(ShieldLayer, self).__init__()

    self.num_classes = num_classes
    self.ordering_choice = ordering_choice
    # With the 'given' ordering and no explicit ordering, default to the ascending order 0,1,...,n-1.
    if 'given' in ordering_choice and custom_ordering is None:
        custom_ordering = ",".join([str(e) for e in np.arange(0,num_classes)])
    self.custom_ordering = custom_ordering
    self.constraints = requirements

    if type(requirements) == str:
        # Load constraints from file and convert them to a normalised set of clauses.
        constraints_filepath = requirements
        constraints_group = ConstraintsGroup(constraints_filepath)
        clauses_group = ClausesGroup.from_constraints_group(constraints_group)

        # forced = False  # TODO: what's this for?
        # clauses = clauses.add_detection_label(forced)
        # print(f"Shifted atoms and added n0 to all clauses (forced {forced})")

        # Stratify the clauses into ordered layers according to the chosen variable centrality.
        centrality = get_order_and_centrality(ordering_choice, custom_ordering)
        strata = clauses_group.stratify(centrality)
        self.stratified_constraints = strata

        print(f"Generated {len(strata)} strata of constraints with {centrality} centrality")
    elif type(requirements) == list:
        # Strata were supplied directly (e.g. from `from_clauses_group`), so use them as-is.
        strata = requirements
        centrality = None
    else:
        raise Exception(
            'constraints argument should be either str (i.e. filepath of the constraints) or List (i.e. strata)')

    # ConstraintsLayer([ConstraintsGroup], int)
    self.atoms = nn.Parameter(torch.tensor(list(range(num_classes))), requires_grad=False)

    # One correction module per stratum; forward() applies them in order.
    modules = [ConstraintsModule(stratum, num_classes) for stratum in strata]
    self.module_list = nn.ModuleList(modules)

    # The "core" is the set of variables that are never a clause head, i.e. never corrected;
    # they are passed through unchanged and seed the corrections of the strata above them.
    core = set(range(num_classes))
    strata = [stratum.heads() for stratum in strata]

    for stratum in strata:
        core = core.difference(stratum)

    assert len(core) > 0
    self.core = core
    self.strata = strata
    self.centrality = centrality

from_clauses_group classmethod

from_clauses_group(num_classes, clauses_group, centrality)

Build a Shield Layer from a clauses group and centrality ordering.

Parameters:

Name Type Description Default
num_classes

The number of output variables.

required
clauses_group

The :class:ClausesGroup of requirements to stratify.

required
centrality

The centrality / ordering used to stratify the clauses.

required

Returns:

Type Description

A configured ShieldLayer.

Source code in pishield/propositional_requirements/shield_layer.py
@classmethod
def from_clauses_group(cls, num_classes, clauses_group, centrality):
    """Build a Shield Layer from a clauses group and centrality ordering.

    Args:
        num_classes: The number of output variables.
        clauses_group: The :class:`ClausesGroup` of requirements to stratify.
        centrality: The centrality / ordering used to stratify the clauses.

    Returns:
        A configured ShieldLayer.
    """
    cls.centrality = centrality
    return cls(num_classes=num_classes, requirements=clauses_group.stratify(centrality))

gradual_prefix

gradual_prefix(ratio)

Select the atoms and number of strata covered by a fraction of variables.

Grows the never-corrected core with whole strata until roughly ratio of all variables are covered; used to gradually enable the requirements in training.

Parameters:

Name Type Description Default
ratio

The target fraction of variables to cover, in [0, 1].

required

Returns:

Type Description

A tuple (atoms, num_modules) of the covered atom set and the number of

leading strata included.

Source code in pishield/propositional_requirements/shield_layer.py
def gradual_prefix(self, ratio):
    """Select the atoms and number of strata covered by a fraction of variables.

    Grows the never-corrected core with whole strata until roughly ``ratio`` of all
    variables are covered; used to gradually enable the requirements in training.

    Args:
        ratio: The target fraction of variables to cover, in [0, 1].

    Returns:
        A tuple ``(atoms, num_modules)`` of the covered atom set and the number of
        leading strata included.
    """
    atoms = self.core
    remaining = math.floor(ratio * self.num_classes) - len(atoms)
    if (remaining <= 0): return atoms, 0

    for i, stratum in enumerate(self.strata):
        remaining -= len(stratum)
        if (remaining < 0): return atoms, i
        atoms = atoms.union(stratum)

    return atoms, len(self.strata)

slicer

slicer(ratio)

Build a :class:Slicer covering a fraction of the requirements.

Parameters:

Name Type Description Default
ratio

The target fraction of variables to cover, in [0, 1].

required

Returns:

Type Description

A Slicer over the corresponding atoms and leading modules.

Source code in pishield/propositional_requirements/shield_layer.py
def slicer(self, ratio):
    """Build a :class:`Slicer` covering a fraction of the requirements.

    Args:
        ratio: The target fraction of variables to cover, in [0, 1].

    Returns:
        A Slicer over the corresponding atoms and leading modules.
    """
    atoms, modules = self.gradual_prefix(ratio)
    return Slicer(atoms, modules)

to_minimal

to_minimal(tensor)

Restrict a tensor to the layer's atom columns.

Parameters:

Name Type Description Default
tensor

A tensor of shape (batch, num_classes).

required

Returns:

Type Description

The tensor restricted to the layer's atoms.

Source code in pishield/propositional_requirements/shield_layer.py
def to_minimal(self, tensor):
    """Restrict a tensor to the layer's atom columns.

    Args:
        tensor: A tensor of shape (batch, num_classes).

    Returns:
        The tensor restricted to the layer's atoms.
    """
    return tensor[:, self.atoms].reshape(tensor.shape[0], len(self.atoms))

from_minimal

from_minimal(tensor, init)

Scatter a minimal-atom tensor back into a full tensor.

Parameters:

Name Type Description Default
tensor

The minimal tensor over the layer's atoms.

required
init

The full tensor to write into.

required

Returns:

Type Description

init with the atom columns overwritten by tensor.

Source code in pishield/propositional_requirements/shield_layer.py
def from_minimal(self, tensor, init):
    """Scatter a minimal-atom tensor back into a full tensor.

    Args:
        tensor: The minimal tensor over the layer's atoms.
        init: The full tensor to write into.

    Returns:
        ``init`` with the atom columns overwritten by ``tensor``.
    """
    return init.index_copy(1, self.atoms, tensor)

forward

forward(preds, goal=None, iterative=True, slicer=None)

Correct predictions so they satisfy all requirements.

Restricts the predictions to the constrained atoms, applies each stratum's correction module in order, and scatters the corrected values back into the full prediction tensor.

Parameters:

Name Type Description Default
preds

The model predictions, shape (batch, num_classes).

required
goal

Optional goal (ground-truth) assignment to keep corrections consistent with during training.

None
iterative

If True use the iterative correction implementation.

True
slicer

Optional :class:Slicer restricting how many strata are applied.

None

Returns:

Type Description

The corrected predictions, same shape as preds.

Example

corrected = layer(predictions)

every requirement now holds on corrected
Source code in pishield/propositional_requirements/shield_layer.py
def forward(self, preds, goal=None, iterative=True, slicer=None):
    """Correct predictions so they satisfy all requirements.

    Restricts the predictions to the constrained atoms, applies each stratum's
    correction module in order, and scatters the corrected values back into the
    full prediction tensor.

    Args:
        preds: The model predictions, shape (batch, num_classes).
        goal: Optional goal (ground-truth) assignment to keep corrections
            consistent with during training.
        iterative: If True use the iterative correction implementation.
        slicer: Optional :class:`Slicer` restricting how many strata are applied.

    Returns:
        The corrected predictions, same shape as ``preds``.

    Example:
        >>> corrected = layer(predictions)
        >>> # every requirement now holds on `corrected`
    """
    # Restrict to the atoms involved in the constraints, apply each stratum's correction in order,
    # then scatter the corrected values back into the full prediction tensor.
    updated = self.to_minimal(preds)
    goal = None if goal is None else self.to_minimal(goal)

    modules = self.module_list if slicer is None else slicer.slice_modules(self.module_list)
    for module in modules:
        updated = module(updated, goal=goal, iterative=iterative)

    return self.from_minimal(updated, preds)

run_layer

run_layer(layer, preds, backward=False)

Run a Shield Layer both ways and assert the results agree.

Runs the layer with the iterative and tensor implementations, checks they are numerically close, and optionally exercises a backward pass; mainly a testing/debugging helper.

Parameters:

Name Type Description Default
layer

The :class:ShieldLayer to run.

required
preds

The prediction tensor.

required
backward

If True, add a random differentiable perturbation and backpropagate.

False

Returns:

Type Description

The corrected predictions (from the iterative implementation), detached.

Raises:

Type Description
AssertionError

If the two implementations disagree.

Source code in pishield/propositional_requirements/shield_layer.py
def run_layer(layer, preds, backward=False):
    """Run a Shield Layer both ways and assert the results agree.

    Runs the layer with the iterative and tensor implementations, checks they are
    numerically close, and optionally exercises a backward pass; mainly a
    testing/debugging helper.

    Args:
        layer: The :class:`ShieldLayer` to run.
        preds: The prediction tensor.
        backward: If True, add a random differentiable perturbation and backpropagate.

    Returns:
        The corrected predictions (from the iterative implementation), detached.

    Raises:
        AssertionError: If the two implementations disagree.
    """
    if backward:
        extra = torch.rand_like(preds, requires_grad=True)
        preds = preds + extra

    iter = layer(preds, iterative=True)
    tens = layer(preds, iterative=False)
    assert torch.isclose(iter, tens).all()

    if backward:
        sum = iter.sum() + tens.sum()
        sum.backward()

    return iter.detach()

Memory-efficient Loss

shield_loss

The propositional Memory-efficient Loss.

Defines :class:ShieldLoss, a t-norm based penalty term that encourages (but does not enforce) the satisfaction of propositional requirements. It is a memory-efficient t-norm loss inspired by Logic Tensor Networks (LTN). Each requirement is read as a disjunction (clause) and its degree of satisfaction is computed under one of three t-norms - Goedel, Lukasiewicz or product - using sparse matrix representations of the requirements.

ShieldLoss

ShieldLoss(num_variables: int, requirements_filepath: str, tnorm_choice: str = 'godel')

Bases: Module

The Memory-efficient Loss: a t-norm based loss term that encourages the satisfaction of propositional requirements. It is a memory-efficient t-norm loss inspired by Logic Tensor Networks (LTN).

Unlike the Shield Layer, the Memory-efficient Loss does not correct the predictions; it returns a scalar penalty (in [0, 1]) which is minimised when the requirements are satisfied. The penalty is computed using one of three t-norms: 'godel', 'product' or 'lukasiewicz'.

The requirements are read from a file whose lines have the form head :- body, where head is a single literal and body is a (possibly empty) list of literals. A literal is the index of a variable (e.g. 3) for a positive literal, or that index prefixed with n (e.g. n3) for a negative literal.

Load the requirements and precompute the t-norm matrices.

Parameters:

Name Type Description Default
num_variables int

The number of variables (labels) the predictions cover.

required
requirements_filepath str

Path to the requirements file (one head :- body rule per line).

required
tnorm_choice str

The t-norm to use: 'godel', 'lukasiewicz' or 'product'.

'godel'
Example

loss_fn = ShieldLoss(num_variables=10, ... requirements_filepath='constraints.txt', ... tnorm_choice='product') penalty = loss_fn(predictions) # predictions: (batch, 10)

Source code in pishield/propositional_requirements/shield_loss.py
def __init__(self, num_variables: int, requirements_filepath: str, tnorm_choice: str = 'godel'):
    """Load the requirements and precompute the t-norm matrices.

    Args:
        num_variables: The number of variables (labels) the predictions cover.
        requirements_filepath: Path to the requirements file (one ``head :- body``
            rule per line).
        tnorm_choice: The t-norm to use: ``'godel'``, ``'lukasiewicz'`` or
            ``'product'``.

    Example:
        >>> loss_fn = ShieldLoss(num_variables=10,
        ...                      requirements_filepath='constraints.txt',
        ...                      tnorm_choice='product')
        >>> penalty = loss_fn(predictions)  # predictions: (batch, 10)
    """
    super().__init__()
    self.num_variables = num_variables
    self.requirements_filepath = requirements_filepath
    self.tnorm_choice = tnorm_choice
    self.create_matrices()

create_matrices

create_matrices()

Build the positive/negative literal matrices used by the t-norm losses.

Combines the body (I) and head (M) encodings into Cplus and Cminus matrices marking the positive and negative literal appearances in each requirement's disjunction (the exact combination depends on the chosen t-norm), and records the number of requirements.

Source code in pishield/propositional_requirements/shield_loss.py
def create_matrices(self):
    """Build the positive/negative literal matrices used by the t-norm losses.

    Combines the body (``I``) and head (``M``) encodings into ``Cplus`` and
    ``Cminus`` matrices marking the positive and negative literal appearances in
    each requirement's disjunction (the exact combination depends on the chosen
    t-norm), and records the number of requirements.
    """
    Iplus_np, Iminus_np = self.createIs()
    Mplus_np, Mminus_np = self.createMs()

    Iplus, Iminus = torch.from_numpy(Iplus_np).float(), torch.from_numpy(Iminus_np).float()
    Mplus, Mminus = torch.from_numpy(Mplus_np).float(), torch.from_numpy(Mminus_np).float()

    if self.tnorm_choice == "product":
        # These are already the negated literals
        # matrix of negative appearances in the conjunction
        Cminus = Iminus + torch.transpose(Mplus, 0, 1)
        # matrix of positive appearances in the conjunction
        Cplus = Iplus + torch.transpose(Mminus, 0, 1)
    else:  # elif args.LOGIC == "Godel" or args.LOGIC == "Lukasiewicz":
        # These are the literals as they appear in the disjunction
        # Matrix of the positive appearances in the disjunction
        Cplus = Iminus + torch.transpose(Mplus, 0, 1)
        # matrix of negative appearances in the conjunction
        Cminus = Iplus + torch.transpose(Mminus, 0, 1)

    self.Cplus = Cplus
    self.Cminus = Cminus
    self.NUM_REQ = Iplus.shape[0]

createIs

createIs()

Encode the body literals of each requirement into indicator matrices.

Reads the requirements file and, for each rule's body, marks which variables appear as positive (Iplus) and negative (Iminus) literals.

Returns:

Type Description

A tuple (Iplus, Iminus) of arrays of shape (num_requirements,

num_variables).

Source code in pishield/propositional_requirements/shield_loss.py
def createIs(self):
    """Encode the body literals of each requirement into indicator matrices.

    Reads the requirements file and, for each rule's body, marks which variables
    appear as positive (``Iplus``) and negative (``Iminus``) literals.

    Returns:
        A tuple ``(Iplus, Iminus)`` of arrays of shape (num_requirements,
        num_variables).
    """
    # Matrix with indices for positive literals
    Iplus = []
    # Matrix with indeces for negative literals
    Iminus = []
    with open(self.requirements_filepath, 'r') as f:
        for line in f:
            split_line = line.split()
            assert split_line[2] == ':-', "Instead of :- found: %s" % split_line[2]
            iplus = np.zeros(self.num_variables)
            iminus = np.zeros(self.num_variables)
            for item in split_line[3:]:
                if 'n' in item:
                    index = int(item[1:])
                    iminus[index] = 1
                else:
                    index = int(item)
                    iplus[index] = 1
            Iplus.append(iplus)
            Iminus.append(iminus)
    Iplus = np.array(Iplus)
    Iminus = np.array(Iminus)
    return Iplus, Iminus

createMs

createMs()

Encode the head literal of each requirement into indicator matrices.

Reads the requirements file and marks, per requirement, whether its head is a positive (Mplus) or negative (Mminus) literal at the head's variable.

Returns:

Type Description

A tuple (Mplus, Mminus) of arrays of shape (num_variables,

num_requirements); each column corresponds to a requirement and carries a 1

at the head variable's row for the matching polarity.

Source code in pishield/propositional_requirements/shield_loss.py
def createMs(self):
    """Encode the head literal of each requirement into indicator matrices.

    Reads the requirements file and marks, per requirement, whether its head is a
    positive (``Mplus``) or negative (``Mminus``) literal at the head's variable.

    Returns:
        A tuple ``(Mplus, Mminus)`` of arrays of shape (num_variables,
        num_requirements); each column corresponds to a requirement and carries a 1
        at the head variable's row for the matching polarity.
    """
    Mplus, Mminus = [], []
    with open(self.requirements_filepath, 'r') as f:
        for line in f:
            split_line = line.split()
            assert split_line[2] == ':-'
            mplus = np.zeros(self.num_variables)
            mminus = np.zeros(self.num_variables)
            if 'n' in split_line[1]:
                # one indentified that is negative, ignore the 'n' to get the index
                index = int(split_line[1][1:])
                mminus[index] = 1
            else:
                index = int(split_line[1])
                mplus[index] = 1
            Mplus.append(mplus)
            Mminus.append(mminus)
    Mplus = np.array(Mplus).transpose()
    Mminus = np.array(Mminus).transpose()

    return Mplus, Mminus

get_sparse_representation

get_sparse_representation(req_matrix)

Return the sparse indices and values of a requirement matrix.

Parameters:

Name Type Description Default
req_matrix

A dense requirement matrix.

required

Returns:

Type Description

A tuple (indices, values) of the matrix's non-zero coordinates and

their values.

Source code in pishield/propositional_requirements/shield_loss.py
def get_sparse_representation(self, req_matrix):
    """Return the sparse indices and values of a requirement matrix.

    Args:
        req_matrix: A dense requirement matrix.

    Returns:
        A tuple ``(indices, values)`` of the matrix's non-zero coordinates and
        their values.
    """
    req_matrix = req_matrix.to_sparse()
    return req_matrix.indices(), req_matrix.values()

godel_disjunctions_sparse

godel_disjunctions_sparse(preds, weighted_literals=False)

Compute the Goedel-t-norm requirement penalty.

Each requirement's satisfaction degree is the maximum over its literals' truth values; the penalty is one minus the mean satisfaction degree.

Parameters:

Name Type Description Default
preds

The predicted probabilities, shape (batch, num_variables).

required
weighted_literals

If True, weight each literal by its matrix value.

False

Returns:

Type Description

A scalar penalty in [0, 1], minimised when the requirements are satisfied.

Source code in pishield/propositional_requirements/shield_loss.py
def godel_disjunctions_sparse(self, preds, weighted_literals=False):
    """Compute the Goedel-t-norm requirement penalty.

    Each requirement's satisfaction degree is the maximum over its literals' truth
    values; the penalty is one minus the mean satisfaction degree.

    Args:
        preds: The predicted probabilities, shape (batch, num_variables).
        weighted_literals: If True, weight each literal by its matrix value.

    Returns:
        A scalar penalty in [0, 1], minimised when the requirements are satisfied.
    """
    constr_values = torch.zeros(preds.shape[0], self.NUM_REQ).to(preds.device)

    indices_nnz_plus, values_nnz_plus = self.get_sparse_representation(self.Cplus)
    indices_nnz_minus, values_nnz_minus = self.get_sparse_representation(self.Cminus)

    # predictions_at_nonzero_values is a matrix [num bboxes, num_nonzero_vals_in_Cplus] which contains
    # the predicted value associated with each label (ordered as they appear in the columns of Cplus)
    predictions_at_nnz_values_plus = preds[:, indices_nnz_plus[1, :]]
    predictions_at_nnz_values_minus = (1. - preds[:, indices_nnz_minus[1, :]])
    if weighted_literals:
        predictions_at_nnz_values_plus *= values_nnz_plus
        predictions_at_nnz_values_minus *= values_nnz_minus

    # the line inside the loop below essentially means that:
    # the constraints containing label k are each multiplied by the value of the prediction for label k
    for k in range(self.num_variables):
        # ind[0, ind[1] == k] returns a list of indices of the requirements in which the kth label appears
        # positively in the conjunction
        # ind[1] == k creates a mask of dim [460] which is equal to 1 if the ith element in the matrix of the
        # indexes is equal to k
        constr_values[:, indices_nnz_plus[0, indices_nnz_plus[1] == k]] = torch.maximum(
            constr_values[:, indices_nnz_plus[0, indices_nnz_plus[1] == k]],
            predictions_at_nnz_values_plus[:, indices_nnz_plus[1] == k])
        constr_values[:, indices_nnz_minus[0, indices_nnz_minus[1] == k]] = torch.maximum(
            constr_values[:, indices_nnz_minus[0, indices_nnz_minus[1] == k]],
            predictions_at_nnz_values_minus[:, indices_nnz_minus[1] == k])

    req_loss = torch.mean(constr_values)
    # We need to do 1-req_loss because we want to maximise the probability p of satisfying our requirements,
    # and hence we want to minimize the 1-p
    return 1 - req_loss

lukasiewicz_disjunctions_sparse

lukasiewicz_disjunctions_sparse(preds, weighted_literals=False)

Compute the Lukasiewicz-t-norm requirement penalty.

Each requirement's satisfaction degree is the sum of its literals' truth values clamped to 1; the penalty is one minus the mean satisfaction degree.

Parameters:

Name Type Description Default
preds

The predicted probabilities, shape (batch, num_variables).

required
weighted_literals

If True, weight each literal by its matrix value.

False

Returns:

Type Description

A scalar penalty in [0, 1], minimised when the requirements are satisfied.

Source code in pishield/propositional_requirements/shield_loss.py
def lukasiewicz_disjunctions_sparse(self, preds, weighted_literals=False):
    """Compute the Lukasiewicz-t-norm requirement penalty.

    Each requirement's satisfaction degree is the sum of its literals' truth
    values clamped to 1; the penalty is one minus the mean satisfaction degree.

    Args:
        preds: The predicted probabilities, shape (batch, num_variables).
        weighted_literals: If True, weight each literal by its matrix value.

    Returns:
        A scalar penalty in [0, 1], minimised when the requirements are satisfied.
    """
    constr_values_unbounded = torch.zeros(preds.shape[0], self.NUM_REQ).to(preds.device)

    indices_nnz_plus, values_nnz_plus = self.get_sparse_representation(self.Cplus)
    indices_nnz_minus, values_nnz_minus = self.get_sparse_representation(self.Cminus)

    # predictions_at_nonzero_values is a matrix [num bboxes, num_nonzero_vals_in_Cplus] which contains
    # the predicted value associated with each label (ordered as they appear in the columns of Cplus)
    predictions_at_nnz_values_plus = preds[:, indices_nnz_plus[1, :]]
    predictions_at_nnz_values_minus = (1. - preds[:, indices_nnz_minus[1, :]])
    if weighted_literals:
        predictions_at_nnz_values_plus *= values_nnz_plus
        predictions_at_nnz_values_minus *= values_nnz_minus

    # the line inside the loop below essentially means that:
    # the constraints containing label k are each multiplied by the value of the prediction for label k
    for k in range(self.num_variables):
        # ind[0, ind[1] == k] returns a list of indices of the requirements in which the kth label appears
        # positively in the conjunction
        # ind[1] == k creates a mask of dim [460] which is equal to 1 if the ith element in the matrix of the
        # indexes is equal to k
        constr_values_unbounded[:, indices_nnz_plus[0, indices_nnz_plus[1] == k]] += \
            predictions_at_nnz_values_plus[:, indices_nnz_plus[1] == k]
        constr_values_unbounded[:, indices_nnz_minus[0, indices_nnz_minus[1] == k]] += \
            predictions_at_nnz_values_minus[:, indices_nnz_minus[1] == k]

    constr_values = torch.min(torch.ones_like(constr_values_unbounded), constr_values_unbounded)
    req_loss = torch.mean(constr_values)

    # We need to do 1-req_loss because we want to maximise the probability p of satisfying our requirements,
    # and hence we want to minimize the 1-p
    return 1 - req_loss

product_disjunctions_sparse

product_disjunctions_sparse(preds, weighted_literals=False)

Compute the product-t-norm requirement penalty.

The disjunction is computed as the negation of the product of the negated literals; the penalty is one minus the mean satisfaction degree.

Parameters:

Name Type Description Default
preds

The predicted probabilities, shape (batch, num_variables).

required
weighted_literals

If True, weight each literal by its matrix value.

False

Returns:

Type Description

A scalar penalty in [0, 1], minimised when the requirements are satisfied.

Source code in pishield/propositional_requirements/shield_loss.py
def product_disjunctions_sparse(self, preds, weighted_literals=False):
    """Compute the product-t-norm requirement penalty.

    The disjunction is computed as the negation of the product of the negated
    literals; the penalty is one minus the mean satisfaction degree.

    Args:
        preds: The predicted probabilities, shape (batch, num_variables).
        weighted_literals: If True, weight each literal by its matrix value.

    Returns:
        A scalar penalty in [0, 1], minimised when the requirements are satisfied.
    """
    # The disjunction is more complex to implement than the conjunction
    # e.g., A and B --> A*B while A or B --> A + B - A*B
    # Thus we see the disjunction as the negation of the conjunction of the negations of all its
    # literals (i.e., A or B = neg (neg A and neg B))

    constr_values = torch.ones(preds.shape[0], self.NUM_REQ).to(preds.device)

    indices_nnz_plus, values_nnz_plus = self.get_sparse_representation(self.Cplus)
    indices_nnz_minus, values_nnz_minus = self.get_sparse_representation(self.Cminus)

    # predictions_at_nonzero_values is a matrix [num bboxes, num_nonzero_vals_in_Cplus] which contains
    # the predicted value associated with each label (ordered as they appear in the columns of Cplus)
    predictions_at_nnz_values_plus = preds[:, indices_nnz_plus[1, :]]
    predictions_at_nnz_values_minus = (1. - preds[:, indices_nnz_minus[1, :]])
    if weighted_literals:
        predictions_at_nnz_values_plus *= values_nnz_plus
        predictions_at_nnz_values_minus *= values_nnz_minus

    # the line inside the loop below essentially means that:
    # the constraints containing label k are each multiplied by the value of the prediction for label k
    for k in range(self.num_variables):
        # ind[0, ind[1] == k] returns a list of indices of the requirements in which the kth label appears
        # positively in the conjunction
        # ind[1] == k creates a mask of dim [460] which is equal to 1 if the ith element in the matrix of the
        # indexes is equal to k
        constr_values[:, indices_nnz_plus[0, indices_nnz_plus[1] == k]] *= \
            predictions_at_nnz_values_plus[:, indices_nnz_plus[1] == k]
        constr_values[:, indices_nnz_minus[0, indices_nnz_minus[1] == k]] *= \
            predictions_at_nnz_values_minus[:, indices_nnz_minus[1] == k]

    # Negate the value of the conjunction
    req_loss = torch.mean(1. - constr_values)

    # We need to do 1-req_loss because we want to maximise the probability p of satisfying our requirements,
    # and hence we want to minimize the 1-p
    return 1 - req_loss

Literal

literal

Propositional literals.

A literal is a single variable (atom) together with a sign: either the variable asserted positively or its negation. Literals are the atomic building blocks of clauses and constraints in the propositional requirements subpackage.

Literal

Literal(*args, reversed_sign=False)

A propositional literal: a variable (atom) and a polarity.

Attributes:

Name Type Description
atom

The integer index of the variable referenced by the literal.

positive

True if the literal asserts the variable, False if it negates it.

Build a literal from either an (atom, polarity) pair or a string.

Two calling conventions are supported
  • Literal(atom: int, positive: bool) builds the literal directly.
  • Literal(text: str) parses a textual literal. Supported forms are 'y_3'/'y_not 3' (label notation) and '3'/'n3' (index notation), where the n prefix denotes a negative literal.

Parameters:

Name Type Description Default
*args

Either two positional arguments (atom, positive) or a single string to parse.

()
reversed_sign

When parsing 'y_' style strings, flips the polarity so the sign matches the head :- body convention used elsewhere.

False
Source code in pishield/propositional_requirements/literal.py
def __init__(self, *args, reversed_sign=False):
    """Build a literal from either an (atom, polarity) pair or a string.

    Two calling conventions are supported:
      * ``Literal(atom: int, positive: bool)`` builds the literal directly.
      * ``Literal(text: str)`` parses a textual literal. Supported forms are
        ``'y_3'``/``'y_not 3'`` (label notation) and ``'3'``/``'n3'`` (index
        notation), where the ``n`` prefix denotes a negative literal.

    Args:
        *args: Either two positional arguments ``(atom, positive)`` or a single
            string to parse.
        reversed_sign: When parsing ``'y_'`` style strings, flips the polarity so
            the sign matches the ``head :- body`` convention used elsewhere.
    """
    if len(args) == 2:
        # Literal(int, bool)
        self.atom = args[0]
        self.positive = args[1]
    else:
        # Literal(string)
        plain = args[0]
        if 'y_' in plain:
            if 'not ' in plain:
                self.atom = int(plain[6:])
                if reversed_sign:
                    self.positive = True   # set to True, to account for the :- format, which the code uses
                else:
                    self.positive = False   # set to True, to account for the :- format, which the code uses
            else:
                self.atom = int(plain[2:])
                if reversed_sign:
                    self.positive = False  # set to False, to account for the :- format, which the code uses
                else:
                    self.positive = True
        else:
            if 'n' in plain:
                self.atom = int(plain[1:])
                self.positive = False
            else:
                self.atom = int(plain)
                self.positive = True

neg

neg()

Return a new literal with the opposite polarity.

Returns:

Type Description

A Literal over the same atom with flipped sign.

Source code in pishield/propositional_requirements/literal.py
def neg(self):
    """Return a new literal with the opposite polarity.

    Returns:
        A Literal over the same atom with flipped sign.
    """
    return Literal(self.atom, not self.positive)

Clause

clause

Propositional clauses.

A clause is a disjunction of literals (e.g. y_0 or not y_1), represented as an unordered set of :class:Literal objects. Clauses are the normalised form into which requirements are converted before stratification, and they support the logical operations (resolution, subsumption, coherency checks) used to build the Shield Layer.

Clause

Clause(literals)

A disjunction of literals.

Attributes:

Name Type Description
literals

A frozenset of the :class:Literal objects in the disjunction.

Build a clause from literals.

Parameters:

Name Type Description Default
literals

Either a whitespace-separated string of literals (e.g. '0 n1 2') or an iterable of :class:Literal objects.

required
Source code in pishield/propositional_requirements/clause.py
def __init__(self, literals):
    """Build a clause from literals.

    Args:
        literals: Either a whitespace-separated string of literals (e.g.
            ``'0 n1 2'``) or an iterable of :class:`Literal` objects.
    """
    if isinstance(literals, str):
        # Clause(string)
        literals = [Literal(lit) for lit in literals.split(' ')]
        self.literals = frozenset(literals)
    else:
        # Clause([Literals])
        self.literals = frozenset(literals)

from_constraint classmethod

from_constraint(constraint)

Build the clause equivalent to a Horn constraint head :- body.

The constraint head :- b1, b2 is logically head or not b1 or not b2, so each body literal is negated and the head kept as-is.

Parameters:

Name Type Description Default
constraint

The :class:Constraint to convert.

required

Returns:

Type Description

The equivalent Clause.

Source code in pishield/propositional_requirements/clause.py
@classmethod
def from_constraint(cls, constraint):
    """Build the clause equivalent to a Horn constraint ``head :- body``.

    The constraint ``head :- b1, b2`` is logically ``head or not b1 or not b2``,
    so each body literal is negated and the head kept as-is.

    Args:
        constraint: The :class:`Constraint` to convert.

    Returns:
        The equivalent Clause.
    """
    body = [lit.neg() for lit in constraint.body]
    return cls([constraint.head] + body)

random classmethod

random(num_classes)

Build a random clause over num_classes variables.

Parameters:

Name Type Description Default
num_classes

The number of available variables (atom indices).

required

Returns:

Type Description

A randomly generated Clause.

Source code in pishield/propositional_requirements/clause.py
@classmethod
def random(cls, num_classes):
    """Build a random clause over ``num_classes`` variables.

    Args:
        num_classes: The number of available variables (atom indices).

    Returns:
        A randomly generated Clause.
    """
    atoms_count = np.random.randint(low=1, high=num_classes, size=1)
    atoms = np.random.randint(num_classes, size=atoms_count)

    pos = atoms[np.random.randint(2, size=atoms_count) == 1]
    literals = [Literal(atom, atom in pos) for atom in atoms]
    return cls(literals)

shift_add_n0

shift_add_n0()

Shift every atom up by one and add the negative literal n0.

Used to make room for a detection variable at index 0.

Returns:

Type Description

A new Clause with shifted atoms plus the n0 literal.

Source code in pishield/propositional_requirements/clause.py
def shift_add_n0(self):
    """Shift every atom up by one and add the negative literal ``n0``.

    Used to make room for a detection variable at index 0.

    Returns:
        A new Clause with shifted atoms plus the ``n0`` literal.
    """
    n0 = Literal(0, False)
    return Clause([Literal(lit.atom + 1, lit.positive) for lit in self] + [n0])

fix_head

fix_head(head)

Turn the clause into a constraint by designating one literal as the head.

The remaining literals become the (negated) body, inverting :meth:from_constraint.

Parameters:

Name Type Description Default
head

The literal of this clause to use as the constraint head.

required

Returns:

Name Type Description
A

class:Constraint with the given head.

Raises:

Type Description
Exception

If head is not a literal of this clause.

Source code in pishield/propositional_requirements/clause.py
def fix_head(self, head):
    """Turn the clause into a constraint by designating one literal as the head.

    The remaining literals become the (negated) body, inverting
    :meth:`from_constraint`.

    Args:
        head: The literal of this clause to use as the constraint head.

    Returns:
        A :class:`Constraint` with the given head.

    Raises:
        Exception: If ``head`` is not a literal of this clause.
    """
    if not head in self.literals:
        raise Exception('Head not in clause')
    body = [lit.neg() for lit in self.literals if lit != head]
    return Constraint(head, body)

always_true

always_true()

Return True if the clause is a tautology.

A clause is always true when it contains both a literal and its negation.

Returns:

Type Description

True if the clause is tautological, False otherwise.

Source code in pishield/propositional_requirements/clause.py
def always_true(self):
    """Return True if the clause is a tautology.

    A clause is always true when it contains both a literal and its negation.

    Returns:
        True if the clause is tautological, False otherwise.
    """
    for literal in self.literals:
        if literal.neg() in self.literals:
            return True
    return False

resolution_on

resolution_on(other, literal)

Resolve this clause with another on a given literal.

Parameters:

Name Type Description Default
other

The other clause to resolve with.

required
literal

The literal to resolve on; it and its negation are removed from the union of the two clauses.

required

Returns:

Type Description

The resolvent Clause, or None if the resolvent is a tautology.

Source code in pishield/propositional_requirements/clause.py
def resolution_on(self, other, literal):
    """Resolve this clause with another on a given literal.

    Args:
        other: The other clause to resolve with.
        literal: The literal to resolve on; it and its negation are removed
            from the union of the two clauses.

    Returns:
        The resolvent Clause, or None if the resolvent is a tautology.
    """
    result = self.literals.union(other.literals).difference({literal, literal.neg()})
    result = Clause(result)
    return None if result.always_true() else result

resolution

resolution(other, literal=None)

Resolve this clause with another, finding a pivot literal if needed.

Parameters:

Name Type Description Default
other

The other clause to resolve with.

required
literal

The literal to resolve on. If None, the first literal of this clause whose negation appears in other is used.

None

Returns:

Type Description

The resolvent Clause, or None if no pivot is found or the resolvent is

a tautology.

Source code in pishield/propositional_requirements/clause.py
def resolution(self, other, literal=None):
    """Resolve this clause with another, finding a pivot literal if needed.

    Args:
        other: The other clause to resolve with.
        literal: The literal to resolve on. If None, the first literal of this
            clause whose negation appears in ``other`` is used.

    Returns:
        The resolvent Clause, or None if no pivot is found or the resolvent is
        a tautology.
    """
    if literal != None:
        return self.resolution_on(other, literal)

    for lit in self.literals:
        if lit.neg() in other.literals:
            return self.resolution_on(other, lit)

    return None

always_false

always_false()

Return True if the clause is empty (and hence unsatisfiable).

Source code in pishield/propositional_requirements/clause.py
def always_false(self):
    """Return True if the clause is empty (and hence unsatisfiable)."""
    return len(self) == 0

coherent_with

coherent_with(preds)

Check whether predictions satisfy this clause.

A clause is satisfied when at least one of its literals is true; here this is relaxed to a probability above 0.5 for any literal.

Parameters:

Name Type Description Default
preds

A 2D array of predicted probabilities, shape (batch, num_classes).

required

Returns:

Type Description

A boolean array of length batch, True where the clause is satisfied.

Source code in pishield/propositional_requirements/clause.py
def coherent_with(self, preds):
    """Check whether predictions satisfy this clause.

    A clause is satisfied when at least one of its literals is true; here this is
    relaxed to a probability above 0.5 for any literal.

    Args:
        preds: A 2D array of predicted probabilities, shape (batch, num_classes).

    Returns:
        A boolean array of length batch, True where the clause is satisfied.
    """
    pos = [lit.atom for lit in self.literals if lit.positive]
    neg = [lit.atom for lit in self.literals if not lit.positive]

    preds = np.concatenate((preds[:, pos], 1 - preds[:, neg]), axis=1)
    preds = preds.max(axis=1)
    return preds > 0.5

is_subset

is_subset(other)

Return True if this clause's literals are a subset of other's.

Parameters:

Name Type Description Default
other

The clause to test against.

required

Returns:

Type Description

True if every literal of this clause is in other.

Source code in pishield/propositional_requirements/clause.py
def is_subset(self, other):
    """Return True if this clause's literals are a subset of ``other``'s.

    Args:
        other: The clause to test against.

    Returns:
        True if every literal of this clause is in ``other``.
    """
    return self.literals.issubset(other.literals)

atoms

atoms()

Return the set of variable indices appearing in the clause.

Source code in pishield/propositional_requirements/clause.py
def atoms(self):
    """Return the set of variable indices appearing in the clause."""
    return {lit.atom for lit in self.literals}

Clauses group

clauses_group

Collections of propositional clauses.

A :class:ClausesGroup holds a set of :class:Clause disjunctions and provides the machinery that turns requirements into the Shield Layer: variable elimination by resolution, clause compaction, centrality-based ordering, and stratification into ordered :class:ConstraintsGroup layers (optionally enforcing strong coherency).

ClausesGroup

ClausesGroup(clauses)

A set of clauses supporting resolution and stratification.

Attributes:

Name Type Description
clauses

A frozenset of the :class:Clause objects in the group.

clauses_list

The clauses in their original order, kept so that coherency results have a stable column order.

Build a clauses group from an iterable of clauses.

Parameters:

Name Type Description Default
clauses

An iterable of :class:Clause objects.

required
Source code in pishield/propositional_requirements/clauses_group.py
def __init__(self, clauses):
    """Build a clauses group from an iterable of clauses.

    Args:
        clauses: An iterable of :class:`Clause` objects.
    """
    # ClausesGroup([Clause])
    self.clauses = frozenset(clauses)
    self.clauses_list = clauses

from_constraints_group classmethod

from_constraints_group(group)

Build a clauses group from a :class:ConstraintsGroup.

Parameters:

Name Type Description Default
group

The constraints group to convert (each constraint becomes a clause).

required

Returns:

Type Description

The equivalent ClausesGroup.

Source code in pishield/propositional_requirements/clauses_group.py
@classmethod
def from_constraints_group(cls, group):
    """Build a clauses group from a :class:`ConstraintsGroup`.

    Args:
        group: The constraints group to convert (each constraint becomes a clause).

    Returns:
        The equivalent ClausesGroup.
    """
    return cls([Clause.from_constraint(cons) for cons in group])

random classmethod

random(max_clauses, num_classes, coherent_with=np.array([]), min_clauses=0)

Build a random clauses group, optionally filtered by coherency.

Generates up to max_clauses random clauses; if coherent_with is given, only clauses satisfied by every one of those predictions are kept, and generation is repeated until at least min_clauses survive.

Parameters:

Name Type Description Default
max_clauses

The number of random clauses to generate per attempt.

required
num_classes

The number of variables available.

required
coherent_with

Optional array of predictions the clauses must satisfy.

array([])
min_clauses

The minimum number of clauses required in the result.

0

Returns:

Type Description

A randomly generated ClausesGroup.

Source code in pishield/propositional_requirements/clauses_group.py
@classmethod
def random(cls, max_clauses, num_classes, coherent_with=np.array([]), min_clauses=0):
    """Build a random clauses group, optionally filtered by coherency.

    Generates up to ``max_clauses`` random clauses; if ``coherent_with`` is given,
    only clauses satisfied by every one of those predictions are kept, and
    generation is repeated until at least ``min_clauses`` survive.

    Args:
        max_clauses: The number of random clauses to generate per attempt.
        num_classes: The number of variables available.
        coherent_with: Optional array of predictions the clauses must satisfy.
        min_clauses: The minimum number of clauses required in the result.

    Returns:
        A randomly generated ClausesGroup.
    """
    assert min_clauses <= max_clauses
    clauses = [Clause.random(num_classes) for i in range(max_clauses)]

    if len(coherent_with) > 0:
        keep = cls(clauses).coherent_with(coherent_with).all(axis=0)
        clauses = np.array(clauses)[keep].tolist()

    found = len(clauses)
    if found < min_clauses:
        other = cls.random(max_clauses - found, num_classes, coherent_with=coherent_with,
                           min_clauses=min_clauses - found)
        return cls(clauses) + other
    else:
        return cls(clauses)

add_detection_label

add_detection_label(forced=False)

Insert a detection variable at index 0 and shift the others up.

Each clause's atoms are shifted up by one and the negative literal n0 added (so the clause is vacuously satisfied when the detection variable is off). When forced, extra clauses tie the detection variable to every atom.

Parameters:

Name Type Description Default
forced

If True, also add 0 n{x+1} clauses for every atom.

False

Returns:

Type Description

A new ClausesGroup including the detection variable.

Source code in pishield/propositional_requirements/clauses_group.py
def add_detection_label(self, forced=False):
    """Insert a detection variable at index 0 and shift the others up.

    Each clause's atoms are shifted up by one and the negative literal ``n0`` added
    (so the clause is vacuously satisfied when the detection variable is off).
    When ``forced``, extra clauses tie the detection variable to every atom.

    Args:
        forced: If True, also add ``0 n{x+1}`` clauses for every atom.

    Returns:
        A new ClausesGroup including the detection variable.
    """
    n0 = Literal(0, False)
    clauses = [clause.shift_add_n0() for clause in self]
    forced = [Clause(f"0 n{x + 1}") for x in self.atoms()] if forced else []
    return ClausesGroup(clauses + forced)

compacted

compacted()

Remove clauses subsumed by a smaller clause.

Sorts clauses from longest to shortest and drops any clause that is a superset of (i.e. subsumed by) a retained clause.

Returns:

Type Description

A new, subsumption-free ClausesGroup.

Source code in pishield/propositional_requirements/clauses_group.py
def compacted(self):
    """Remove clauses subsumed by a smaller clause.

    Sorts clauses from longest to shortest and drops any clause that is a superset
    of (i.e. subsumed by) a retained clause.

    Returns:
        A new, subsumption-free ClausesGroup.
    """
    clauses = list(self.clauses)
    clauses.sort(reverse=True, key=len)
    compacted = []

    for clause in clauses:
        compacted = [c for c in compacted if not clause.is_subset(c)]
        compacted.append(clause)

    # print(f"compacted {len(clauses) - len(compacted)} out of {len(clauses)}")
    return ClausesGroup(compacted)

resolution

resolution(atom)

Eliminate a variable by resolution, returning its constraints.

Splits the clauses into those containing the positive literal, the negative literal, and neither; resolves each positive against each negative clause to produce the remaining clauses (compacted), and turns the positive and negative clauses into constraints whose head is the eliminated literal.

Parameters:

Name Type Description Default
atom

The variable index to eliminate.

required

Returns:

Type Description

A tuple (constraints, next_clauses) where constraints is a

class:ConstraintsGroup defining the eliminated atom and next_clauses

is the remaining :class:ClausesGroup over the other atoms.

Source code in pishield/propositional_requirements/clauses_group.py
def resolution(self, atom):
    """Eliminate a variable by resolution, returning its constraints.

    Splits the clauses into those containing the positive literal, the negative
    literal, and neither; resolves each positive against each negative clause to
    produce the remaining clauses (compacted), and turns the positive and negative
    clauses into constraints whose head is the eliminated literal.

    Args:
        atom: The variable index to eliminate.

    Returns:
        A tuple ``(constraints, next_clauses)`` where ``constraints`` is a
        :class:`ConstraintsGroup` defining the eliminated atom and ``next_clauses``
        is the remaining :class:`ClausesGroup` over the other atoms.
    """
    pos = Literal(atom, True)
    neg = Literal(atom, False)

    # Split clauses in three categories
    pos_clauses, neg_clauses, other_clauses = set(), set(), set()
    for clause in self.clauses:
        if pos in clause:
            pos_clauses.add(clause)
        elif neg in clause:
            neg_clauses.add(clause)
        else:
            other_clauses.add(clause)

    # Apply resolution on positive and negative clauses
    resolution_clauses = [c1.resolution(c2, literal=pos) for c1 in pos_clauses for c2 in neg_clauses]
    resolution_clauses = {clause for clause in resolution_clauses if clause != None}
    next_clauses = ClausesGroup(other_clauses.union(resolution_clauses)).compacted()

    # Compute constraints 
    pos_constraints = [clause.fix_head(pos) for clause in pos_clauses]
    neg_constraints = [clause.fix_head(neg) for clause in neg_clauses]
    constraints = ConstraintsGroup(pos_constraints + neg_constraints)

    return constraints, next_clauses

graph

graph()

Build a bipartite graph linking clauses to their atoms.

Returns:

Type Description

A networkx.Graph with kind='atom' and kind='clause' nodes and

an edge between each clause and every atom it mentions.

Source code in pishield/propositional_requirements/clauses_group.py
def graph(self):
    """Build a bipartite graph linking clauses to their atoms.

    Returns:
        A ``networkx.Graph`` with ``kind='atom'`` and ``kind='clause'`` nodes and
        an edge between each clause and every atom it mentions.
    """
    G = nx.Graph()
    G.add_nodes_from(self.atoms(), kind='atom')
    G.add_nodes_from(self.clauses, kind='clause')

    for clause in self.clauses:
        for lit in clause:
            G.add_edge(clause, lit.atom)

    return G

centrality_measures staticmethod

centrality_measures()

Return the names of the supported centrality measures.

Returns:

Type Description

A list of measure names usable with :meth:centrality.

Source code in pishield/propositional_requirements/clauses_group.py
@staticmethod
def centrality_measures():
    """Return the names of the supported centrality measures.

    Returns:
        A list of measure names usable with :meth:`centrality`.
    """
    return ['degree', 'eigenvector', 'katz', 'closeness', 'betweenness']

centrality

centrality(centrality)

Compute a centrality score per atom to guide the elimination order.

Parameters:

Name Type Description Default
centrality

The centrality measure name (one of :meth:centrality_measures), optionally prefixed with 'rev-' to invert the resulting scores.

required

Returns:

Type Description

A dict mapping each graph node to its (possibly reversed) centrality score.

Raises:

Type Description
Exception

If the centrality measure name is unknown.

Source code in pishield/propositional_requirements/clauses_group.py
def centrality(self, centrality):
    """Compute a centrality score per atom to guide the elimination order.

    Args:
        centrality: The centrality measure name (one of
            :meth:`centrality_measures`), optionally prefixed with ``'rev-'`` to
            invert the resulting scores.

    Returns:
        A dict mapping each graph node to its (possibly reversed) centrality score.

    Raises:
        Exception: If the centrality measure name is unknown.
    """
    G = self.graph()

    if centrality.startswith('rev-'):
        centrality = centrality[4:]
        rev = True
    else:
        rev = False

    if centrality == 'degree':
        result = nx.algorithms.centrality.degree_centrality(G)
    elif centrality == 'eigenvector':
        result = nx.algorithms.centrality.eigenvector_centrality_numpy(G)
    elif centrality == 'katz':
        result = nx.algorithms.centrality.katz_centrality_numpy(G)
    elif centrality == 'closeness':
        result = nx.algorithms.centrality.closeness_centrality(G)
    elif centrality == 'betweenness':
        result = nx.algorithms.centrality.betweenness_centrality(G)
    else:
        raise Exception(f"Unknown centrality {centrality}")

    # Normalize results
    if rev:
        values = np.array([result[node] for node in result])
        mini, maxi = values.min(), values.max()
        for node in result: result[node] = maxi - (result[node] - mini)

    return result

stratify

stratify(centrality)

Convert the clauses into ordered constraint strata for the Shield Layer.

Repeatedly eliminates atoms (in an order determined by centrality) via resolution, accumulating the resulting constraints and optionally rewriting them to enforce strong coherency, then stratifies the accumulated constraints.

Parameters:

Name Type Description Default
centrality

Either a centrality measure name (str) used to order atoms, an explicit iterable of atom indices giving the order, or None to use the order atoms appear in the constraints.

required

Returns:

Type Description

A list of :class:ConstraintsGroup strata (the output of

meth:ConstraintsGroup.stratify).

Raises:

Type Description
Exception

If the clauses are unsatisfiable (clauses remain after eliminating every atom).

Source code in pishield/propositional_requirements/clauses_group.py
def stratify(self, centrality):
    """Convert the clauses into ordered constraint strata for the Shield Layer.

    Repeatedly eliminates atoms (in an order determined by ``centrality``) via
    resolution, accumulating the resulting constraints and optionally rewriting
    them to enforce strong coherency, then stratifies the accumulated constraints.

    Args:
        centrality: Either a centrality measure name (str) used to order atoms, an
            explicit iterable of atom indices giving the order, or None to use the
            order atoms appear in the constraints.

    Returns:
        A list of :class:`ConstraintsGroup` strata (the output of
        :meth:`ConstraintsGroup.stratify`).

    Raises:
        Exception: If the clauses are unsatisfiable (clauses remain after
            eliminating every atom).
    """
    # Centrality guides the inference order
    if not isinstance(centrality, str):
        atoms = centrality
    else:
        centrality = self.centrality(centrality)
        atoms = list(self.atoms())
        atoms.sort(key=lambda x: centrality[x])
    if centrality is None:  # get atoms in the order they appear in constraints
        atoms = list(self.atoms())

    # Apply resolution repeatedly
    atoms = atoms[::-1]
    group = ConstraintsGroup([])
    clauses = self

    for atom in atoms:
        # print(f"Eliminating %{atom} from %{len(clauses)} clauses\n")
        constraints, clauses = clauses.resolution(atom)
        if len(constraints.constraints_list):
            strongly_coherent_constraints = strong_coherency_constraint_preprocessing(constraints.constraints_list, atoms)
            if strongly_coherent_constraints is not None:
                constraints = strongly_coherent_constraints
        group = group + constraints

    if len(clauses):
        raise Exception("Unsatisfiable set of clauses")

    return group.stratify()

coherent_with

coherent_with(preds)

Check which clauses each prediction satisfies.

Parameters:

Name Type Description Default
preds

A 2D array of predicted probabilities, shape (batch, num_classes).

required

Returns:

Type Description

A boolean array of shape (batch, num_clauses), True where the corresponding

clause is satisfied (columns in insertion order).

Source code in pishield/propositional_requirements/clauses_group.py
def coherent_with(self, preds):
    """Check which clauses each prediction satisfies.

    Args:
        preds: A 2D array of predicted probabilities, shape (batch, num_classes).

    Returns:
        A boolean array of shape (batch, num_clauses), True where the corresponding
        clause is satisfied (columns in insertion order).
    """
    answer = [clause.coherent_with(preds) for clause in self.clauses_list]
    answer = np.array(answer).reshape(len(self.clauses_list), preds.shape[0])
    return answer.transpose()

atoms

atoms()

Return the set of all variable indices used across the clauses.

Source code in pishield/propositional_requirements/clauses_group.py
def atoms(self):
    """Return the set of all variable indices used across the clauses."""
    result = set()
    for clause in self.clauses:
        result = result.union(clause.atoms())
    return result

Constraint

constraint

Propositional constraints (Horn rules).

A constraint is a Horn rule of the form head :- body, where head is a single literal and body is a conjunction of literals: the rule states that if every body literal holds then the head must hold. Constraints can be parsed from the head :- body or disjunctive (y_0 or not y_1) textual forms and are the requirement representation that the Shield Layer and Memory-efficient Loss consume.

Constraint

Constraint(*args)

A Horn rule head :- body over propositional literals.

Attributes:

Name Type Description
head

The :class:Literal implied by the constraint.

body

A frozenset of :class:Literal objects forming the conjunction that must hold for the head to be implied.

Build a constraint from a head/body pair or from a textual rule.

Two calling conventions are supported
  • Constraint(head: Literal, body: Iterable[Literal]) builds it directly.
  • Constraint(text: str) parses a rule in either head :- body form or disjunctive head or lit or ... form (the latter parsed with reversed signs to match the :- convention).

Parameters:

Name Type Description Default
*args

Either two positional arguments (head, body) or a single string.

()
Source code in pishield/propositional_requirements/constraint.py
def __init__(self, *args):
    """Build a constraint from a head/body pair or from a textual rule.

    Two calling conventions are supported:
      * ``Constraint(head: Literal, body: Iterable[Literal])`` builds it directly.
      * ``Constraint(text: str)`` parses a rule in either ``head :- body`` form or
        disjunctive ``head or lit or ...`` form (the latter parsed with reversed
        signs to match the ``:-`` convention).

    Args:
        *args: Either two positional arguments ``(head, body)`` or a single string.
    """
    if len(args) == 2:
        # Constraint(Literal, [Literal])
        self.head = args[0]
        self.body = frozenset(args[1])
    else:
        # Constraint(string)
        if ':-' in args[0]:
            line = args[0].split(' ')
            if line[2] == ':-':
                line = line[1:]
            assert line[1] == ':-'
            self.head = Literal(line[0])
            self.body = frozenset(Literal(lit) for lit in line[2:])
        elif 'or' in args[0]:
            # Constraint(string)
            line = args[0].split(' or ')
            self.head = Literal(line[0])
            self.body = frozenset(Literal(lit, reversed_sign=True) for lit in line[1:])

head_encoded

head_encoded(num_classes)

One-hot encode the head literal into positive and negative vectors.

Parameters:

Name Type Description Default
num_classes

The number of variables (length of the encoding vectors).

required

Returns:

Type Description

A tuple (pos_head, neg_head) of length-num_classes arrays; exactly

one entry is set to 1 depending on the head's atom and polarity.

Source code in pishield/propositional_requirements/constraint.py
def head_encoded(self, num_classes):
    """One-hot encode the head literal into positive and negative vectors.

    Args:
        num_classes: The number of variables (length of the encoding vectors).

    Returns:
        A tuple ``(pos_head, neg_head)`` of length-``num_classes`` arrays; exactly
        one entry is set to 1 depending on the head's atom and polarity.
    """
    pos_head = np.zeros(num_classes)
    neg_head = np.zeros(num_classes)
    if self.head.positive:
        pos_head[self.head.atom] = 1
    else:
        neg_head[self.head.atom] = 1
    return pos_head, neg_head

body_encoded

body_encoded(num_classes)

Multi-hot encode the body literals into positive and negative vectors.

Parameters:

Name Type Description Default
num_classes

The number of variables (length of the encoding vectors).

required

Returns:

Type Description

A tuple (pos_body, neg_body) of length-num_classes integer arrays

marking which atoms appear positively and negatively in the body.

Source code in pishield/propositional_requirements/constraint.py
def body_encoded(self, num_classes):
    """Multi-hot encode the body literals into positive and negative vectors.

    Args:
        num_classes: The number of variables (length of the encoding vectors).

    Returns:
        A tuple ``(pos_body, neg_body)`` of length-``num_classes`` integer arrays
        marking which atoms appear positively and negatively in the body.
    """
    pos_body = np.zeros(num_classes, dtype=int)
    neg_body = np.zeros(num_classes, dtype=int)
    for lit in self.body:
        if lit.positive:
            pos_body[lit.atom] = 1
        else:
            neg_body[lit.atom] = 1
    return pos_body, neg_body

where

where(cond, opt1, opt2)

Differentiable select between two options.

Parameters:

Name Type Description Default
cond

A 0/1 mask (or probability) selecting between the options.

required
opt1

The value used where cond is 1.

required
opt2

The value used where cond is 0.

required

Returns:

Type Description

opt2 + cond * (opt1 - opt2).

Source code in pishield/propositional_requirements/constraint.py
def where(self, cond, opt1, opt2):
    """Differentiable select between two options.

    Args:
        cond: A 0/1 mask (or probability) selecting between the options.
        opt1: The value used where ``cond`` is 1.
        opt2: The value used where ``cond`` is 0.

    Returns:
        ``opt2 + cond * (opt1 - opt2)``.
    """
    return opt2 + cond * (opt1 - opt2)

coherent_with

coherent_with(preds)

Check whether predictions satisfy this constraint.

The body truth value is the minimum over its literals (a product/Goedel-style conjunction); the constraint is satisfied when body <= head.

Parameters:

Name Type Description Default
preds

A 2D array of predicted probabilities, shape (batch, num_classes).

required

Returns:

Type Description

A boolean array of length batch, True where the constraint is satisfied.

Source code in pishield/propositional_requirements/constraint.py
def coherent_with(self, preds):
    """Check whether predictions satisfy this constraint.

    The body truth value is the minimum over its literals (a product/Goedel-style
    conjunction); the constraint is satisfied when ``body <= head``.

    Args:
        preds: A 2D array of predicted probabilities, shape (batch, num_classes).

    Returns:
        A boolean array of length batch, True where the constraint is satisfied.
    """
    num_classes = preds.shape[1]
    pos_body, neg_body = self.body_encoded(num_classes)
    pos_body = preds[:, pos_body.astype(bool)]
    neg_body = 1 - preds[:, neg_body.astype(bool)]
    body = np.min(np.concatenate((pos_body, neg_body), axis=1), axis=1)

    head = preds[:, self.head.atom]
    if not self.head.positive:
        head = 1 - head

    return body <= head

atoms

atoms()

Return the set of variable indices in the head and body.

Source code in pishield/propositional_requirements/constraint.py
def atoms(self):
    """Return the set of variable indices in the head and body."""
    return {lit.atom for lit in self.body.union({self.head})}

Constraints group

constraints_group

Collections of propositional constraints.

A :class:ConstraintsGroup bundles a set of :class:Constraint Horn rules, supports encoding them into matrices, checking coherency of predictions, building dependency graphs over the variables, and stratifying the constraints into ordered layers so that each head is only resolved after the variables it depends on.

ConstraintsGroup

ConstraintsGroup(arg)

A set of Horn constraints with encoding and stratification utilities.

Attributes:

Name Type Description
constraints

A frozenset of the :class:Constraint objects in the group.

constraints_list

The constraints in their original insertion order, kept so that coherency results have a stable column order.

Build a constraints group from a file or a list of constraints.

Parameters:

Name Type Description Default
arg

Either a path to a constraints file (one head :- body rule per line) or an iterable of :class:Constraint objects.

required
Source code in pishield/propositional_requirements/constraints_group.py
def __init__(self, arg):
    """Build a constraints group from a file or a list of constraints.

    Args:
        arg: Either a path to a constraints file (one ``head :- body`` rule per
            line) or an iterable of :class:`Constraint` objects.
    """
    if isinstance(arg, str):
        # ConstraintGroup(string)
        with open(arg, 'r') as f:
            self.constraints = [Constraint(line) for line in f]
    else:
        # ConstraintGroup([Constraint])
        self.constraints = arg

    # Keep the initial order of constraints for coherent_with
    self.constraints_list = self.constraints
    self.constraints = frozenset(self.constraints_list)

head_encoded

head_encoded(num_classes)

Encode all constraint heads into stacked positive/negative matrices.

Parameters:

Name Type Description Default
num_classes

The number of variables (row width of the encodings).

required

Returns:

Type Description

A tuple (pos_head, neg_head) of arrays, each row encoding one

constraint's head (see :meth:Constraint.head_encoded).

Source code in pishield/propositional_requirements/constraints_group.py
def head_encoded(self, num_classes):
    """Encode all constraint heads into stacked positive/negative matrices.

    Args:
        num_classes: The number of variables (row width of the encodings).

    Returns:
        A tuple ``(pos_head, neg_head)`` of arrays, each row encoding one
        constraint's head (see :meth:`Constraint.head_encoded`).
    """
    pos_head = []
    neg_head = []

    for constraint in self.constraints:
        pos, neg = constraint.head_encoded(num_classes)
        pos_head.append(pos)
        neg_head.append(neg)

    return np.array(pos_head), np.array(neg_head)

body_encoded

body_encoded(num_classes)

Encode all constraint bodies into stacked positive/negative matrices.

Parameters:

Name Type Description Default
num_classes

The number of variables (row width of the encodings).

required

Returns:

Type Description

A tuple (pos_body, neg_body) of arrays, each row encoding one

constraint's body (see :meth:Constraint.body_encoded).

Source code in pishield/propositional_requirements/constraints_group.py
def body_encoded(self, num_classes):
    """Encode all constraint bodies into stacked positive/negative matrices.

    Args:
        num_classes: The number of variables (row width of the encodings).

    Returns:
        A tuple ``(pos_body, neg_body)`` of arrays, each row encoding one
        constraint's body (see :meth:`Constraint.body_encoded`).
    """
    pos_body = []
    neg_body = []

    for constraint in self.constraints:
        pos, neg = constraint.body_encoded(num_classes)
        pos_body.append(pos)
        neg_body.append(neg)

    return np.array(pos_body), np.array(neg_body)

encoded

encoded(num_classes)

Encode both heads and bodies of all constraints.

Parameters:

Name Type Description Default
num_classes

The number of variables (encoding width).

required

Returns:

Type Description

A tuple (head, body) where each element is the (pos, neg) pair

returned by :meth:head_encoded and :meth:body_encoded.

Source code in pishield/propositional_requirements/constraints_group.py
def encoded(self, num_classes):
    """Encode both heads and bodies of all constraints.

    Args:
        num_classes: The number of variables (encoding width).

    Returns:
        A tuple ``(head, body)`` where each element is the ``(pos, neg)`` pair
        returned by :meth:`head_encoded` and :meth:`body_encoded`.
    """
    head = self.head_encoded(num_classes)
    body = self.body_encoded(num_classes)
    return head, body

coherent_with

coherent_with(preds)

Check which constraints each prediction satisfies.

Parameters:

Name Type Description Default
preds

A 2D array of predicted probabilities, shape (batch, num_classes).

required

Returns:

Type Description

A boolean array of shape (batch, num_constraints), True where the

corresponding constraint is satisfied (columns in insertion order).

Source code in pishield/propositional_requirements/constraints_group.py
def coherent_with(self, preds):
    """Check which constraints each prediction satisfies.

    Args:
        preds: A 2D array of predicted probabilities, shape (batch, num_classes).

    Returns:
        A boolean array of shape (batch, num_constraints), True where the
        corresponding constraint is satisfied (columns in insertion order).
    """
    coherent = [constraint.coherent_with(preds) for constraint in self.constraints_list]
    return np.array(coherent).transpose()

atoms

atoms()

Return the set of all variable indices used across the constraints.

Source code in pishield/propositional_requirements/constraints_group.py
def atoms(self):
    """Return the set of all variable indices used across the constraints."""
    atoms = set()
    for constraint in self.constraints:
        atoms = atoms.union(constraint.atoms())
    return atoms

heads

heads()

Return the set of variable indices that appear as a constraint head.

Source code in pishield/propositional_requirements/constraints_group.py
def heads(self):
    """Return the set of variable indices that appear as a constraint head."""
    heads = set()
    for constraint in self.constraints:
        heads.add(constraint.head.atom)
    return heads

graph

graph()

Build a directed dependency graph over the variables.

Adds an edge from each body atom to the head atom of every constraint, annotating it with the body and head polarities.

Returns:

Type Description

A networkx.DiGraph whose nodes are atom indices and whose edges carry

'body' and 'head' polarity attributes.

Source code in pishield/propositional_requirements/constraints_group.py
def graph(self):
    """Build a directed dependency graph over the variables.

    Adds an edge from each body atom to the head atom of every constraint,
    annotating it with the body and head polarities.

    Returns:
        A ``networkx.DiGraph`` whose nodes are atom indices and whose edges carry
        ``'body'`` and ``'head'`` polarity attributes.
    """
    G = nx.DiGraph()
    G.add_nodes_from(self.atoms())

    for constraint in self.constraints:
        for lit in constraint.body:
            x = lit.atom
            y = constraint.head.atom
            G.add_edge(x, y)
            G[x][y]['body'] = lit.positive
            G[x][y]['head'] = constraint.head.positive

    return G

duograph

duograph()

Build a directed graph over signed literals (atom and its negation).

Unlike :meth:graph, each atom appears as two nodes (positive and negative literal), and an edge is added from each body literal to the head literal.

Returns:

Type Description

A networkx.DiGraph whose nodes are literal strings.

Source code in pishield/propositional_requirements/constraints_group.py
def duograph(self):
    """Build a directed graph over signed literals (atom and its negation).

    Unlike :meth:`graph`, each atom appears as two nodes (positive and negative
    literal), and an edge is added from each body literal to the head literal.

    Returns:
        A ``networkx.DiGraph`` whose nodes are literal strings.
    """
    atoms = self.atoms()
    pos_atoms = [str(Literal(atom, True)) for atom in atoms]
    neg_atoms = [str(Literal(atom, False)) for atom in atoms]

    G = nx.DiGraph()
    G.add_nodes_from(pos_atoms + neg_atoms)

    for constraint in self.constraints:
        for lit in constraint.body:
            G.add_edge(str(lit), str(constraint.head))

    return G

stratify

stratify()

Partition the constraints into ordered dependency layers (strata).

Performs a topological-style sweep of the dependency graph: variables with no unresolved dependencies are processed first, and the constraints whose heads they are form a stratum, repeated until all variables are consumed. Each stratum can be corrected without violating an earlier one.

Returns:

Type Description

A list of :class:ConstraintsGroup objects, ordered from the constraints

that should be applied first to those applied last.

Source code in pishield/propositional_requirements/constraints_group.py
def stratify(self):
    """Partition the constraints into ordered dependency layers (strata).

    Performs a topological-style sweep of the dependency graph: variables with no
    unresolved dependencies are processed first, and the constraints whose heads
    they are form a stratum, repeated until all variables are consumed. Each
    stratum can be corrected without violating an earlier one.

    Returns:
        A list of :class:`ConstraintsGroup` objects, ordered from the constraints
        that should be applied first to those applied last.
    """
    G = self.graph()

    for node in G.nodes():
        G.nodes[node]['deps'] = 0
        G.nodes[node]['constraints'] = []

    for x, y in G.edges():
        G.nodes[y]['deps'] += 1

    for constraint in self.constraints:
        G.nodes[constraint.head.atom]['constraints'].append(constraint)

    result = []
    ready = [node for node in G.nodes() if G.nodes[node]['deps'] == 0]
    while len(ready) > 0:
        resolved = [cons for node in ready for cons in G.nodes[node]['constraints']]
        if len(resolved) > 0:
            result.append(ConstraintsGroup(resolved))

        next = []
        for node in ready:
            for other in G[node]:
                G.nodes[other]['deps'] -= 1
                if G.nodes[other]['deps'] == 0:
                    next.append(other)

        ready = next

    return result

Constraints module

constraints_module

Differentiable correction module for one stratum of constraints.

A :class:ConstraintsModule holds the (precomputed, non-trainable) tensor encoding of a single :class:ConstraintsGroup and corrects predictions so they satisfy that group. Each constraint contributes a lower bound (for positive heads) or an upper bound (for negative heads) on its head variable, derived from the truth value of its body; the prediction is then clamped into those bounds. Two equivalent implementations are provided: a vectorised 3D-tensor version (:meth:apply_tensor) and an iterative 2D version (:meth:apply_iter). The Shield Layer stacks one such module per stratum.

ConstraintsModule

ConstraintsModule(constraints_group, num_classes)

Bases: Module

Corrects predictions to satisfy a single stratum of constraints.

The module restricts attention to the atoms occurring in its constraints (re-indexed to a minimal range) and stores the encoded heads/bodies as buffers. Bodies are evaluated with a Goedel (min) semantics and used to bound the head predictions.

Attributes:

Name Type Description
atoms

The original variable indices covered by this module's constraints.

heads

The re-indexed head :class:Literal of each constraint.

pos_head, neg_head

Encoded positive/negative head indicators per constraint.

pos_body, neg_body

Encoded positive/negative body indicators per constraint.

symm_body, symm_head

Precomputed symmetric (-1/+1) body/head encodings.

literals_count

The number of literals in each constraint's body.

Encode a constraints group into reusable correction tensors.

Parameters:

Name Type Description Default
constraints_group

The :class:ConstraintsGroup (one stratum) to encode.

required
num_classes

The total number of variables in the full prediction space.

required
Source code in pishield/propositional_requirements/constraints_module.py
@profiler.wrap
def __init__(self, constraints_group, num_classes):
    """Encode a constraints group into reusable correction tensors.

    Args:
        constraints_group: The :class:`ConstraintsGroup` (one stratum) to encode.
        num_classes: The total number of variables in the full prediction space.
    """
    super(ConstraintsModule, self).__init__()
    head, body = constraints_group.encoded(num_classes)
    pos_head, neg_head = head
    pos_body, neg_body = body

    # Compute necessary atoms
    self.atoms = nn.Parameter(torch.tensor(list(constraints_group.atoms())), requires_grad=False)
    reindexed = {float(atom): i for i, atom in enumerate(self.atoms)}
    if len(self.atoms) == 0: return

    # Reduce tensors to minimal size and reindex heads
    pos_head, neg_head = self.to_minimal(pos_head), self.to_minimal(neg_head)
    pos_body, neg_body = self.to_minimal(pos_body), self.to_minimal(neg_body)

    heads = [constraint.head for constraint in constraints_group]
    self.heads = [Literal(reindexed[head.atom], head.positive) for head in heads]

    # Module parameters
    self.pos_head = nn.Parameter(torch.from_numpy(pos_head).float(), requires_grad=False)
    self.neg_head = nn.Parameter(torch.from_numpy(neg_head).float(), requires_grad=False)
    self.pos_body = nn.Parameter(torch.from_numpy(pos_body).float(), requires_grad=False)
    self.neg_body = nn.Parameter(torch.from_numpy(neg_body).float(), requires_grad=False)

    # Precomputed parameters
    self.symm_body = nn.Parameter((self.pos_body - self.neg_body).t(), requires_grad=False)
    self.symm_head = nn.Parameter((self.pos_head - self.neg_head).t(), requires_grad=False)
    self.literals_count = nn.Parameter(self.pos_body.sum(dim=1) + self.neg_body.sum(dim=1), requires_grad=False)

dimensions

dimensions(pred)

Return the (batch, num_atoms, num_constraints) dimensions for a tensor.

Parameters:

Name Type Description Default
pred

A prediction tensor of shape (batch, num_atoms).

required

Returns:

Type Description

A tuple (batch, num, cons).

Source code in pishield/propositional_requirements/constraints_module.py
def dimensions(self, pred):
    """Return the (batch, num_atoms, num_constraints) dimensions for a tensor.

    Args:
        pred: A prediction tensor of shape (batch, num_atoms).

    Returns:
        A tuple ``(batch, num, cons)``.
    """
    batch, num = pred.shape[0], pred.shape[1]
    cons = self.pos_head.shape[0]
    return batch, num, cons

from_symmetric staticmethod

from_symmetric(preds)

Map symmetric values in [-1, 1] back to probabilities in [0, 1].

Parameters:

Name Type Description Default
preds

A tensor of symmetric values.

required

Returns:

Type Description

The corresponding probabilities.

Source code in pishield/propositional_requirements/constraints_module.py
@staticmethod
@profiler.wrap
def from_symmetric(preds):
    """Map symmetric values in [-1, 1] back to probabilities in [0, 1].

    Args:
        preds: A tensor of symmetric values.

    Returns:
        The corresponding probabilities.
    """
    return (preds + 1) / 2

to_symmetric staticmethod

to_symmetric(preds)

Map probabilities in [0, 1] to symmetric values in [-1, 1].

Parameters:

Name Type Description Default
preds

A tensor of probabilities.

required

Returns:

Type Description

The corresponding symmetric values.

Source code in pishield/propositional_requirements/constraints_module.py
@staticmethod
@profiler.wrap
def to_symmetric(preds):
    """Map probabilities in [0, 1] to symmetric values in [-1, 1].

    Args:
        preds: A tensor of probabilities.

    Returns:
        The corresponding symmetric values.
    """
    return 2 * preds - 1

to_minimal

to_minimal(tensor, atoms=None)

Restrict a full tensor to the module's atom columns.

Parameters:

Name Type Description Default
tensor

A tensor of shape (batch, num_classes).

required
atoms

The atom indices to keep; defaults to self.atoms.

None

Returns:

Type Description

The tensor restricted to the selected atom columns.

Source code in pishield/propositional_requirements/constraints_module.py
@profiler.wrap
def to_minimal(self, tensor, atoms=None):
    """Restrict a full tensor to the module's atom columns.

    Args:
        tensor: A tensor of shape (batch, num_classes).
        atoms: The atom indices to keep; defaults to ``self.atoms``.

    Returns:
        The tensor restricted to the selected atom columns.
    """
    if atoms is None: atoms = self.atoms
    return tensor[:, atoms].reshape(tensor.shape[0], len(atoms))

from_minimal

from_minimal(tensor, init, atoms=None)

Scatter a minimal-atom tensor back into a full tensor.

Parameters:

Name Type Description Default
tensor

The minimal tensor over the module's atoms.

required
init

The full tensor to copy the values into.

required
atoms

The atom indices the values belong to; defaults to self.atoms.

None

Returns:

Type Description

init with the atom columns overwritten by tensor.

Source code in pishield/propositional_requirements/constraints_module.py
@profiler.wrap
def from_minimal(self, tensor, init, atoms=None):
    """Scatter a minimal-atom tensor back into a full tensor.

    Args:
        tensor: The minimal tensor over the module's atoms.
        init: The full tensor to copy the values into.
        atoms: The atom indices the values belong to; defaults to ``self.atoms``.

    Returns:
        ``init`` with the atom columns overwritten by ``tensor``.
    """
    if atoms is None: atoms = self.atoms
    return init.index_copy(1, atoms, tensor)

active_constraints

active_constraints(goal)

Identify which constraints are activated by a goal assignment.

Parameters:

Name Type Description Default
goal

A ground-truth assignment over the module's atoms.

required

Returns:

Type Description

A tuple (full_body, unsat_head) of boolean masks marking the

constraints whose body is fully satisfied by goal and those whose head

is unsatisfied by goal.

Source code in pishield/propositional_requirements/constraints_module.py
@profiler.wrap
def active_constraints(self, goal):
    """Identify which constraints are activated by a goal assignment.

    Args:
        goal: A ground-truth assignment over the module's atoms.

    Returns:
        A tuple ``(full_body, unsat_head)`` of boolean masks marking the
        constraints whose body is fully satisfied by ``goal`` and those whose head
        is unsatisfied by ``goal``.
    """
    symm_goal = ConstraintsModule.to_symmetric(goal)
    full_body = torch.matmul(symm_goal, self.symm_body) == self.literals_count
    unsat_head = torch.matmul(symm_goal, self.symm_head) == -1
    return full_body, unsat_head

apply_tensor

apply_tensor(preds, active_constraints=None, body_mask=None)

Correct predictions with a vectorised 3D-tensor computation.

Computes, for every constraint, the Goedel (min) truth value of its body and uses it as a lower bound on positive heads and an upper bound on negative heads, then clamps each prediction into the resulting bounds.

Parameters:

Name Type Description Default
preds

The predictions over the module's atoms, shape (batch, num_atoms).

required
active_constraints

Optional per-constraint mask zeroing inactive constraints' contributions.

None
body_mask

Optional per-atom mask used to drop body literals that the goal already determines.

None

Returns:

Type Description

The corrected predictions, same shape as preds.

Source code in pishield/propositional_requirements/constraints_module.py
@profiler.wrap
def apply_tensor(self, preds, active_constraints=None, body_mask=None):
    """Correct predictions with a vectorised 3D-tensor computation.

    Computes, for every constraint, the Goedel (min) truth value of its body and
    uses it as a lower bound on positive heads and an upper bound on negative heads,
    then clamps each prediction into the resulting bounds.

    Args:
        preds: The predictions over the module's atoms, shape (batch, num_atoms).
        active_constraints: Optional per-constraint mask zeroing inactive
            constraints' contributions.
        body_mask: Optional per-atom mask used to drop body literals that the goal
            already determines.

    Returns:
        The corrected predictions, same shape as ``preds``.
    """
    batch, num, cons = self.dimensions(preds)

    # batch x cons x num: prepare (preds x body)
    exp_preds = preds.unsqueeze(1).expand(batch, cons, num)
    pos_body = self.pos_body.unsqueeze(0).expand(batch, cons, num)
    neg_body = self.neg_body.unsqueeze(0).expand(batch, cons, num)

    # batch x cons x num: ignore literals from constraints
    if body_mask != None:
        body_mask = body_mask.unsqueeze(1).expand(batch, cons, num)
        pos_body = pos_body * (1 - body_mask)
        neg_body = neg_body * body_mask

    # batch x cons: compute body minima
    body_rev = pos_body + exp_preds * (neg_body - pos_body)
    body_min = 1. - torch.max(body_rev, dim=2).values

    # batch x cons: ignore constraints
    if active_constraints != None:
        body_min = body_min * active_constraints.float()

    # batch x cons x num: prepare (body_min x head)
    body_min = body_min.unsqueeze(2).expand(batch, cons, num)
    pos_head = self.pos_head.unsqueeze(0).expand(batch, cons, num)
    neg_head = self.neg_head.unsqueeze(0).expand(batch, cons, num)

    # batch x num: compute head lower and upper bounds
    lb = torch.max(body_min * pos_head, dim=1).values
    ub = 1 - torch.max(body_min * neg_head, dim=1).values
    lb, ub = torch.minimum(lb, ub), torch.maximum(lb, ub)

    preds = torch.maximum(lb, torch.minimum(ub, preds))
    return preds

apply_iter

apply_iter(preds, active_constraints=None, body_mask=None, in_bounds=None, out_bounds=False)

Correct predictions by iterating constraint by constraint with 2D matrices.

Equivalent to :meth:apply_tensor but loops over constraints, accumulating per atom lower/upper bounds; this avoids materialising the large 3D tensors. Can accept incoming bounds and/or return the raw bounds instead of corrected predictions, which the goal-conditioned path uses to chain two passes.

Parameters:

Name Type Description Default
preds

The predictions over the module's atoms, shape (batch, num_atoms).

required
active_constraints

Optional per-constraint mask zeroing inactive constraints' contributions.

None
body_mask

Optional per-atom mask used to drop body literals.

None
in_bounds

Optional (lb, ub) lists of incoming per-atom bounds to start from; defaults to [0, 1] per atom.

None
out_bounds

If True, return the raw (lb, ub) bounds rather than the clamped predictions.

False

Returns:

Type Description

Either the corrected predictions, or the (lb, ub) bounds when

out_bounds is True.

Source code in pishield/propositional_requirements/constraints_module.py
@profiler.wrap
def apply_iter(self, preds, active_constraints=None, body_mask=None, in_bounds=None, out_bounds=False):
    """Correct predictions by iterating constraint by constraint with 2D matrices.

    Equivalent to :meth:`apply_tensor` but loops over constraints, accumulating per
    atom lower/upper bounds; this avoids materialising the large 3D tensors. Can
    accept incoming bounds and/or return the raw bounds instead of corrected
    predictions, which the goal-conditioned path uses to chain two passes.

    Args:
        preds: The predictions over the module's atoms, shape (batch, num_atoms).
        active_constraints: Optional per-constraint mask zeroing inactive
            constraints' contributions.
        body_mask: Optional per-atom mask used to drop body literals.
        in_bounds: Optional ``(lb, ub)`` lists of incoming per-atom bounds to start
            from; defaults to ``[0, 1]`` per atom.
        out_bounds: If True, return the raw ``(lb, ub)`` bounds rather than the
            clamped predictions.

    Returns:
        Either the corrected predictions, or the ``(lb, ub)`` bounds when
        ``out_bounds`` is True.
    """
    batch, num, cons = self.dimensions(preds)
    device = preds.device

    if not active_constraints is None: active_constraints = active_constraints.float()
    zeros = torch.zeros(batch, 1, device=device)

    profiler = ConstraintsModule.profiler.branch('iter')

    with profiler.watch('init'):
        if in_bounds is None:
            lb = [torch.zeros(preds.shape[0], device=device) for i in range(preds.shape[1])]
            ub = [torch.ones(preds.shape[0], device=device) for i in range(preds.shape[1])]
        else:
            lb, ub = in_bounds

    with profiler.watch('precompute'):
        bool_pos_body = self.pos_body.bool()
        bool_neg_body = self.neg_body.bool()

        full_pos_body = 1 - preds
        full_neg_body = preds

        if not body_mask is None:
            full_pos_body = (1 - preds) * (1 - body_mask)
            full_neg_body = preds * body_mask

    for c, lit in enumerate(self.heads):
        # slice positive and negative body preds
        with profiler.watch('where'):
            pos_where = bool_pos_body[c]
            neg_where = bool_neg_body[c]

        # body predictions (possibly masked) 
        with profiler.watch('body'):
            pos_body = full_pos_body[:, pos_where]
            neg_body = full_neg_body[:, neg_where]

        # compute maximal inverted values
        with profiler.watch('candidate'):
            candidate = torch.cat((zeros, pos_body, neg_body), dim=1)
            candidate = 1 - candidate.max(dim=1).values

        # clear inactive constraints
        with profiler.watch('active_cons'):
            if not active_constraints is None:
                candidate = candidate * active_constraints[:, c]

        # update preds
        with profiler.watch('min_max'):
            if lit.positive:
                lb[lit.atom] = torch.maximum(lb[lit.atom], candidate)
            else:
                ub[lit.atom] = torch.minimum(ub[lit.atom], 1 - candidate)

    with profiler.watch('lb_ub'):
        if out_bounds:
            return lb, ub

        lb, ub = torch.stack(lb, dim=1), torch.stack(ub, dim=1)
        lb, ub = torch.minimum(lb, ub), torch.maximum(lb, ub)
        updated = torch.maximum(lb, torch.minimum(ub, preds))

        return updated

apply

apply(preds, iterative)

Correct predictions using the chosen implementation.

Parameters:

Name Type Description Default
preds

The predictions over the module's atoms.

required
iterative

If True use :meth:apply_iter, else :meth:apply_tensor.

required

Returns:

Type Description

The corrected predictions.

Source code in pishield/propositional_requirements/constraints_module.py
@profiler.wrap
def apply(self, preds, iterative):
    """Correct predictions using the chosen implementation.

    Args:
        preds: The predictions over the module's atoms.
        iterative: If True use :meth:`apply_iter`, else :meth:`apply_tensor`.

    Returns:
        The corrected predictions.
    """
    if iterative:
        return self.apply_iter(preds)
    else:
        return self.apply_tensor(preds)

apply_goal

apply_goal(preds, goal, iterative)

Correct predictions consistently with a known goal assignment.

Applies the constraints in two passes: first using only constraints whose body is fully satisfied by the goal (to propagate firm consequences), then using the constraints whose head is unsatisfied with the goal-determined body literals masked out.

Parameters:

Name Type Description Default
preds

The predictions over the module's atoms.

required
goal

The goal (ground-truth) assignment over the module's atoms.

required
iterative

If True use the iterative implementation, else the tensor one.

required

Returns:

Type Description

The corrected predictions.

Source code in pishield/propositional_requirements/constraints_module.py
@profiler.wrap
def apply_goal(self, preds, goal, iterative):
    """Correct predictions consistently with a known goal assignment.

    Applies the constraints in two passes: first using only constraints whose body
    is fully satisfied by the goal (to propagate firm consequences), then using the
    constraints whose head is unsatisfied with the goal-determined body literals
    masked out.

    Args:
        preds: The predictions over the module's atoms.
        goal: The goal (ground-truth) assignment over the module's atoms.
        iterative: If True use the iterative implementation, else the tensor one.

    Returns:
        The corrected predictions.
    """
    full_body, unsat_head = self.active_constraints(goal)
    body_mask = goal

    if iterative:
        bounds = self.apply_iter(preds, active_constraints=full_body, out_bounds=True)
        updated = self.apply_iter(preds, active_constraints=unsat_head, body_mask=body_mask, in_bounds=bounds)
    else:
        updated = self.apply_tensor(preds, active_constraints=full_body)
        updated = self.apply_tensor(updated, active_constraints=unsat_head, body_mask=body_mask)

    return updated

forward

forward(preds, goal=None, iterative=True)

Correct a full prediction tensor against this stratum's constraints.

Restricts the predictions to the module's atoms, applies the correction (optionally goal-conditioned), and scatters the result back into the full tensor. Returns the input unchanged when there are no predictions or no atoms.

Parameters:

Name Type Description Default
preds

The full prediction tensor, shape (batch, num_classes).

required
goal

Optional goal assignment over the full variable space; when given, corrections are made consistent with it.

None
iterative

If True use the iterative implementation, else the tensor one.

True

Returns:

Type Description

The full prediction tensor with this stratum's atoms corrected.

Source code in pishield/propositional_requirements/constraints_module.py
@profiler.wrap
def forward(self, preds, goal=None, iterative=True):
    """Correct a full prediction tensor against this stratum's constraints.

    Restricts the predictions to the module's atoms, applies the correction
    (optionally goal-conditioned), and scatters the result back into the full
    tensor. Returns the input unchanged when there are no predictions or no atoms.

    Args:
        preds: The full prediction tensor, shape (batch, num_classes).
        goal: Optional goal assignment over the full variable space; when given,
            corrections are made consistent with it.
        iterative: If True use the iterative implementation, else the tensor one.

    Returns:
        The full prediction tensor with this stratum's atoms corrected.
    """
    if len(preds) == 0 or len(self.atoms) == 0:
        return preds

    updated = self.to_minimal(preds)

    if goal is None:
        updated = self.apply(updated, iterative=iterative)
        return self.from_minimal(updated, preds)
    else:
        goal = self.to_minimal(goal)
        updated = self.apply_goal(updated, goal=goal, iterative=iterative)
        return self.from_minimal(updated, preds)

run_cm

run_cm(cm, preds, goal=None, device='cpu')

Run a constraints module both ways and assert they agree.

Runs the module with the iterative and tensor implementations and checks the results are numerically close; mainly a testing/debugging helper.

Parameters:

Name Type Description Default
cm

The :class:ConstraintsModule to run.

required
preds

The prediction tensor.

required
goal

Optional goal assignment.

None
device

The torch device to run on.

'cpu'

Returns:

Type Description

The corrected predictions (from the iterative implementation), on CPU.

Raises:

Type Description
AssertionError

If the two implementations disagree.

Source code in pishield/propositional_requirements/constraints_module.py
def run_cm(cm, preds, goal=None, device='cpu'):
    """Run a constraints module both ways and assert they agree.

    Runs the module with the iterative and tensor implementations and checks the
    results are numerically close; mainly a testing/debugging helper.

    Args:
        cm: The :class:`ConstraintsModule` to run.
        preds: The prediction tensor.
        goal: Optional goal assignment.
        device: The torch device to run on.

    Returns:
        The corrected predictions (from the iterative implementation), on CPU.

    Raises:
        AssertionError: If the two implementations disagree.
    """
    cm, preds = cm.to(device), preds.to(device)
    if not goal is None: goal = goal.to(device)

    iter = cm(preds, goal=goal, iterative=True)
    tens = cm(preds, goal=goal, iterative=False)
    assert torch.isclose(iter, tens).all()
    return iter.cpu()

Strong coherency

strong_coherency

Strong-coherency preprocessing of constraints.

When several constraints share the same head (some asserting it positively, some negatively), the Shield Layer must correct that head consistently. This module rewrites such constraint sets so that the positive and negative rules branch on a common body literal, which guarantees a strongly coherent correction. The shared literal is chosen as the highest-ranking atom (per the variable ordering) appearing in the rule bodies.

get_max_ranking_atom

get_max_ranking_atom(atoms_list, ranking)

Find the atom with the smallest position in the ranking.

Parameters:

Name Type Description Default
atoms_list

The candidate atom indices.

required
ranking

An ordered sequence of atoms; earlier means higher priority.

required

Returns:

Type Description

A tuple (atom, rank) of the highest-ranked atom and its index in

ranking, or (None, None) if atoms_list is empty.

Source code in pishield/propositional_requirements/strong_coherency.py
def get_max_ranking_atom(atoms_list, ranking):
    """Find the atom with the smallest position in the ranking.

    Args:
        atoms_list: The candidate atom indices.
        ranking: An ordered sequence of atoms; earlier means higher priority.

    Returns:
        A tuple ``(atom, rank)`` of the highest-ranked atom and its index in
        ``ranking``, or ``(None, None)`` if ``atoms_list`` is empty.
    """
    ranks = [list(ranking).index(atom) for atom in atoms_list]
    if len(ranks) == 0:
        return None, None
    return atoms_list[np.argmin(ranks)], np.min(ranks)

create_new_rule

create_new_rule(rule, extra_literal, positive)

Return a copy of a rule with one extra literal added to its body.

Parameters:

Name Type Description Default
rule

The :class:Constraint to extend.

required
extra_literal

The atom index of the literal to add.

required
positive

The polarity of the added literal.

required

Returns:

Type Description

A new :class:Constraint with the same head and the extended body.

Source code in pishield/propositional_requirements/strong_coherency.py
def create_new_rule(rule, extra_literal, positive):
    """Return a copy of a rule with one extra literal added to its body.

    Args:
        rule: The :class:`Constraint` to extend.
        extra_literal: The atom index of the literal to add.
        positive: The polarity of the added literal.

    Returns:
        A new :class:`Constraint` with the same head and the extended body.
    """
    new_literal = Literal(extra_literal, positive)
    new_rule = Constraint(rule.head, rule.body.union([new_literal]))
    # rule.body = rule.body.union([new_literal])
    return new_rule

get_max_ranking_eligible_atom_from_sets_of_rules

get_max_ranking_eligible_atom_from_sets_of_rules(R, R_other, literal_ranking)

Pick the highest-ranked body atom across two rule sets.

Parameters:

Name Type Description Default
R

A list of constraints (e.g. the positive-head rules).

required
R_other

Another list of constraints (e.g. the negative-head rules).

required
literal_ranking

The ordering used to rank atoms.

required

Returns:

Type Description

The atom index of the highest-ranked body literal occurring in either set, or

None if neither set has any body literals.

Source code in pishield/propositional_requirements/strong_coherency.py
def get_max_ranking_eligible_atom_from_sets_of_rules(R, R_other, literal_ranking):
    """Pick the highest-ranked body atom across two rule sets.

    Args:
        R: A list of constraints (e.g. the positive-head rules).
        R_other: Another list of constraints (e.g. the negative-head rules).
        literal_ranking: The ordering used to rank atoms.

    Returns:
        The atom index of the highest-ranked body literal occurring in either set, or
        None if neither set has any body literals.
    """
    all_literals = set([])

    for r in R:
        all_literals = all_literals.union([lit.atom for lit in r.body])
    for r in R_other:
        all_literals = all_literals.union([lit.atom for lit in r.body])

    max_ranking_atom, max_rank = get_max_ranking_atom(list(all_literals), literal_ranking)
    return max_ranking_atom

extend_rules_set

extend_rules_set(R, max_ranking_body_literal)

Branch every rule on the chosen literal unless it already contains it.

For each rule that does not already mention max_ranking_body_literal, two new rules are produced (with the literal added positively and negatively); rules that already contain the literal (in either polarity) are kept unchanged.

Parameters:

Name Type Description Default
R

The list of constraints (all sharing the same head) to extend.

required
max_ranking_body_literal

The atom index to branch the rules on.

required

Returns:

Type Description

A list of the extended constraints.

Source code in pishield/propositional_requirements/strong_coherency.py
def extend_rules_set(R, max_ranking_body_literal):
    """Branch every rule on the chosen literal unless it already contains it.

    For each rule that does not already mention ``max_ranking_body_literal``, two new
    rules are produced (with the literal added positively and negatively); rules that
    already contain the literal (in either polarity) are kept unchanged.

    Args:
        R: The list of constraints (all sharing the same head) to extend.
        max_ranking_body_literal: The atom index to branch the rules on.

    Returns:
        A list of the extended constraints.
    """
    new_rules_R_hat = set([])
    print("head", R[0].head, "with old rules", len(R))

    for r in R:
        literals_in_r = [lit.atom for lit in r.body]

        # if l is not in body(r), add new rules:
        if max_ranking_body_literal not in literals_in_r:
            new_rules_R_hat.add(create_new_rule(r, max_ranking_body_literal, positive=True))
            new_rules_R_hat.add(create_new_rule(r, max_ranking_body_literal, positive=False))
        else:
            # do the following step for each r, to ensure no r is left out in case no new rules can be added for it!!!
            # so, if r already contains l (pos or neg), then return the old rule:
            new_rules_R_hat.add(r)

    print("head", R[0].head, "with new rules", len(new_rules_R_hat))
    return list(new_rules_R_hat)

strong_coherency_constraint_preprocessing

strong_coherency_constraint_preprocessing(R_atom, literal_ranking)

Rewrite a head's constraints to guarantee a strongly coherent correction.

Splits the constraints for one head into those asserting it positively and those asserting it negatively, picks the highest-ranked shared body atom, and branches both groups on that atom so the head can be corrected consistently.

Parameters:

Name Type Description Default
R_atom

The list of constraints sharing the same head atom.

required
literal_ranking

The variable ordering used to choose the branching atom.

required

Returns:

Name Type Description
A

class:ConstraintsGroup of the rewritten constraints, or None if no

eligible shared body atom exists.

Source code in pishield/propositional_requirements/strong_coherency.py
def strong_coherency_constraint_preprocessing(R_atom, literal_ranking):
    """Rewrite a head's constraints to guarantee a strongly coherent correction.

    Splits the constraints for one head into those asserting it positively and those
    asserting it negatively, picks the highest-ranked shared body atom, and branches
    both groups on that atom so the head can be corrected consistently.

    Args:
        R_atom: The list of constraints sharing the same head atom.
        literal_ranking: The variable ordering used to choose the branching atom.

    Returns:
        A :class:`ConstraintsGroup` of the rewritten constraints, or None if no
        eligible shared body atom exists.
    """
    literal_ranking = list(literal_ranking)

    R_atom_plus, R_atom_minus = [], []
    for constr in R_atom:
        if constr.head.positive:
            R_atom_plus.append(constr)
        else:
            R_atom_minus.append(constr)
    print(R_atom_minus, R_atom_plus)

    # l = max_lambda over literals in R+ and R-
    max_ranking_body_literal = get_max_ranking_eligible_atom_from_sets_of_rules(R_atom_plus, R_atom_minus,
                                                                                literal_ranking)
    if max_ranking_body_literal is None:
        return None

    if len(R_atom_plus):
        R_atom_plus = extend_rules_set(R=R_atom_plus, max_ranking_body_literal=max_ranking_body_literal)
    if len(R_atom_minus):
        R_atom_minus = extend_rules_set(R=R_atom_minus, max_ranking_body_literal=max_ranking_body_literal)

    new_R_atom = ConstraintsGroup(R_atom_plus.copy() + R_atom_minus.copy())
    return new_R_atom

Detection threshold

detection_threshold

Detection-threshold based slicing of predictions.

A detection variable (at column 0) gates whether an example is processed: only rows whose detection probability exceeds a threshold are kept (cut) before the requirements are applied, and the corrected rows are then scattered back into the original tensor (uncut).

DetectionThreshold

DetectionThreshold(threshold)

Restrict prediction rows to those passing a detection threshold.

Attributes:

Name Type Description
threshold

The minimum value of the detection variable (column 0) for a row to be kept.

Store the detection threshold.

Parameters:

Name Type Description Default
threshold

The minimum detection probability for a row to be kept.

required
Source code in pishield/propositional_requirements/detection_threshold.py
def __init__(self, threshold):
    """Store the detection threshold.

    Args:
        threshold: The minimum detection probability for a row to be kept.
    """
    self.threshold = threshold

cut

cut(preds, mask)

Drop the detection column and keep only masked rows.

Parameters:

Name Type Description Default
preds

The full prediction tensor, shape (batch, num_classes).

required
mask

A boolean row mask selecting which examples to keep.

required

Returns:

Type Description

A tuple of the sliced predictions (kept rows, columns from index 1

onward) and a callable that scatters updated values back via :meth:uncut.

Source code in pishield/propositional_requirements/detection_threshold.py
def cut(self, preds, mask):
    """Drop the detection column and keep only masked rows.

    Args:
        preds: The full prediction tensor, shape (batch, num_classes).
        mask: A boolean row mask selecting which examples to keep.

    Returns:
        A tuple of the sliced predictions (kept rows, columns from index 1
        onward) and a callable that scatters updated values back via :meth:`uncut`.
    """
    return preds[mask, 1:], lambda updated: self.uncut(preds, mask, updated)

uncut

uncut(init, mask, preds)

Scatter corrected rows back into the original tensor.

Re-attaches the detection column and writes the corrected rows back into the positions selected by mask, leaving non-selected rows untouched.

Parameters:

Name Type Description Default
init

The original full prediction tensor before cutting.

required
mask

The boolean row mask used by :meth:cut.

required
preds

The corrected predictions for the kept rows (without column 0).

required

Returns:

Type Description

A tensor the shape of init with corrected rows written back in place.

Source code in pishield/propositional_requirements/detection_threshold.py
def uncut(self, init, mask, preds):
    """Scatter corrected rows back into the original tensor.

    Re-attaches the detection column and writes the corrected rows back into the
    positions selected by ``mask``, leaving non-selected rows untouched.

    Args:
        init: The original full prediction tensor before cutting.
        mask: The boolean row mask used by :meth:`cut`.
        preds: The corrected predictions for the kept rows (without column 0).

    Returns:
        A tensor the shape of ``init`` with corrected rows written back in place.
    """
    preds = torch.cat((init[mask, 0].reshape(-1, 1), preds), dim=1)
    index = torch.tensor(list(range(init.shape[0])), device=mask.device)
    index = index[mask]

    # init = torch.cat((init[:, 0].reshape(-1, 1), torch.zeros_like(init[:, 1:])), dim=1)
    return init.index_copy(0, index, preds)

cutter

cutter(preds)

Build a cut function for predictions using the detection threshold.

Parameters:

Name Type Description Default
preds

A prediction tensor whose column 0 is the detection variable.

required

Returns:

Type Description

A callable mapping a prediction tensor to the result of :meth:cut using

a row mask derived from this tensor's detection column.

Source code in pishield/propositional_requirements/detection_threshold.py
def cutter(self, preds):
    """Build a cut function for predictions using the detection threshold.

    Args:
        preds: A prediction tensor whose column 0 is the detection variable.

    Returns:
        A callable mapping a prediction tensor to the result of :meth:`cut` using
        a row mask derived from this tensor's detection column.
    """
    mask = preds[:, 0] > self.threshold
    return lambda preds: self.cut(preds, mask)

Slicer

slicer

Partial application of the Shield Layer.

The Slicer lets the Shield Layer apply only a prefix of its stratified correction modules to a subset of the variables (atoms), which is used to gradually enable the requirements during training.

Slicer

Slicer(atoms, modules)

Selects a subset of atoms and a prefix of correction modules.

Attributes:

Name Type Description
atoms

The list of variable (atom) indices the slicer restricts to.

modules

The number of leading correction modules to keep.

Store the atoms and module count to slice to.

Parameters:

Name Type Description Default
atoms

An iterable of variable indices to keep.

required
modules

The number of leading correction modules to keep.

required
Source code in pishield/propositional_requirements/slicer.py
def __init__(self, atoms, modules):
    """Store the atoms and module count to slice to.

    Args:
        atoms: An iterable of variable indices to keep.
        modules: The number of leading correction modules to keep.
    """
    self.atoms = list(atoms)
    self.modules = modules
    print(f"Created slicer for {modules} modules (atoms {atoms})")

slice_atoms

slice_atoms(tensor)

Select the slicer's columns from a tensor.

Parameters:

Name Type Description Default
tensor

A 2D tensor indexed by variable in its columns.

required

Returns:

Type Description

The tensor restricted to the slicer's atom columns.

Source code in pishield/propositional_requirements/slicer.py
def slice_atoms(self, tensor):
    """Select the slicer's columns from a tensor.

    Args:
        tensor: A 2D tensor indexed by variable in its columns.

    Returns:
        The tensor restricted to the slicer's atom columns.
    """
    return tensor[:, self.atoms]

slice_modules

slice_modules(modules)

Return the leading prefix of correction modules.

Parameters:

Name Type Description Default
modules

The full ordered sequence of correction modules.

required

Returns:

Type Description

The first self.modules of them.

Source code in pishield/propositional_requirements/slicer.py
def slice_modules(self, modules):
    """Return the leading prefix of correction modules.

    Args:
        modules: The full ordered sequence of correction modules.

    Returns:
        The first ``self.modules`` of them.
    """
    return modules[:self.modules]

Profiler

profiler

Lightweight time and GPU-memory profiling.

Provides a hierarchical profiler used to instrument the constraint modules: named "watches" (context managers, or the :meth:Profiler.wrap decorator) record the elapsed time and peak/net CUDA memory of code regions, aggregating the measurements into a tree of :class:Stats. On CPU-only systems the memory measurements degrade to zero.

MaxStack

MaxStack()

A stack whose update raises stored values to a running maximum.

Supports the usual push/pop plus an update(x) that bumps the top of stack up to x; popping a value also propagates it into the new top, so nested peak measures accumulate correctly.

Attributes:

Name Type Description
stack

The underlying list of values.

Initialise an empty stack.

Source code in pishield/propositional_requirements/profiler.py
def __init__(self):
    """Initialise an empty stack."""
    self.stack = []

push

push(value)

Push a value onto the stack.

Parameters:

Name Type Description Default
value

The value to push.

required
Source code in pishield/propositional_requirements/profiler.py
def push(self, value):
    """Push a value onto the stack.

    Args:
        value: The value to push.
    """
    self.stack.append(value)

update

update(value)

Raise the top of the stack to at least value.

Parameters:

Name Type Description Default
value

The lower bound to apply to the current top (no-op if empty).

required
Source code in pishield/propositional_requirements/profiler.py
def update(self, value):
    """Raise the top of the stack to at least ``value``.

    Args:
        value: The lower bound to apply to the current top (no-op if empty).
    """
    if len(self.stack):
        last = self.stack.pop()
        self.stack.append(max(last, value))

pop

pop()

Pop the top value, propagating it into the new top as a maximum.

Returns:

Type Description

The popped value.

Source code in pishield/propositional_requirements/profiler.py
def pop(self):
    """Pop the top value, propagating it into the new top as a maximum.

    Returns:
        The popped value.
    """
    value = self.stack.pop()
    if len(self.stack):
        last = self.stack.pop()
        self.stack.append(max(last, value))
    return value

PeakMemoryManager

PeakMemoryManager()

Singleton tracking nested peak CUDA-memory measurements.

Uses a :class:MaxStack so that entering a region resets the CUDA peak counter and exiting returns the peak observed within that region, while still propagating it to enclosing regions.

Attributes:

Name Type Description
stack

The :class:MaxStack of in-flight peak measurements.

Initialise with an empty measurement stack.

Source code in pishield/propositional_requirements/profiler.py
def __init__(self):
    """Initialise with an empty measurement stack."""
    self.stack = MaxStack()

enter

enter()

Begin a nested measurement region, resetting the CUDA peak counter.

Source code in pishield/propositional_requirements/profiler.py
def enter(self):
    """Begin a nested measurement region, resetting the CUDA peak counter."""
    self.stack.update(get_peak())
    reset_peak()
    self.stack.push(0)

exit

exit()

End the current measurement region.

Returns:

Type Description

The peak CUDA memory observed within the region.

Source code in pishield/propositional_requirements/profiler.py
def exit(self):
    """End the current measurement region.

    Returns:
        The peak CUDA memory observed within the region.
    """
    self.stack.update(get_peak())
    return self.stack.pop()

Stats

Stats(peak, diff, sum, tdiff, tsum)

Aggregated time and memory statistics for a profiled region.

Attributes:

Name Type Description
peak

The maximum peak memory observed.

diff

The maximum net memory change observed.

sum

The summed net memory change across measurements.

tdiff

The maximum elapsed time observed.

tsum

The summed elapsed time across measurements.

Store the individual statistics.

Parameters:

Name Type Description Default
peak

The peak memory.

required
diff

The net memory change.

required
sum

The summed net memory change.

required
tdiff

The elapsed time.

required
tsum

The summed elapsed time.

required
Source code in pishield/propositional_requirements/profiler.py
def __init__(self, peak, diff, sum, tdiff, tsum):
    """Store the individual statistics.

    Args:
        peak: The peak memory.
        diff: The net memory change.
        sum: The summed net memory change.
        tdiff: The elapsed time.
        tsum: The summed elapsed time.
    """
    self.peak = peak
    self.diff = diff
    self.sum = sum

    self.tdiff = tdiff
    self.tsum = tsum

single classmethod

single(peak, diff, tdiff)

Build a Stats from a single measurement (sum equals the value).

Parameters:

Name Type Description Default
peak

The peak memory.

required
diff

The net memory change.

required
tdiff

The elapsed time.

required

Returns:

Type Description

A Stats whose sum fields equal the corresponding single values.

Source code in pishield/propositional_requirements/profiler.py
@classmethod
def single(cls, peak, diff, tdiff):
    """Build a Stats from a single measurement (sum equals the value).

    Args:
        peak: The peak memory.
        diff: The net memory change.
        tdiff: The elapsed time.

    Returns:
        A Stats whose sum fields equal the corresponding single values.
    """
    return cls(peak, diff, diff, tdiff, tdiff)

null classmethod

null()

Return a zero-valued Stats (the additive identity).

Source code in pishield/propositional_requirements/profiler.py
@classmethod
def null(cls):
    """Return a zero-valued Stats (the additive identity)."""
    return cls.single(0, 0, 0)

memory

memory(long=True)

Return the memory statistics as a tuple.

Parameters:

Name Type Description Default
long

If True include the summed change; otherwise only peak and diff.

True

Returns:

Type Description

(peak, diff, sum) when long else (peak, diff).

Source code in pishield/propositional_requirements/profiler.py
def memory(self, long=True):
    """Return the memory statistics as a tuple.

    Args:
        long: If True include the summed change; otherwise only peak and diff.

    Returns:
        ``(peak, diff, sum)`` when ``long`` else ``(peak, diff)``.
    """
    if long:
        return (self.peak, self.diff, self.sum)
    else:
        return (self.peak, self.diff)

time

time(long=True)

Return the timing statistics as a tuple.

Parameters:

Name Type Description Default
long

If True include the summed time; otherwise only the max.

True

Returns:

Type Description

(tdiff, tsum) when long else (tdiff,).

Source code in pishield/propositional_requirements/profiler.py
def time(self, long=True):
    """Return the timing statistics as a tuple.

    Args:
        long: If True include the summed time; otherwise only the max.

    Returns:
        ``(tdiff, tsum)`` when ``long`` else ``(tdiff,)``.
    """
    if long:
        return (self.tdiff, self.tsum)
    else:
        return (self.tdiff,)

tuple

tuple(long=True)

Return the combined time and memory statistics.

Parameters:

Name Type Description Default
long

Whether to include the summed fields.

True

Returns:

Type Description

A tuple (time_tuple, memory_tuple).

Source code in pishield/propositional_requirements/profiler.py
def tuple(self, long=True):
    """Return the combined time and memory statistics.

    Args:
        long: Whether to include the summed fields.

    Returns:
        A tuple ``(time_tuple, memory_tuple)``.
    """
    return (self.time(long), self.memory(long))

Profiler

Profiler(watches=None)

Hierarchical recorder of timing and memory statistics.

A profiler owns a tree of named "watches"; each leaf accumulates a list of :class:Stats. Sub-profilers created via :meth:branch share the parent's subtree, allowing nested instrumentation. Profiling can be globally toggled with :meth:enable/:meth:disable.

Attributes:

Name Type Description
watches

The nested dict of recorded watches (subtrees or lists of Stats).

manager

The shared :class:PeakMemoryManager.

Create a profiler over an optional existing watches subtree.

Parameters:

Name Type Description Default
watches

An existing watches dict to attach to, or None for a fresh tree.

None
Source code in pishield/propositional_requirements/profiler.py
def __init__(self, watches=None):
    """Create a profiler over an optional existing watches subtree.

    Args:
        watches: An existing watches dict to attach to, or None for a fresh tree.
    """
    self.watches = dict() if watches is None else watches
    self.manager = PeakMemoryManager()

enable classmethod

enable()

Globally enable profiling.

Source code in pishield/propositional_requirements/profiler.py
@classmethod
def enable(cls):
    """Globally enable profiling."""
    cls._enabled = True

disable classmethod

disable()

Globally disable profiling.

Source code in pishield/propositional_requirements/profiler.py
@classmethod
def disable(cls):
    """Globally disable profiling."""
    cls._enabled = False

enabled classmethod

enabled()

Return whether profiling is currently enabled.

Source code in pishield/propositional_requirements/profiler.py
@classmethod
def enabled(cls):
    """Return whether profiling is currently enabled."""
    return cls._enabled

shared classmethod

shared()

Return the shared singleton profiler instance.

Source code in pishield/propositional_requirements/profiler.py
@classmethod
def shared(cls):
    """Return the shared singleton profiler instance."""
    if not hasattr(cls, '_shared_'):
        cls._shared_ = cls()
    return cls._shared_

norm staticmethod

norm(x)

Convert a byte count to mebibytes.

Parameters:

Name Type Description Default
x

A value in bytes.

required

Returns:

Type Description

The value in MiB.

Source code in pishield/propositional_requirements/profiler.py
@staticmethod
def norm(x):
    """Convert a byte count to mebibytes.

    Args:
        x: A value in bytes.

    Returns:
        The value in MiB.
    """
    return x / 1024 / 1024

branch

branch(name)

Return a sub-profiler for a named subtree, creating it if needed.

Parameters:

Name Type Description Default
name

The name of the branch.

required

Returns:

Type Description

A Profiler scoped to the named subtree.

Source code in pishield/propositional_requirements/profiler.py
def branch(self, name):
    """Return a sub-profiler for a named subtree, creating it if needed.

    Args:
        name: The name of the branch.

    Returns:
        A Profiler scoped to the named subtree.
    """
    if not name in self.watches:
        self.watches[name] = dict()
    return Profiler(self.watches[name])

register

register(name, peak, diff, tdiff)

Record one measurement under a named watch.

Parameters:

Name Type Description Default
name

The watch name.

required
peak

The peak memory in bytes.

required
diff

The net memory change in bytes.

required
tdiff

The elapsed time in seconds.

required
Source code in pishield/propositional_requirements/profiler.py
def register(self, name, peak, diff, tdiff):
    """Record one measurement under a named watch.

    Args:
        name: The watch name.
        peak: The peak memory in bytes.
        diff: The net memory change in bytes.
        tdiff: The elapsed time in seconds.
    """
    [peak, diff] = [Profiler.norm(x) for x in [peak, diff]]
    stats = Stats.single(peak, diff, tdiff)

    if not name in self.watches:
        self.watches[name] = [stats]
    else:
        self.watches[name].append(stats)

watch

watch(name)

Return a context manager that times and measures a named region.

Parameters:

Name Type Description Default
name

The watch name to record under.

required

Returns:

Name Type Description
A

class:Watch context manager.

Source code in pishield/propositional_requirements/profiler.py
def watch(self, name):
    """Return a context manager that times and measures a named region.

    Args:
        name: The watch name to record under.

    Returns:
        A :class:`Watch` context manager.
    """
    return Watch(name, self)

wrap

wrap(f)

Decorator that profiles a function under its own name.

Parameters:

Name Type Description Default
f

The function to wrap.

required

Returns:

Type Description

The wrapped function, profiled under f.__name__.

Source code in pishield/propositional_requirements/profiler.py
def wrap(self, f):
    """Decorator that profiles a function under its own name.

    Args:
        f: The function to wrap.

    Returns:
        The wrapped function, profiled under ``f.__name__``.
    """
    @functools.wraps(f)
    def profiled(*args, **kwargs):
        with self.watch(f.__name__):
            return f(*args, **kwargs)

    return profiled

map_dict classmethod

map_dict(f, node)

Recursively apply a function to every leaf of a nested dict.

Parameters:

Name Type Description Default
f

The function applied to each non-dict leaf value.

required
node

The nested dict to traverse.

required

Returns:

Type Description

A new nested dict with f applied to every leaf.

Source code in pishield/propositional_requirements/profiler.py
@classmethod
def map_dict(cls, f, node):
    """Recursively apply a function to every leaf of a nested dict.

    Args:
        f: The function applied to each non-dict leaf value.
        node: The nested dict to traverse.

    Returns:
        A new nested dict with ``f`` applied to every leaf.
    """
    result = dict()
    for key in node:
        if isinstance(node[key], dict):
            result[key] = cls.map_dict(f, node[key])
        else:
            result[key] = f(node[key])
    return result

reset

reset()

Clear all recorded measurements, leaving a single null Stats per watch.

Source code in pishield/propositional_requirements/profiler.py
def reset(self):
    """Clear all recorded measurements, leaving a single null Stats per watch."""
    def zero(x):
        x.clear()
        x.append(Stats.null())

    Profiler.map_dict(zero, self.watches)

get_kind

get_kind(kind, long)

Return a selector mapping a Stats to the requested view.

Parameters:

Name Type Description Default
kind

One of 'all' (time and memory), 'gpu' (memory), or 'time'.

required
long

Whether the returned tuples should include the summed fields.

required

Returns:

Type Description

A callable mapping a :class:Stats to the selected tuple.

Raises:

Type Description
Exception

If kind is not recognised.

Source code in pishield/propositional_requirements/profiler.py
def get_kind(self, kind, long):
    """Return a selector mapping a Stats to the requested view.

    Args:
        kind: One of ``'all'`` (time and memory), ``'gpu'`` (memory), or
            ``'time'``.
        long: Whether the returned tuples should include the summed fields.

    Returns:
        A callable mapping a :class:`Stats` to the selected tuple.

    Raises:
        Exception: If ``kind`` is not recognised.
    """
    if kind == 'all':
        return lambda x: x.tuple(long)
    elif kind == 'gpu':
        return lambda x: x.memory(long)
    elif kind == 'time':
        return lambda x: x.time(long)
    else:
        raise Exception(f"Unknown kind {kind}")

all

all(kind='all')

Return every individual measurement per watch.

Parameters:

Name Type Description Default
kind

The view to extract (see :meth:get_kind).

'all'

Returns:

Type Description

A nested dict mapping each watch to the list of its per-measurement tuples.

Source code in pishield/propositional_requirements/profiler.py
def all(self, kind='all'):
    """Return every individual measurement per watch.

    Args:
        kind: The view to extract (see :meth:`get_kind`).

    Returns:
        A nested dict mapping each watch to the list of its per-measurement tuples.
    """
    kinder = self.get_kind(kind, long=False)
    return Profiler.map_dict(lambda xs: [kinder(x) for x in xs], self.watches)

combined

combined(kind='all')

Return the per-watch aggregate of all its measurements.

Parameters:

Name Type Description Default
kind

The view to extract (see :meth:get_kind).

'all'

Returns:

Type Description

A nested dict mapping each watch to its combined Stats tuple.

Source code in pishield/propositional_requirements/profiler.py
def combined(self, kind='all'):
    """Return the per-watch aggregate of all its measurements.

    Args:
        kind: The view to extract (see :meth:`get_kind`).

    Returns:
        A nested dict mapping each watch to its combined Stats tuple.
    """
    kinder = self.get_kind(kind, long=True)
    return Profiler.map_dict(lambda x: kinder(sum(x, Stats.null())), self.watches)

total

total(kind='all')

Return the aggregate over every watch in the tree.

Parameters:

Name Type Description Default
kind

The view to extract (see :meth:get_kind).

'all'

Returns:

Type Description

The combined Stats tuple summed across all watches.

Source code in pishield/propositional_requirements/profiler.py
def total(self, kind='all'):
    """Return the aggregate over every watch in the tree.

    Args:
        kind: The view to extract (see :meth:`get_kind`).

    Returns:
        The combined Stats tuple summed across all watches.
    """
    kinder = self.get_kind(kind, long=True)
    result = Stats.null()

    def update(xs):
        nonlocal result
        for x in xs: result = result + x

    Profiler.map_dict(update, self.watches)
    return kinder(result)

Watch

Watch(name, profiler)

Context manager that records the time and memory of its body.

On enter it snapshots the time and allocated memory; on exit it computes the elapsed time, net memory change and peak memory and registers them with the profiler. Both methods are no-ops when profiling is disabled.

Attributes:

Name Type Description
name

The watch name to register under.

profiler

The owning :class:Profiler.

Store the watch name and owning profiler.

Parameters:

Name Type Description Default
name

The watch name.

required
profiler

The owning Profiler.

required
Source code in pishield/propositional_requirements/profiler.py
def __init__(self, name, profiler):
    """Store the watch name and owning profiler.

    Args:
        name: The watch name.
        profiler: The owning Profiler.
    """
    self.name = name
    self.profiler = profiler

singleton

singleton(cls)

Class decorator that turns a class into a lazily-created singleton.

Parameters:

Name Type Description Default
cls

The class to wrap.

required

Returns:

Type Description

A constructor function that always returns the single shared instance.

Source code in pishield/propositional_requirements/profiler.py
def singleton(cls):
    """Class decorator that turns a class into a lazily-created singleton.

    Args:
        cls: The class to wrap.

    Returns:
        A constructor function that always returns the single shared instance.
    """
    @functools.wraps(cls)
    def constructor(*args, **kwargs):
        if not hasattr(cls, '_instance_'):
            cls._instance_ = cls(*args, **kwargs)
        return cls._instance_

    constructor.__dict__.update(cls.__dict__)
    return constructor

no_cuda

no_cuda()

Return True if CUDA is unavailable.

Source code in pishield/propositional_requirements/profiler.py
def no_cuda():
    """Return True if CUDA is unavailable."""
    return not torch.cuda.is_available()

get_allocated

get_allocated(device=None)

Return the currently allocated CUDA memory in bytes (0 without CUDA).

Parameters:

Name Type Description Default
device

Optional CUDA device.

None

Returns:

Type Description

The allocated memory in bytes, or 0 if CUDA is unavailable.

Source code in pishield/propositional_requirements/profiler.py
def get_allocated(device=None):
    """Return the currently allocated CUDA memory in bytes (0 without CUDA).

    Args:
        device: Optional CUDA device.

    Returns:
        The allocated memory in bytes, or 0 if CUDA is unavailable.
    """
    if no_cuda(): return 0
    return torch.cuda.memory_allocated(device)

get_peak

get_peak(device=None)

Return the peak allocated CUDA memory in bytes (0 without CUDA).

Parameters:

Name Type Description Default
device

Optional CUDA device.

None

Returns:

Type Description

The peak allocated memory in bytes, or 0 if CUDA is unavailable.

Source code in pishield/propositional_requirements/profiler.py
def get_peak(device=None):
    """Return the peak allocated CUDA memory in bytes (0 without CUDA).

    Args:
        device: Optional CUDA device.

    Returns:
        The peak allocated memory in bytes, or 0 if CUDA is unavailable.
    """
    if no_cuda(): return 0
    return torch.cuda.max_memory_allocated(device)

reset_peak

reset_peak(device=None)

Reset the CUDA peak-memory statistics (no-op without CUDA).

Parameters:

Name Type Description Default
device

Optional CUDA device.

None
Source code in pishield/propositional_requirements/profiler.py
def reset_peak(device=None):
    """Reset the CUDA peak-memory statistics (no-op without CUDA).

    Args:
        device: Optional CUDA device.
    """
    if no_cuda(): return None
    torch.cuda.reset_peak_memory_stats(device)

condition

condition(cond)

Build a decorator that only runs the wrapped function when a condition holds.

Parameters:

Name Type Description Default
cond

A boolean, or a callable returning a boolean, evaluated on each call.

required

Returns:

Type Description

A decorator that runs the function when the condition is truthy, else returns

None.

Source code in pishield/propositional_requirements/profiler.py
def condition(cond):
    """Build a decorator that only runs the wrapped function when a condition holds.

    Args:
        cond: A boolean, or a callable returning a boolean, evaluated on each call.

    Returns:
        A decorator that runs the function when the condition is truthy, else returns
        None.
    """
    def conditioned(f):
        @functools.wraps(f)
        def decorated(*args, **kwargs):
            should = cond() if callable(cond) else cond
            return f(*args, **kwargs) if should else None

        return decorated

    return conditioned

Utilities

util

Training, evaluation and ordering utilities for propositional requirements.

Helper routines used in examples and experiments: a standard PyTorch training loop and test loop that run predictions through a Shield Layer, a function for visualising the class predictions of a 2D model, and a helper for resolving the variable ordering / centrality used to stratify the requirements.

train

train(dataloader, model, clayer, loss_fn, optimizer, device, ratio=1.0)

Run one training epoch with predictions corrected by the Shield Layer.

For each batch, the model's predictions are passed through clayer (with the ground-truth labels as goal and a slicer that gradually enables the requirements), then the loss is computed on the sliced atoms and backpropagated.

Parameters:

Name Type Description Default
dataloader

A PyTorch DataLoader yielding (X, y) batches.

required
model

The prediction model.

required
clayer

The Shield Layer correcting the predictions.

required
loss_fn

The loss function applied to corrected predictions and labels.

required
optimizer

The optimizer updating the model parameters.

required
device

The torch device to run on.

required
ratio

The fraction of strata to enable via the layer's slicer (1.0 = all).

1.0
Source code in pishield/propositional_requirements/util.py
def train(dataloader, model, clayer, loss_fn, optimizer, device, ratio=1.):
    """Run one training epoch with predictions corrected by the Shield Layer.

    For each batch, the model's predictions are passed through ``clayer`` (with the
    ground-truth labels as goal and a slicer that gradually enables the requirements),
    then the loss is computed on the sliced atoms and backpropagated.

    Args:
        dataloader: A PyTorch DataLoader yielding ``(X, y)`` batches.
        model: The prediction model.
        clayer: The Shield Layer correcting the predictions.
        loss_fn: The loss function applied to corrected predictions and labels.
        optimizer: The optimizer updating the model parameters.
        device: The torch device to run on.
        ratio: The fraction of strata to enable via the layer's slicer (1.0 = all).
    """
    size = len(dataloader.dataset)
    model, clayer = model.to(device), clayer.to(device)
    slicer = clayer.slicer(ratio)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        constrained = clayer(pred, goal=y, slicer=slicer)

        constrained, y = slicer.slice_atoms(constrained), slicer.slice_atoms(y)
        loss = loss_fn(constrained, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

test

test(dataloader, model, clayer, loss_fn, device)

Evaluate the model with Shield-Layer-corrected predictions.

Runs the model over the dataloader (no goal), corrects the predictions with clayer, and reports the average loss and per-class accuracy.

Parameters:

Name Type Description Default
dataloader

A PyTorch DataLoader yielding (X, y) batches.

required
model

The prediction model.

required
clayer

The Shield Layer correcting the predictions.

required
loss_fn

The loss function.

required
device

The torch device to run on.

required

Returns:

Type Description

A tuple (test_loss, correct) of the average loss and the list of per-class

accuracy percentages.

Source code in pishield/propositional_requirements/util.py
def test(dataloader, model, clayer, loss_fn, device):
    """Evaluate the model with Shield-Layer-corrected predictions.

    Runs the model over the dataloader (no goal), corrects the predictions with
    ``clayer``, and reports the average loss and per-class accuracy.

    Args:
        dataloader: A PyTorch DataLoader yielding ``(X, y)`` batches.
        model: The prediction model.
        clayer: The Shield Layer correcting the predictions.
        loss_fn: The loss function.
        device: The torch device to run on.

    Returns:
        A tuple ``(test_loss, correct)`` of the average loss and the list of per-class
        accuracy percentages.
    """
    size = len(dataloader.dataset)
    model, clayer = model.to(device), clayer.to(device)
    model.eval()

    test_loss = 0.
    correct = 0.

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)
            pred = clayer(pred)
            test_loss += loss_fn(pred, y).item()
            correct += (torch.where(pred > 0.5, 1., 0.) == y).sum(dim=0)

    test_loss /= size
    correct /= size

    correct = [100 * rate for rate in correct]
    accuracy = ", ".join([f"{rate:>0.1f}%" for rate in correct])
    print(f"Test Error: \n Accuracy: {accuracy}")
    print(f" Avg loss: {test_loss:>8f} \n")
    return test_loss, correct

draw_classes

draw_classes(model, draw=None, path=None, device='cpu', show=False)

Plot each class score of a model over the unit square.

Evaluates the model on a dense grid of points in [0, 1)^2 and renders one filled-contour subplot per output class.

Parameters:

Name Type Description Default
model

A model mapping 2D points to per-class scores.

required
draw

Optional callable draw(ax, i) to overlay extra artwork on subplot i.

None
path

Optional file path to save the figure to.

None
device

The torch device to run the model on.

'cpu'
show

If True, display the figure interactively.

False

Returns:

Type Description

The matplotlib Figure that was created.

Source code in pishield/propositional_requirements/util.py
def draw_classes(model, draw=None, path=None, device='cpu', show=False):
    """Plot each class score of a model over the unit square.

    Evaluates the model on a dense grid of points in ``[0, 1)^2`` and renders one
    filled-contour subplot per output class.

    Args:
        model: A model mapping 2D points to per-class scores.
        draw: Optional callable ``draw(ax, i)`` to overlay extra artwork on subplot i.
        path: Optional file path to save the figure to.
        device: The torch device to run the model on.
        show: If True, display the figure interactively.

    Returns:
        The matplotlib Figure that was created.
    """
    import matplotlib.pyplot as plt  # imported lazily: plotting is optional and pulls in a heavy dependency
    dots = np.arange(0., 1., 0.001, dtype="float32")
    grid = torch.tensor([(x, y) for y in dots for x in dots]).to(device)
    model = model.to(device)
    preds = model(grid).detach()

    classes = preds.shape[1]
    fig, ax = plt.subplots(1, classes, figsize=(20, 20 * classes))
    for i, ax in enumerate(ax):
        image = preds[:, i].view((len(dots), len(dots))).to('cpu')
        # ax.imshow(
        #     image, 
        #     cmap='hot', 
        #     interpolation='nearest', 
        #     origin='lower', 
        #     extent=(0., 1., 0., 1.),
        #     vmin=0.,
        #     vmax=1.
        # )
        ax.contourf(
            dots,
            dots,
            image,
            cmap='hot',
            origin='lower',
            extent=(0., 1., 0., 1.),
            vmin=0.1,
            vmax=1.
        )
        if draw != None: draw(ax, i)

    if show:
        plt.show()

    if not path is None:
        plt.savefig(path)
        plt.close()

    return fig

get_order_and_centrality

get_order_and_centrality(ordering_choice: str, custom_ordering: str)

Resolve the centrality/ordering argument used to stratify the requirements.

If a custom ordering is supplied and the choice is a custom/given one, the ordering is parsed into an explicit array of atom indices (reversed when the choice contains 'rev'); otherwise the named centrality choice is returned as-is.

Parameters:

Name Type Description Default
ordering_choice str

The ordering choice name (e.g. a centrality measure, or one containing 'custom'/'given', optionally with 'rev').

required
custom_ordering str

An optional comma-separated string of atom indices.

required

Returns:

Type Description

Either the ordering choice name (str) or an array of atom indices giving an

explicit order.

Source code in pishield/propositional_requirements/util.py
def get_order_and_centrality(ordering_choice: str, custom_ordering: str):
    """Resolve the centrality/ordering argument used to stratify the requirements.

    If a custom ordering is supplied and the choice is a custom/given one, the ordering
    is parsed into an explicit array of atom indices (reversed when the choice contains
    ``'rev'``); otherwise the named centrality choice is returned as-is.

    Args:
        ordering_choice: The ordering choice name (e.g. a centrality measure, or one
            containing ``'custom'``/``'given'``, optionally with ``'rev'``).
        custom_ordering: An optional comma-separated string of atom indices.

    Returns:
        Either the ordering choice name (str) or an array of atom indices giving an
        explicit order.
    """
    if custom_ordering is None:
        return ordering_choice
    if 'custom' in ordering_choice or 'given' in ordering_choice:
        order = custom_ordering.split(',')
        centrality = np.array([int(nr) for nr in order])
        if 'rev' in ordering_choice:
            centrality = centrality[::-1]
    else:
        centrality = ordering_choice
    return centrality