Skip to content

Commit

Permalink
[red-knot] Infer target types for unpacked tuple assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvmanila committed Sep 10, 2024
1 parent b7cef6c commit c66a85d
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 20 deletions.
42 changes: 31 additions & 11 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::semantic_index::SemanticIndex;
use crate::Db;

use super::constraint::{Constraint, PatternConstraint};
use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};
use super::definition::{AssignmentKind, MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};

pub(super) struct SemanticIndexBuilder<'db> {
// Builder state
Expand Down Expand Up @@ -517,8 +517,17 @@ where
debug_assert!(self.current_assignment.is_none());
self.visit_expr(&node.value);
self.add_standalone_expression(&node.value);
self.current_assignment = Some(node.into());
for target in &node.targets {
let kind = match target {
ast::Expr::Name(_) => AssignmentKind::Name,
ast::Expr::List(_) | ast::Expr::Tuple(_) => AssignmentKind::Sequence(0),
ast::Expr::Starred(_) => AssignmentKind::Starred,
ast::Expr::Attribute(_) => AssignmentKind::Attribute,
ast::Expr::Subscript(_) => AssignmentKind::Subscript,
// TODO: is this a good default for an error recovery case like `1 = 2`?
_ => continue,
};
self.current_assignment = Some(CurrentAssignment::Assign { node, kind });
self.visit_expr(target);
}
self.current_assignment = None;
Expand Down Expand Up @@ -699,12 +708,13 @@ where
let symbol = self.add_or_update_symbol(id.clone(), flags);
if flags.contains(SymbolFlags::IS_DEFINED) {
match self.current_assignment {
Some(CurrentAssignment::Assign(assignment)) => {
Some(CurrentAssignment::Assign { node, kind }) => {
self.add_definition(
symbol,
AssignmentDefinitionNodeRef {
assignment,
assignment: node,
target: name_node,
kind,
},
);
}
Expand Down Expand Up @@ -851,6 +861,19 @@ where
self.visit_expr(key);
self.visit_expr(value);
}
ast::Expr::Tuple(ast::ExprTuple { elts, ctx, .. }) => {
for (index, element) in elts.iter().enumerate() {
if let Some(CurrentAssignment::Assign {
kind: AssignmentKind::Sequence(target_index),
..
}) = self.current_assignment.as_mut()
{
*target_index = index;
}
self.visit_expr(element);
}
self.visit_expr_context(ctx);
}
_ => {
walk_expr(self, expr);
}
Expand Down Expand Up @@ -957,7 +980,10 @@ where

#[derive(Copy, Clone, Debug)]
enum CurrentAssignment<'a> {
Assign(&'a ast::StmtAssign),
Assign {
node: &'a ast::StmtAssign,
kind: AssignmentKind,
},
AnnAssign(&'a ast::StmtAnnAssign),
AugAssign(&'a ast::StmtAugAssign),
For(&'a ast::StmtFor),
Expand All @@ -969,12 +995,6 @@ enum CurrentAssignment<'a> {
WithItem(&'a ast::WithItem),
}

impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAssign) -> Self {
Self::Assign(value)
}
}

impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAnnAssign) -> Self {
Self::AnnAssign(value)
Expand Down
32 changes: 26 additions & 6 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) struct AssignmentDefinitionNodeRef<'a> {
pub(crate) assignment: &'a ast::StmtAssign,
pub(crate) target: &'a ast::ExprName,
pub(crate) kind: AssignmentKind,
}

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -203,12 +204,15 @@ impl DefinitionNodeRef<'_> {
DefinitionNodeRef::NamedExpression(named) => {
DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named))
}
DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { assignment, target }) => {
DefinitionKind::Assignment(AssignmentDefinitionKind {
assignment: AstNodeRef::new(parsed.clone(), assignment),
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef {
assignment,
target,
kind,
}) => DefinitionKind::Assignment(AssignmentDefinitionKind {
assignment: AstNodeRef::new(parsed.clone(), assignment),
target: AstNodeRef::new(parsed, target),
kind,
}),
DefinitionNodeRef::AnnotatedAssignment(assign) => {
DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign))
}
Expand Down Expand Up @@ -276,6 +280,7 @@ impl DefinitionNodeRef<'_> {
Self::Assignment(AssignmentDefinitionNodeRef {
assignment: _,
target,
kind: _,
}) => target.into(),
Self::AnnotatedAssignment(node) => node.into(),
Self::AugmentedAssignment(node) => node.into(),
Expand Down Expand Up @@ -381,6 +386,7 @@ impl ImportFromDefinitionKind {
pub struct AssignmentDefinitionKind {
assignment: AstNodeRef<ast::StmtAssign>,
target: AstNodeRef<ast::ExprName>,
kind: AssignmentKind,
}

impl AssignmentDefinitionKind {
Expand All @@ -391,6 +397,20 @@ impl AssignmentDefinitionKind {
pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}

pub(crate) fn kind(&self) -> AssignmentKind {
self.kind
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AssignmentKind {
Attribute,
Subscript,
Starred,
Name,
/// list or tuple with an index into the list of targets.
Sequence(usize),
}

#[derive(Clone, Debug)]
Expand Down
13 changes: 13 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ impl<'db> Type<'db> {
matches!(self, Type::Never)
}

pub const fn as_tuple_type(&self) -> Option<&TupleType<'db>> {
match self {
Type::Tuple(tuple_type) => Some(tuple_type),
_ => None,
}
}

pub const fn into_class_type(self) -> Option<ClassType<'db>> {
match self {
Type::Class(class_type) => Some(class_type),
Expand Down Expand Up @@ -672,3 +679,9 @@ pub struct TupleType<'db> {
#[return_ref]
elements: Box<[Type<'db>]>,
}

impl<'db> TupleType<'db> {
pub fn get(&self, db: &'db dyn Db, index: usize) -> Option<&Type<'db>> {
self.elements(db).get(index)
}
}
40 changes: 37 additions & 3 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ use ruff_text_size::Ranged;
use crate::module_name::ModuleName;
use crate::module_resolver::{file_to_module, resolve_module};
use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId};
use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey};
use crate::semantic_index::definition::{
AssignmentKind, Definition, DefinitionKind, DefinitionNodeKey,
};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::semantic_index;
use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId};
Expand Down Expand Up @@ -380,6 +382,7 @@ impl<'db> TypeInferenceBuilder<'db> {
DefinitionKind::Assignment(assignment) => {
self.infer_assignment_definition(
assignment.target(),
assignment.kind(),
assignment.assignment(),
definition,
);
Expand Down Expand Up @@ -957,19 +960,34 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_assignment_definition(
&mut self,
target: &ast::ExprName,
kind: AssignmentKind,
assignment: &ast::StmtAssign,
definition: Definition<'db>,
) {
let expression = self.index.expression(assignment.value.as_ref());
let result = infer_expression_types(self.db, expression);
self.extend(result);

let value_ty = self
.types
.expression_ty(assignment.value.scoped_ast_id(self.db, self.scope));

let target_ty = match (value_ty, kind) {
(Type::Tuple(tuple_type), AssignmentKind::Sequence(target_index)) => {
// TODO: when does this happen?
tuple_type
.get(self.db, target_index)
.copied()
.unwrap_or(Type::Unknown)
}
_ => value_ty,
};

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), value_ty);
self.types.definitions.insert(definition, value_ty);
.insert(target.scoped_ast_id(self.db, self.scope), target_ty);

self.types.definitions.insert(definition, target_ty);
}

fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) {
Expand Down Expand Up @@ -4057,6 +4075,22 @@ mod tests {
Ok(())
}

#[test]
fn unpacked_tuple_assignment() {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
x, y = 1, 2
",
)
.unwrap();

assert_public_ty(&db, "/src/a.py", "x", "Literal[1]");
assert_public_ty(&db, "/src/a.py", "y", "Literal[2]");
}

#[test]
fn list_literal() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit c66a85d

Please sign in to comment.