Skip to content

Commit

Permalink
Support string variables and operators in ExprWrapper (#1066)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raine-Yang-UofT committed Jul 23, 2024
1 parent 2a48ce9 commit a0c2e47
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Custom checkers:
- Fixed minor typo in an error message in `python_ta/cfg/visitor.py`
- Updated `ExprWrapper` to support `set/list/tuple` literals and `in/not in` operators
- Updated `snapshot.py` and `test_snapshot.py` to align with MemoryViz 0.2.0 updates
- Updated `ExprWrapper` to support string variables and `==`, `in/not in`, indexing and slicing operators
- Added protected `_z3_vars` attribute to `ControlFlowGraph` to store variables to be used in Z3 solver
- Removed unused imports from `python_ta/cfg/graph.py`
- Extended functionality of `ExprWrapper` class to include function definitions' arguments and name assignments
Expand Down
108 changes: 101 additions & 7 deletions python_ta/transforms/ExprWrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Union

import astroid
import z3
Expand Down Expand Up @@ -67,6 +67,8 @@ def reduce(self, node: astroid.NodeNG = None) -> z3.ExprRef:
node = self.apply_name(node.name)
elif isinstance(node, (nodes.List, nodes.Tuple, nodes.Set)):
node = self.parse_container_op(node)
elif isinstance(node, nodes.Subscript):
node = self.parse_subscript_op(node)
else:
raise Z3ParseException(f"Unhandled node type {type(node)}.")

Expand All @@ -82,6 +84,7 @@ def apply_name(self, name: str) -> z3.ExprRef:
"int": z3.Int,
"float": z3.Real,
"bool": z3.Bool,
"str": z3.String,
}
if typ in type_to_z3:
x = type_to_z3[typ](name)
Expand Down Expand Up @@ -137,10 +140,10 @@ def apply_bin_op(
return left < right
elif op == ">":
return left > right
elif op == "in" and isinstance(right, list):
return z3.Or(*[left == element for element in right])
elif op == "not in" and isinstance(right, list):
return z3.And(*[left != element for element in right])
elif op == "in":
return self.apply_in_op(left, right)
elif op == "not in":
return self.apply_in_op(left, right, negate=True)
else:
raise Z3ParseException(
f"Unhandled binary operation {op} with operator types {left} and {right}."
Expand Down Expand Up @@ -184,11 +187,102 @@ def parse_bool_op(self, node: astroid.BoolOp) -> z3.ExprRef:
return self.apply_bool_op(op, values)

def parse_container_op(
self, node: Union[nodes.List, nodes.Set, nodes.Tuple]
self, node: Union[nodes.List, astroid.Set, astroid.Tuple]
) -> List[z3.ExprRef]:
"""Convert an astroid List, Set, Tuple node to a list of z3 expressions."""
return [self.reduce(element) for element in node.elts]

def apply_in_op(
self,
left: Union[z3.ExprRef, str],
right: Union[z3.ExprRef, List[z3.ExprRef], str],
negate: bool = False,
) -> z3.ExprRef:
"""
Apply `in` or `not in` operator on a list or string and return the
resulting z3 expression. Raise Z3ParseException if the operands
do not support `in` operator
"""
if isinstance(right, list): # container type (list/set/tuple)
return (
z3.And(*[left != element for element in right])
if negate
else z3.Or(*[left == element for element in right])
)
elif isinstance(left, (str, z3.SeqRef)) and isinstance(
right, (str, z3.SeqRef)
): # string literal or variable
return z3.Not(z3.Contains(right, left)) if negate else z3.Contains(right, left)
else:
op = "not in" if negate else "in"
raise Z3ParseException(
f"Unhandled binary operation {op} with operator types {left} and {right}."
)

def _parse_number_literal(self, node: astroid.NodeNG) -> Optional[Union[int, float]]:
"""
If the subtree from `node` represent a number literal, return the value
Otherwise, return None
"""
# positive number
if isinstance(node, nodes.Const) and isinstance(node.value, (int, float)):
return node.value
# negative number
elif (
isinstance(node, nodes.UnaryOp)
and node.op == "-"
and isinstance(node.operand, nodes.Const)
and isinstance(node.operand.value, (int, float))
):
return -node.operand.value
else:
return None

def parse_subscript_op(self, node: astroid.Subscript) -> z3.ExprRef:
"""
Convert an astroid Subscript node to z3 expression.
This method only supports string values and integer literal (both positive and negative) indexes
"""
value = self.reduce(node.value)
slice = node.slice

# check for invalid node type
if not z3.is_seq(value):
raise Z3ParseException(f"Unhandled subscript operand type {value}")

# handle indexing
index = self._parse_number_literal(slice)
if isinstance(index, int):
return z3.SubString(value, index, 1)

# handle slicing
if isinstance(slice, nodes.Slice):
lower = 0 if slice.lower is None else self._parse_number_literal(slice.lower)
upper = (
z3.Length(value) if slice.upper is None else self._parse_number_literal(slice.upper)
)
step = 1 if slice.step is None else self._parse_number_literal(slice.step)

if not (
isinstance(lower, int)
and isinstance(upper, (int, z3.ArithRef))
and isinstance(step, int)
):
raise Z3ParseException(f"Invalid slicing indexes {lower}, {upper}, {step}")

if step == 1:
return z3.SubString(value, lower, upper - lower)

# unhandled case: the upper bound is indeterminant
if step != 1 and upper == z3.Length(value):
raise Z3ParseException(
"Unable to convert a slicing operation with a non-unit step length and an indeterminant upper bound"
)

return z3.Concat(*(z3.SubString(value, i, 1) for i in range(lower, upper, step)))

raise Z3ParseException(f"Unhandled subscript operator type {slice}")

def parse_arguments(self, node: astroid.Arguments) -> Dict[str, z3.ExprRef]:
"""Convert an astroid Arguments node's parameters to z3 variables."""
z3_vars = {}
Expand All @@ -205,7 +299,7 @@ def parse_arguments(self, node: astroid.Arguments) -> Dict[str, z3.ExprRef]:

self.types[arg.name] = inferred.name

if arg.name in self.types and self.types[arg.name] in {"int", "float", "bool"}:
if arg.name in self.types and self.types[arg.name] in {"int", "float", "bool", "str"}:
z3_vars[arg.name] = self.reduce(arg)

return z3_vars
182 changes: 178 additions & 4 deletions tests/test_z3_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,88 @@ def not_in_empty_tuple(x: int):
""",
]

# test cases for strings expressions
string_list = [
"""
def string_equality(x: str, y: str, z: str):
'''
Preconditions:
- x == y
- z == x + y
'''
pass
""",
"""
def in_string(x: str, y: str):
'''
Preconditions:
- x in "abc"
- x in y
'''
pass
""",
"""
def not_in_string(x: str, y: str):
'''
Preconditions:
- x not in "abc"
- x not in y
'''
pass
""",
"""
def string_indexing_positive(x: str, y: str):
'''
Preconditions:
- x[0] == y
- x[1] == "a"
'''
pass
""",
"""
def string_indexing_negative(x: str, y: str):
'''
Preconditions:
- x[-1] == "b"
- x[-2] == y
'''
pass
""",
"""
def string_slicing_positive(x: str, y: str, z: str):
'''
Preconditions:
- x[1:4] == y
- x[4:5] == "a"
- x[4:] == "abc"
- x[:3] == "def"
- x[:] == z
'''
pass
""",
"""
def string_slicing_negative(x: str, y: str, z: str):
'''
Preconditions:
- x[-4:-1] == y
- x[-4:-3] == "a"
- x[-4:] == "abc"
- x[:-3] == "def"
'''
pass
"""
"""
def string_step_length(x: str, y: str):
'''
Preconditions:
- x[1:5:2] == y
- x[:5:4] == "ab"
- x[6:3:-2] == "cd"
'''
pass
""",
]


# expected arithmetic expressions
x = z3.Int("x")
Expand Down Expand Up @@ -154,9 +236,45 @@ def not_in_empty_tuple(x: int):
[z3.BoolVal(True)],
]

# expected string expressions
x = z3.String("x")
y = z3.String("y")
z = z3.String("z")
string_expected = [
[x == y, z == x + y],
[z3.Contains("abc", x), z3.Contains(y, x)],
[z3.Not(z3.Contains("abc", x)), z3.Not(z3.Contains(y, x))],
[
z3.SubString(x, 0, 1) == y,
z3.SubString(x, 1, 1) == "a",
],
[
z3.SubString(x, z3.Length(x) - 1, 1) == "b",
z3.SubString(x, z3.Length(x) - 2, 1) == y,
],
[
z3.SubString(x, 1, 3) == y,
z3.SubString(x, 4, 1) == "a",
z3.SubString(x, 4, z3.Length(x) - 4) == "abc",
z3.SubString(x, 0, 3) == "def",
x == z,
],
[
z3.SubString(x, z3.Length(x) - 4, 3) == y,
z3.SubString(x, z3.Length(x) - 4, 1) == "a",
z3.SubString(x, z3.Length(x) - 4, 4) == "abc",
z3.SubString(x, 0, z3.Length(x) - 3) == "def",
],
[
z3.Concat(z3.SubString(x, 1, 1), z3.SubString(x, 3, 1)) == y,
z3.Concat(z3.SubString(x, 0, 1), z3.SubString(x, 4, 1)) == "ab",
z3.Concat(z3.SubString(x, 6, 1), z3.SubString(x, 4, 1)) == "cd",
],
]

# lists of all test cases
code_list = [arithmetic_list, boolean_list, container_list]
expected_list = [arithmetic_expected, boolean_expected, container_expected]
code_list = [arithmetic_list, boolean_list, container_list, string_list]
expected_list = [arithmetic_expected, boolean_expected, container_expected, string_expected]


def _get_constraints_from_code(code) -> List[z3.ExprRef]:
Expand All @@ -180,6 +298,62 @@ def test_constraint(code, expected):
assert solver.check() == z3.sat


# test cases for invalid inputs
#
# Explanation on unhandled_slicing_index:
# a string slicing operation with indeterminanat upper bound
# (such as the variable's lnegth) and a step length not equal
# to 1 is currently not supported.
invalid_input_list = [
"""
def invalid_in_op_type(x: int, y: bool):
'''
Preconditions:
- x in y
'''
pass
""",
"""
def invalid_string_index(x: str, a):
'''
Preconditions:
- x[a] == "a"
'''
pass
""",
"""
def invalid_slicing_index(x: str, a, b, c):
'''
Preconditions:
- x[a:b:c] == "a"
'''
pass
""",
"""
def invalid_subscript_type(x: int):
'''
Preconditions:
- x[1] == 0
'''
pasd
""",
"""
def unhandled_slicing_index(x: str):
'''
Preconditions:
- x[::2] == "abc"
- x[::-2] == "abc"
'''
pass
""",
]


@pytest.mark.parametrize("invalid_code", invalid_input_list)
def test_invalid_input(invalid_code):
assert _get_constraints_from_code(invalid_code) == []


def test_cfg_z3_vars_initialization():
"""
Test that the cfg's z3 variable mapping is correctly initialized.
Expand All @@ -189,8 +363,8 @@ def test_cfg_z3_vars_initialization():
cfg = ControlFlowGraph()
cfg.add_arguments(node.args)

# Note that this first assert implicitly includes the assertion that 'a' not in cfg._z3_vars
assert len(cfg._z3_vars) == 3
assert len(cfg._z3_vars) == 4
assert cfg._z3_vars["x"] == z3.Int("x")
assert cfg._z3_vars["y"] == z3.Real("y")
assert cfg._z3_vars["z"] == z3.Bool("z")
assert cfg._z3_vars["a"] == z3.String("a")

0 comments on commit a0c2e47

Please sign in to comment.