Skip to content

Fixed typing of arithmetic methods #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Upcoming Version
gap tolerance.
* Improve the mapping of termination conditions for the SCIP solver
* Treat GLPK's `integer undefined` status as not-OK
* Fixed variable/expression arithmetic methods so that they correctly handle types

Version 0.5.3
--------------
Expand Down
74 changes: 33 additions & 41 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from itertools import product, zip_longest
from typing import (
TYPE_CHECKING,
Any,
)
from typing import TYPE_CHECKING, Any
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -487,13 +484,16 @@
)
print(self)

def __add__(self, other: SideLike) -> LinearExpression:
def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression:
"""
Add an expression to others.

Note: If other is a numpy array or pandas object without axes names,
dimension names of self will be filled in other
"""
if isinstance(other, QuadraticExpression):
return other.__add__(self)

try:
if np.isscalar(other):
return self.assign(const=self.const + other)
Expand All @@ -503,9 +503,8 @@
except TypeError:
return NotImplemented

def __radd__(self, other: int) -> LinearExpression | NotImplementedType:
# This is needed for using python's sum function
return self if other == 0 else NotImplemented
def __radd__(self, other: ConstantLike) -> LinearExpression:
return self.__add__(other)

def __sub__(self, other: SideLike) -> LinearExpression:
"""
Expand All @@ -514,6 +513,9 @@
Note: If other is a numpy array or pandas object without axes names,
dimension names of self will be filled in other
"""
if isinstance(other, QuadraticExpression):
return other.__rsub__(self)

Check warning on line 517 in linopy/expressions.py

View check run for this annotation

Codecov / codecov/patch

linopy/expressions.py#L517

Added line #L517 was not covered by tests

try:
if np.isscalar(other):
return self.assign_multiindex_safe(const=self.const - other)
Expand All @@ -523,7 +525,7 @@
except TypeError:
return NotImplemented

def __neg__(self) -> LinearExpression | QuadraticExpression:
def __neg__(self) -> LinearExpression:
"""
Get the negative of the expression.
"""
Expand All @@ -536,14 +538,11 @@
"""
Multiply the expr by a factor.
"""
if isinstance(other, QuadraticExpression):
return other.__rmul__(self) # type: ignore

Check warning on line 542 in linopy/expressions.py

View check run for this annotation

Codecov / codecov/patch

linopy/expressions.py#L542

Added line #L542 was not covered by tests

try:
if isinstance(other, QuadraticExpression):
raise TypeError(
"unsupported operand type(s) for *: "
f"{type(self)} and {type(other)}. "
"Higher order non-linear expressions are not yet supported."
)
elif isinstance(other, (variables.Variable, variables.ScalarVariable)):
if isinstance(other, (variables.Variable, variables.ScalarVariable)):
other = other.to_linexpr()

if isinstance(other, (LinearExpression, ScalarLinearExpression)):
Expand Down Expand Up @@ -593,7 +592,7 @@
raise ValueError("Power must be 2.")
return self * self # type: ignore

def __rmul__(self, other: ConstantLike) -> LinearExpression | QuadraticExpression:
def __rmul__(self, other: ConstantLike) -> LinearExpression:
"""
Right-multiply the expr by a factor.
"""
Expand Down Expand Up @@ -1545,9 +1544,7 @@
data = xr.Dataset(data.transpose(..., FACTOR_DIM, TERM_DIM))
self._data = data

def __mul__(
self, other: ConstantLike | VariableLike | ExpressionLike
) -> QuadraticExpression:
def __mul__(self, other: ConstantLike) -> QuadraticExpression:
"""
Multiply the expr by a factor.
"""
Expand All @@ -1567,13 +1564,14 @@
)
return super().__mul__(other) # type: ignore

def __rmul__(self, other: ConstantLike) -> QuadraticExpression:
return self.__mul__(other)

@property
def type(self) -> str:
return "QuadraticExpression"

def __add__(
self, other: ConstantLike | VariableLike | ExpressionLike
) -> QuadraticExpression:
def __add__(self, other: SideLike) -> QuadraticExpression:
"""
Add an expression to others.

Expand All @@ -1592,21 +1590,13 @@
except TypeError:
return NotImplemented

def __radd__(
self, other: LinearExpression | int
) -> LinearExpression | QuadraticExpression:
def __radd__(self, other: ConstantLike) -> QuadraticExpression:
"""
Add others to expression.
"""
if type(other) is LinearExpression:
other = other.to_quadexpr()
return other.__add__(self)
elif other == 0:
return self
else:
return NotImplemented
return other.__add__(self)

def __sub__(self, other: SideLike | QuadraticExpression) -> QuadraticExpression:
def __sub__(self, other: SideLike) -> QuadraticExpression:
"""
Subtract others from expression.

