python - Filtering from index and comparing row value with all values in column - Stack Overflow

admin2025-05-01  3

Starting with this DataFrame:

df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

┌───────┬───────┬───────┬───────┐
│ name  ┆ index ┆ limit ┆ price │
│ ---   ┆   --- ┆   --- ┆   --- │
│ str   ┆   i64 ┆   i64 ┆   i64 │
╞═══════╪═══════╪═══════╪═══════╡
│ Alpha ┆     0 ┆    12 ┆    10 │
│ Alpha ┆     3 ┆    18 ┆    15 │
│ Alpha ┆     4 ┆    11 ┆    12 │
│ Alpha ┆     7 ┆     5 ┆     8 │
│ Alpha ┆     9 ┆     9 ┆    11 │
└───────┴───────┴───────┴───────┘

I need to add a new column to tell me at which index (greater than the current one) the price is equal or higher than the current limit.

With this example above, the expected output is:

┌───────┬───────┬───────┬───────┬───────────┐
│ name  ┆ index ┆ limit ┆ price ┆ min_index │
│ ---   ┆   --- ┆   --- ┆   --- ┆       --- │
│ str   ┆   i64 ┆   i64 ┆   i64 ┆       i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆     0 ┆    12 ┆    10 ┆         3 │
│ Alpha ┆     3 ┆    18 ┆    15 ┆      null │
│ Alpha ┆     4 ┆    11 ┆    12 ┆         9 │
│ Alpha ┆     7 ┆     5 ┆     8 ┆         9 │
│ Alpha ┆     9 ┆     9 ┆    11 ┆      null │
└───────┴───────┴───────┴───────┴───────────┘

Explaining the "min_index" column results:

  • 1st row, where the limit is 12: from the 2nd row onwards, the minimum index whose price is equal or greater than 12 is 3.
  • 2nd row, where the limit is 18: from the 3rd row onwards, there is no index whose price is equal or greater than 18.
  • 3rd row, where the limit is 11: from the 4th row onwards, the minimum index whose price is equal or greater than 11 is 9.
  • 4th row, where the limit is 5: from the 5th row onwards, the minimum index whose price is equal or greater than 5 is 9.
  • 5th row, where the limit is 9: as this is the last row, there is no further index whose price is equal or greater than 9.

My solution is shown below - but what would be a neat Polars way of doing it? I was able to solve it in 8 steps, but I'm sure there is a more effective way of doing it.

# Import Polars.
import polars as pl

# Create a sample DataFrame.
df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

# Group by name, so that we can vertically stack all row's values into a single list.
df_2 = df_1.group_by('name').agg(pl.all())

# Put the lists with the original DataFrame.
df_3 = df_1.join(
    other=df_2,
    on='name',
    suffix='_list'
)

# Explode the dataframe to long format by exploding the given columns.
df_3 = df_3.explode([
    'index_list',
    'limit_list',
    'price_list',
])

# Filter the DataFrame for the condition we want.
df_3 = df_3.filter(
    (pl.col('index_list') > pl.col('index')) &
    (pl.col('price_list') >= pl.col('limit'))
)

# Get the minimum index over the index column.
df_3 = df_3.with_columns(
    pl.col('index_list').min().over('index').alias('min_index')
)

# Select only the relevant columns and drop duplicates.
df_3 = df_3.select(
    pl.col(['index', 'min_index'])
).unique()

# Finally join the result.
df_final = df_1.join(
    other=df_3,
    on='index',
    how='left'
)

print(df_final)

Starting with this DataFrame:

df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

┌───────┬───────┬───────┬───────┐
│ name  ┆ index ┆ limit ┆ price │
│ ---   ┆   --- ┆   --- ┆   --- │
│ str   ┆   i64 ┆   i64 ┆   i64 │
╞═══════╪═══════╪═══════╪═══════╡
│ Alpha ┆     0 ┆    12 ┆    10 │
│ Alpha ┆     3 ┆    18 ┆    15 │
│ Alpha ┆     4 ┆    11 ┆    12 │
│ Alpha ┆     7 ┆     5 ┆     8 │
│ Alpha ┆     9 ┆     9 ┆    11 │
└───────┴───────┴───────┴───────┘

I need to add a new column to tell me at which index (greater than the current one) the price is equal or higher than the current limit.

With this example above, the expected output is:

┌───────┬───────┬───────┬───────┬───────────┐
│ name  ┆ index ┆ limit ┆ price ┆ min_index │
│ ---   ┆   --- ┆   --- ┆   --- ┆       --- │
│ str   ┆   i64 ┆   i64 ┆   i64 ┆       i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆     0 ┆    12 ┆    10 ┆         3 │
│ Alpha ┆     3 ┆    18 ┆    15 ┆      null │
│ Alpha ┆     4 ┆    11 ┆    12 ┆         9 │
│ Alpha ┆     7 ┆     5 ┆     8 ┆         9 │
│ Alpha ┆     9 ┆     9 ┆    11 ┆      null │
└───────┴───────┴───────┴───────┴───────────┘

