I have an LSTM model that receives 5 inputs to predict 3 outputs:
import torch
import torch.nn as nn
class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CustomLSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
None
I want to prevent certain input from having any impact on a certain output. Let's say, the first input should not have any effect on the prediction of the second output. In other words, the second prediction should not be a function of the first input.
One solution I have tried is using separate LSTMs for each output:
class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CustomLSTMModel, self).__init__()
self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
self.lstm2 = nn.LSTM(input_size, hidden_size, batch_first=True)
self.lstm3 = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc1 = nn.Linear(hidden_size, output_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
# Assume x is of shape (batch_size, seq_length, input_size)
# Split inputs
input1, input2, input3, input4, input5 = x.split(1, dim=2)
# Mask inputs for each output
# For output1, exclude input2
input1_for_output1 = torch.cat((input1, input3, input4, input5), dim=2)
# For output2, exclude input3
input2_for_output2 = torch.cat((input1, input2, input4, input5), dim=2)
# For output3, exclude input4
input3_for_output3 = torch.cat((input1, input2, input3, input5), dim=2)
# Process through LSTM
_, (hn1, _) = self.lstm1(input1_for_output1)
output1 = self.fc1(hn1[-1])
_, (hn2, _) = self.lstm2(input2_for_output2)
output2 = self.fc2(hn2[-1])
_, (hn3, _) = self.lstm3(input3_for_output3)
output3 = self.fc2(hn3[-1])
return output1, output2, output3
The problem with this approach is that it takes at least 3 times longer to run the model (since I am running LSTM 3 times, 1 for each output). Is it possible to do what I want to achieve more efficiently, with one run?
I have an LSTM model that receives 5 inputs to predict 3 outputs:
import torch
import torch.nn as nn
class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CustomLSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
None
I want to prevent certain input from having any impact on a certain output. Let's say, the first input should not have any effect on the prediction of the second output. In other words, the second prediction should not be a function of the first input.
One solution I have tried is using separate LSTMs for each output:
class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CustomLSTMModel, self).__init__()
self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
self.lstm2 = nn.LSTM(input_size, hidden_size, batch_first=True)
self.lstm3 = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc1 = nn.Linear(hidden_size, output_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
# Assume x is of shape (batch_size, seq_length, input_size)
# Split inputs
input1, input2, input3, input4, input5 = x.split(1, dim=2)
# Mask inputs for each output
# For output1, exclude input2
input1_for_output1 = torch.cat((input1, input3, input4, input5), dim=2)
# For output2, exclude input3
input2_for_output2 = torch.cat((input1, input2, input4, input5), dim=2)
# For output3, exclude input4
input3_for_output3 = torch.cat((input1, input2, input3, input5), dim=2)
# Process through LSTM
_, (hn1, _) = self.lstm1(input1_for_output1)
output1 = self.fc1(hn1[-1])
_, (hn2, _) = self.lstm2(input2_for_output2)
output2 = self.fc2(hn2[-1])
_, (hn3, _) = self.lstm3(input3_for_output3)
output3 = self.fc2(hn3[-1])
return output1, output2, output3
The problem with this approach is that it takes at least 3 times longer to run the model (since I am running LSTM 3 times, 1 for each output). Is it possible to do what I want to achieve more efficiently, with one run?
What you're trying to achieve is input masking. You can use separate fully connected layers (different people) for each output after running the LSTM once to capture all the information (one photo), instead of running the LSTM separately for each output (like taking multiple photos for different people).
You can mask or zero out portions of the LSTM's hidden state before sending it to the appropriate output layer if you want to make sure that particular inputs don't affect particular outputs. By running the LSTM just once, you can accomplish the same goal much more effectively.
I hope that helps you
Edit:
ok here's the modification needed base on your code:
import torch
import torch.nn as nn
class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LstmModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
# Fully connected layers for different outputs
self.fc1 = nn.Linear(hidden_size, output_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.fc3 = nn.Linear(hidden_size, output_size)
# Create masks to remove certain inputs' effects
self.mask1 = nn.Parameter(torch.tensor([1, 1, 0, 1, 1, 0, 0, 0], dtype=torch.float32), requires_grad=False) # Ignore input3 for output1, padded with 0's
self.mask2 = nn.Parameter(torch.tensor([1, 0, 1, 1, 1, 0, 0, 0], dtype=torch.float32), requires_grad=False) # Ignore input2 for output2, padded with 0's
self.mask3 = nn.Parameter(torch.tensor([1, 1, 1, 0, 1, 0, 0, 0], dtype=torch.float32), requires_grad=False) # Ignore input4 for output3, padded with 0's
def forward(self, x):
batch_size, seq_length, input_size = x.shape
# LSTM Forward pass
_, (hn, _) = self.lstm(x) # hn shape: (num_layers, batch, hidden_size)
hidden = hn[-1] # Get the final hidden state (batch, hidden_size)
# Apply masks by element-wise multiplication with the hidden state
hidden_masked1 = hidden * self.mask1 # Apply mask to hidden state for output1
hidden_masked2 = hidden * self.mask2 # Apply mask for output2
hidden_masked3 = hidden * self.mask3 # Apply mask for output3
# Generate outputs
output1 = self.fc1(hidden_masked1)
output2 = self.fc2(hidden_masked2)
output3 = self.fc3(hidden_masked3)
return output1, output2, output3
# === TESTING THE MODEL ===
# Example input: (batch_size=2, seq_length=10, input_size=5)
batch_size = 2
seq_length = 10
input_size = 5
hidden_size = 8
output_size = 1 # Single value per output
# Create model
model = LstmModel(input_size, hidden_size, output_size)
# Generate some random input data
x = torch.randn(batch_size, seq_length, input_size)
# Forward pass
output1, output2, output3 = model(x)
print("Output 1:", output1)
print("Output 2:", output2)
print("Output 3:", output3)
the results without the hidden_masked:
Output 1: tensor([[-0.0487],
[-0.0439]], grad_fn=<AddmmBackward0>)
Output 2: tensor([[-0.2588],
[-0.2890]], grad_fn=<AddmmBackward0>)
Output 3: tensor([[0.1792],
[0.1249]], grad_fn=<AddmmBackward0>)
with the hidden_masked:
Output 1: tensor([[0.3568],
[0.3477]], grad_fn=<AddmmBackward0>)
Output 2: tensor([[-0.3200],
[-0.3470]], grad_fn=<AddmmBackward0>)
Output 3: tensor([[0.4120],
[0.2970]], grad_fn=<AddmmBackward0>)
i realy hope that the comments i added in the code clarify each line and what it role.
use google colab to quick test the code
Edit 2:
since i used hard code values, here's a more rebost way:
import torch
import torch.nn as nn
class MaskedLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, exclusion_map):
"""
Args:
exclusion_map: Dictionary mapping output_idx to excluded_input_idx
Example: {1: 2, 2: 1, 3: 3} # output1 excludes input3, output2 excludes input2, etc.
"""
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc_layers = nn.ModuleList([nn.Linear(hidden_size, output_size)
for _ in range(len(exclusion_map))])
# Create trainable masks (modified)
self.masks = nn.ParameterDict()
for output_idx, excluded_input in exclusion_map.items():
mask = torch.ones(hidden_size)
# Zero 20% of dimensions associated with excluded input
input_span = hidden_size // input_size
start = excluded_input * input_span
end = (excluded_input + 1) * input_span
mask[start:end] = 0
self.masks[f"mask_{output_idx}"] = nn.Parameter(mask, requires_grad=False)
def forward(self, x):
lstm_out, (hn, _) = self.lstm(x)
final_hidden = hn[-1] # (batch_size, hidden_size)
outputs = []
for idx, fc in enumerate(self.fc_layers):
masked_hidden = final_hidden * self.masks[f"mask_{idx+1}"]
outputs.append(fc(masked_hidden))
return tuple(outputs)
# Configuration
exclusion_rules = {
1: 2, # Output1 excludes input3
2: 1, # Output2 excludes input2
3: 3 # Output3 excludes input4
}
model = MaskedLSTM(input_size=5, hidden_size=10,
output_size=1, exclusion_map=exclusion_rules)
# Test
x = torch.randn(3, 10, 5) # batch_size=3, seq_len=10
out1, out2, out3 = model(x)
print("Output 1:", out1)
print("Output 2:", out2)
print("Output 3:", out3)
t
contains information from all timesteps beforet
. If you don't want that, you should use a different model architecture – Karl Commented Feb 1 at 19:56