Expand All @@ -1624,15 +1614,17 @@
except TypeError:
return NotImplemented

def __rsub__(self, other: LinearExpression) -> QuadraticExpression:
def __rsub__(self, other: SideLike) -> QuadraticExpression:
"""
Subtract expression from others.
"""
if type(other) is LinearExpression:
other = other.to_quadexpr()
return other.__sub__(self)
else:
return NotImplemented
return self.__neg__().__add__(other)

def __neg__(self) -> QuadraticExpression:
"""
Get the negative of the expression.
"""
return super().__neg__() # type: ignore

@property
def solution(self) -> DataArray:
Expand Down Expand Up @@ -1875,7 +1867,7 @@
A scalar linear expression container.

In contrast to the LinearExpression class, a ScalarLinearExpression
only contains only one label. Use this class to create a constraint
only contains one label. Use this class to create a constraint
in a rule.
"""

Expand Down
50 changes: 25 additions & 25 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,22 +382,24 @@
"""
return self.to_linexpr(-1)

def __mul__(
self, other: float | int | ndarray | Variable
) -> LinearExpression | QuadraticExpression:
def __mul__(self, other: SideLike) -> LinearExpression | QuadraticExpression:
"""
Multiply variables with a coefficient.
Multiply variables with a coefficient, variable, or expression.
"""
try:
if isinstance(
other, (expressions.LinearExpression, Variable, ScalarVariable)
):
if isinstance(other, (Variable, ScalarVariable)):
return self.to_linexpr() * other

return self.to_linexpr(other)
except TypeError:
return NotImplemented

def __rmul__(self, other: ConstantLike) -> LinearExpression:
"""
Right-multiply variables by a constant
"""
return self.to_linexpr(other)

def __pow__(self, other: int) -> QuadraticExpression:
"""
Power of the variables with a coefficient. The only coefficient allowed is 2.
Expand All @@ -407,15 +409,6 @@
return expr._multiply_by_linear_expression(expr)
return NotImplemented

def __rmul__(self, other: float | DataArray | int | ndarray) -> LinearExpression:
"""
Right-multiply variables with a coefficient.
"""
try:
return self.to_linexpr(other)
except TypeError:
return NotImplemented

def __matmul__(
self, other: LinearExpression | ndarray | Variable
) -> QuadraticExpression | LinearExpression:
Expand Down Expand Up @@ -449,9 +442,7 @@
except TypeError:
return NotImplemented

def __add__(
self, other: int | QuadraticExpression | LinearExpression | Variable
) -> QuadraticExpression | LinearExpression:
def __add__(self, other: SideLike) -> LinearExpression:
"""
Add variables to linear expressions or other variables.
"""
Expand All @@ -460,13 +451,13 @@
except TypeError:
return NotImplemented

def __radd__(self, other: int) -> Variable | NotImplementedType:
# This is needed for using python's sum function
return self if other == 0 else NotImplemented
def __radd__(self, other: ConstantLike) -> LinearExpression:
try:
return self.__add__(other)
except ValueError:
return NotImplemented

Check warning on line 458 in linopy/variables.py

View check run for this annotation

Codecov / codecov/patch

linopy/variables.py#L457-L458

Added lines #L457 - L458 were not covered by tests

def __sub__(
self, other: QuadraticExpression | LinearExpression | Variable
) -> QuadraticExpression | LinearExpression:
def __sub__(self, other: SideLike) -> LinearExpression:
"""
Subtract linear expressions or other variables from the variables.
"""
Expand All @@ -475,6 +466,15 @@
except TypeError:
return NotImplemented

def __rsub__(self, other: ConstantLike) -> LinearExpression:
"""
Subtract linear expressions or other variables from the variables.
"""
try:
return self.to_linexpr(-1) + other
except TypeError:
return NotImplemented

Check warning on line 476 in linopy/variables.py

View check run for this annotation

Codecov / codecov/patch

linopy/variables.py#L473-L476

Added lines #L473 - L476 were not covered by tests

def __le__(self, other: SideLike) -> Constraint:
return self.to_linexpr().__le__(other)

Expand Down
32 changes: 32 additions & 0 deletions test/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import xarray as xr
from mypy import api

import linopy


def test_operations_with_data_arrays_are_typed_correctly() -> None:
m = linopy.Model()

a: xr.DataArray = xr.DataArray([1, 2, 3])

v: linopy.Variable = m.add_variables(lower=0.0, name="v")
e: linopy.LinearExpression = v * 1.0
q = v * v
assert isinstance(q, linopy.QuadraticExpression)

_ = a * v
_ = v * a
_ = v + a

_ = a * e
_ = e * a
_ = e + a

_ = a * q
_ = q * a
_ = q + a

# Get the path of this file
file_path = __file__
result = api.run([file_path])
assert result[2] == 0, "Mypy returned issues: " + result[0]
Loading