python - How to prevent certain input from impacting certain output of neural networks in pytorch? - Stack Overflow

admin2025-04-17  1

I have an LSTM model that receives 5 inputs to predict 3 outputs:

import torch
import torch.nn as nn

class LstmModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomLSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        None

I want to prevent certain input from having any impact on a certain output. Let's say, the first input should not have any effect on the prediction of the second output. In other words, the second prediction should not be a function of the first input.

One solution I have tried is using separate LSTMs for each output:

class LstmModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomLSTMModel, self).__init__()
        self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.lstm2 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.lstm3 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc1 = nn.Linear(hidden_size, output_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Assume x is of shape (batch_size, seq_length, input_size)
        # Split inputs
        input1, input2, input3, input4, input5 = x.split(1, dim=2)

        # Mask inputs for each output
        # For output1, exclude input2
        input1_for_output1 = torch.cat((input1, input3, input4, input5), dim=2)
        
        # For output2, exclude input3
        input2_for_output2 = torch.cat((input1, input2, input4, input5), dim=2)
        
        # For output3, exclude input4
        input3_for_output3 = torch.cat((input1, input2, input3, input5), dim=2)

        # Process through LSTM
        _, (hn1, _) = self.lstm1(input1_for_output1)
        output1 = self.fc1(hn1[-1])

        _, (hn2, _) = self.lstm2(input2_for_output2)
        output2 = self.fc2(hn2[-1])

        _, (hn3, _) = self.lstm3(input3_for_output3)
        output3 = self.fc2(hn3[-1])

        return output1, output2, output3

The problem with this approach is that it takes at least 3 times longer to run the model (since I am running LSTM 3 times, 1 for each output). Is it possible to do what I want to achieve more efficiently, with one run?

I have an LSTM model that receives 5 inputs to predict 3 outputs:

import torch
import torch.nn as nn

class LstmModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomLSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        None

I want to prevent certain input from having any impact on a certain output. Let's say, the first input should not have any effect on the prediction of the second output. In other words, the second prediction should not be a function of the first input.

One solution I have tried is using separate LSTMs for each output:

class LstmModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomLSTMModel, self).__init__()
        self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.lstm2 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.lstm3 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc1 = nn.Linear(hidden_size, output_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Assume x is of shape (batch_size, seq_length, input_size)
        # Split inputs
        input1, input2, input3, input4, input5 = x.split(1, dim=2)

        # Mask inputs for each output
        # For output1, exclude input2
        input1_for_output1 = torch.cat((input1, input3, input4, input5), dim=2)
        
        # For output2, exclude input3
        input2_for_output2 = torch.cat((input1, input2, input4, input5), dim=2)
        
        # For output3, exclude input4
        input3_for_output3 = torch.cat((input1, input2, input3, input5), dim=2)

        # Process through LSTM
        _, (hn1, _) = self.lstm1(input1_for_output1)
        output1 = self.fc1(hn1[-1])

        _, (hn2, _) = self.lstm2(input2_for_output2)
        output2 = self.fc2(hn2[-1])

        _, (hn3, _) = self.lstm3(input3_for_output3)
        output3 = self.fc2(hn3[-1])

        return output1, output2, output3

The problem with this approach is that it takes at least 3 times longer to run the model (since I am running LSTM 3 times, 1 for each output). Is it possible to do what I want to achieve more efficiently, with one run?

Share edited Jan 31 at 22:21 bird asked Jan 31 at 16:39 birdbird 3,2861 gold badge12 silver badges44 bronze badges 5
  • 1 Since you're first passing through an LSTM and then a FC layer, would you be fine just masking out the last hidden layer produced by the LSTM and feeding that through the FC? Or you don't want the inputs to even have an effect on the LSTM itself? – Chrispresso Commented Jan 31 at 17:48
  • @Chrispresso I am not sure if I understand your question, but when a prediction made for the first output, then I know that the second input should have 0 impact. – bird Commented Jan 31 at 22:18
  • 1 What do you mean by "certain input" and "certain output"? You are training a LSTM. The hidden state at time t contains information from all timesteps before t. If you don't want that, you should use a different model architecture – Karl Commented Feb 1 at 19:56
  • @Karl I have 5 inputs, and when I predict 1st output, the second input should not have any impact. This is what I mean by certain input. Esentially, I want to be able to manually choose which input can impact which output – bird Commented Feb 2 at 13:50
  • 1 Can you provide a practical example of what something like this would be useful for? It could help reduce any misunderstanding of the question. – Sachin Hosmani Commented Mar 29 at 21:44
Add a comment  | 

1 Answer 1

Reset to default 0

What you're trying to achieve is input masking. You can use separate fully connected layers (different people) for each output after running the LSTM once to capture all the information (one photo), instead of running the LSTM separately for each output (like taking multiple photos for different people).

You can mask or zero out portions of the LSTM's hidden state before sending it to the appropriate output layer if you want to make sure that particular inputs don't affect particular outputs. By running the LSTM just once, you can accomplish the same goal much more effectively.

I hope that helps you

Edit:

ok here's the modification needed base on your code:

import torch
import torch.nn as nn

class LstmModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LstmModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        
        # Fully connected layers for different outputs
        self.fc1 = nn.Linear(hidden_size, output_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

        # Create masks to remove certain inputs' effects
        self.mask1 = nn.Parameter(torch.tensor([1, 1, 0, 1, 1, 0, 0, 0], dtype=torch.float32), requires_grad=False)  # Ignore input3 for output1, padded with 0's
        self.mask2 = nn.Parameter(torch.tensor([1, 0, 1, 1, 1, 0, 0, 0], dtype=torch.float32), requires_grad=False)  # Ignore input2 for output2, padded with 0's
        self.mask3 = nn.Parameter(torch.tensor([1, 1, 1, 0, 1, 0, 0, 0], dtype=torch.float32), requires_grad=False)  # Ignore input4 for output3, padded with 0's


    def forward(self, x):
        batch_size, seq_length, input_size = x.shape
        
        # LSTM Forward pass
        _, (hn, _) = self.lstm(x)  # hn shape: (num_layers, batch, hidden_size)
        hidden = hn[-1]  # Get the final hidden state (batch, hidden_size)

        # Apply masks by element-wise multiplication with the hidden state
        hidden_masked1 = hidden * self.mask1  # Apply mask to hidden state for output1
        hidden_masked2 = hidden * self.mask2  # Apply mask for output2
        hidden_masked3 = hidden * self.mask3  # Apply mask for output3

        # Generate outputs
        output1 = self.fc1(hidden_masked1)
        output2 = self.fc2(hidden_masked2)
        output3 = self.fc3(hidden_masked3)

        return output1, output2, output3


# === TESTING THE MODEL ===
# Example input: (batch_size=2, seq_length=10, input_size=5)
batch_size = 2
seq_length = 10
input_size = 5
hidden_size = 8
output_size = 1  # Single value per output

# Create model
model = LstmModel(input_size, hidden_size, output_size)

# Generate some random input data
x = torch.randn(batch_size, seq_length, input_size)

# Forward pass
output1, output2, output3 = model(x)

print("Output 1:", output1)
print("Output 2:", output2)
print("Output 3:", output3)

the results without the hidden_masked:

Output 1: tensor([[-0.0487],
        [-0.0439]], grad_fn=<AddmmBackward0>)
Output 2: tensor([[-0.2588],
        [-0.2890]], grad_fn=<AddmmBackward0>)
Output 3: tensor([[0.1792],
        [0.1249]], grad_fn=<AddmmBackward0>)

with the hidden_masked:

Output 1: tensor([[0.3568],
        [0.3477]], grad_fn=<AddmmBackward0>)
Output 2: tensor([[-0.3200],
        [-0.3470]], grad_fn=<AddmmBackward0>)
Output 3: tensor([[0.4120],
        [0.2970]], grad_fn=<AddmmBackward0>)

i realy hope that the comments i added in the code clarify each line and what it role.

use google colab to quick test the code

Edit 2:

since i used hard code values, here's a more rebost way:

import torch
import torch.nn as nn

class MaskedLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, exclusion_map):
        """
        Args:
            exclusion_map: Dictionary mapping output_idx to excluded_input_idx
            Example: {1: 2, 2: 1, 3: 3}  # output1 excludes input3, output2 excludes input2, etc.
        """
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc_layers = nn.ModuleList([nn.Linear(hidden_size, output_size) 
                                      for _ in range(len(exclusion_map))])
        
        # Create trainable masks (modified)
        self.masks = nn.ParameterDict()
        for output_idx, excluded_input in exclusion_map.items():
            mask = torch.ones(hidden_size)
            
            # Zero 20% of dimensions associated with excluded input
            input_span = hidden_size // input_size
            start = excluded_input * input_span
            end = (excluded_input + 1) * input_span
            mask[start:end] = 0
            
            self.masks[f"mask_{output_idx}"] = nn.Parameter(mask, requires_grad=False)

    def forward(self, x):
        lstm_out, (hn, _) = self.lstm(x)
        final_hidden = hn[-1]  # (batch_size, hidden_size)
        
        outputs = []
        for idx, fc in enumerate(self.fc_layers):
            masked_hidden = final_hidden * self.masks[f"mask_{idx+1}"]
            outputs.append(fc(masked_hidden))
            
        return tuple(outputs)

# Configuration
exclusion_rules = {
    1: 2,  # Output1 excludes input3
    2: 1,  # Output2 excludes input2
    3: 3   # Output3 excludes input4
}

model = MaskedLSTM(input_size=5, hidden_size=10, 
                 output_size=1, exclusion_map=exclusion_rules)

# Test
x = torch.randn(3, 10, 5)  # batch_size=3, seq_len=10
out1, out2, out3 = model(x)

print("Output 1:", out1)
print("Output 2:", out2)
print("Output 3:", out3)
转载请注明原文地址:http://anycun.com/QandA/1744857383a88597.html