Explaining the "min_index" column results:

  • 1st row, where the limit is 12: from the 2nd row onwards, the minimum index whose price is equal or greater than 12 is 3.
  • 2nd row, where the limit is 18: from the 3rd row onwards, there is no index whose price is equal or greater than 18.
  • 3rd row, where the limit is 11: from the 4th row onwards, the minimum index whose price is equal or greater than 11 is 9.
  • 4th row, where the limit is 5: from the 5th row onwards, the minimum index whose price is equal or greater than 5 is 9.
  • 5th row, where the limit is 9: as this is the last row, there is no further index whose price is equal or greater than 9.

My solution is shown below - but what would be a neat Polars way of doing it? I was able to solve it in 8 steps, but I'm sure there is a more effective way of doing it.

# Import Polars.
import polars as pl

# Create a sample DataFrame.
df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

# Group by name, so that we can vertically stack all row's values into a single list.
df_2 = df_1.group_by('name').agg(pl.all())

# Put the lists with the original DataFrame.
df_3 = df_1.join(
    other=df_2,
    on='name',
    suffix='_list'
)

# Explode the dataframe to long format by exploding the given columns.
df_3 = df_3.explode([
    'index_list',
    'limit_list',
    'price_list',
])

# Filter the DataFrame for the condition we want.
df_3 = df_3.filter(
    (pl.col('index_list') > pl.col('index')) &
    (pl.col('price_list') >= pl.col('limit'))
)

# Get the minimum index over the index column.
df_3 = df_3.with_columns(
    pl.col('index_list').min().over('index').alias('min_index')
)

# Select only the relevant columns and drop duplicates.
df_3 = df_3.select(
    pl.col(['index', 'min_index'])
).unique()

# Finally join the result.
df_final = df_1.join(
    other=df_3,
    on='index',
    how='left'
)

print(df_final)
Share Improve this question edited Jan 4 at 20:00 jonrsharpe 122k30 gold badges268 silver badges476 bronze badges asked Jan 4 at 19:58 Danilo SettonDanilo Setton 71311 silver badges22 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 2

Option 1: df.join_where (experimental)

out = (
    df_1.join(
        df_1
        .join_where(
            df_1.select('index', 'price'),
            pl.col('index_right') > pl.col('index'),
            pl.col('price_right') >= pl.col('limit')
        )
        .group_by('index')
        .agg(
            pl.col('index_right').min().alias('min_index')
            ),
        on='index',
        how='left'
    )
)

Output:

shape: (5, 5)
┌───────┬───────┬───────┬───────┬───────────┐
│ name  ┆ index ┆ limit ┆ price ┆ min_index │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---       │
│ str   ┆ i64   ┆ i64   ┆ i64   ┆ i64       │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 3         │
│ Alpha ┆ 3     ┆ 18    ┆ 15    ┆ null      │
│ Alpha ┆ 4     ┆ 11    ┆ 12    ┆ 9         │
│ Alpha ┆ 7     ┆ 5     ┆ 8     ┆ 9         │
│ Alpha ┆ 9     ┆ 9     ┆ 11    ┆ null      │
└───────┴───────┴───────┴───────┴───────────┘

Explanation / Intermediates

  • Use df.join_where and for other use df.select (note that you don't need 'limit'), adding the filter predicates.
# df_1.join_where(...)

shape: (4, 6)
┌───────┬───────┬───────┬───────┬─────────────┬─────────────┐
│ name  ┆ index ┆ limit ┆ price ┆ index_right ┆ price_right │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---         ┆ ---         │
│ str   ┆ i64   ┆ i64   ┆ i64   ┆ i64         ┆ i64         │
╞═══════╪═══════╪═══════╪═══════╪═════════════╪═════════════╡
│ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 3           ┆ 15          │
│ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 4           ┆ 12          │
│ Alpha ┆ 4     ┆ 11    ┆ 12    ┆ 9           ┆ 11          │
│ Alpha ┆ 7     ┆ 5     ┆ 8     ┆ 9           ┆ 11          │
└───────┴───────┴───────┴───────┴─────────────┴─────────────┘
  • Since order is not maintained, use df.group_by to retrieve pl.Expr.min per 'index'.
# df_1.join_where(...).group_by('index').agg(...)

shape: (3, 2)
┌───────┬───────────┐
│ index ┆ min_index │
│ ---   ┆ ---       │
│ i64   ┆ i64       │
╞═══════╪═══════════╡
│ 0     ┆ 3         │
│ 7     ┆ 9         │
│ 4     ┆ 9         │
└───────┴───────────┘
  • The result we add to df_1 with a left join.

Option 2: df.join with "cross" + df.filter

(Adding this option, since df.join_where is experimental. This will be more expensive though.)

out2 = (
    df_1.join(
        df_1
        .join(df_1.select('index', 'price'), how='cross')
        .filter(
            pl.col('index_right') > pl.col('index'),
            pl.col('price_right') >= pl.col('limit')
        )
        .group_by('index')
        .agg(
            pl.col('index_right').min().alias('min_index')
        ),
        on='index',
        how='left'
    )
)

out2.equals(out)
# True
转载请注明原文地址:http://anycun.com/QandA/1746029150a91562.html