Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/command_line.rst
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,21 @@ of the above sections.
f(memoryview(b"")) # Ok
.. option:: --disallow-str-iteration

Disallow iterating over ``str`` values.
This also rejects using ``str`` where an ``Iterable[str]`` or ``Sequence[str]`` is expected.
To iterate over characters, call ``iter`` on the string explicitly.

.. code-block:: python
s = "hello"
for ch in s: # error: Iterating over "str" is disallowed
print(ch)
for ch in iter(s): # OK
print(ch)
.. option:: --extra-checks

This flag enables additional checks that are technically correct but may be
Expand Down
8 changes: 8 additions & 0 deletions docs/source/config_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,14 @@ section of the command line docs.
Disable treating ``bytearray`` and ``memoryview`` as subtypes of ``bytes``.
This will be enabled by default in *mypy 2.0*.

.. confval:: disallow_str_iteration

:type: boolean
:default: False

Disallow iterating over ``str`` values.
This also rejects using ``str`` where an ``Iterable[str]`` or ``Sequence[str]`` is expected.

.. confval:: strict

:type: boolean
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi
index bd425ff3c..5dae75dd9 100644
--- a/mypy/typeshed/stdlib/builtins.pyi
+++ b/mypy/typeshed/stdlib/builtins.pyi
@@ -1458,6 +1458,8 @@ class _GetItemIterable(Protocol[_T_co]):
@overload
def iter(object: SupportsIter[_SupportsNextT_co], /) -> _SupportsNextT_co: ...
@overload
+def iter(object: str, /) -> Iterator[str]: ...
+@overload
def iter(object: _GetItemIterable[_T], /) -> Iterator[_T]: ...
@overload
def iter(object: Callable[[], _T | None], sentinel: None, /) -> Iterator[_T]: ...
31 changes: 29 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from mypy.checkpattern import PatternChecker
from mypy.constraints import SUPERTYPE_OF
from mypy.disallow_str_iteration_state import disallow_str_iteration_state
from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values
from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode
from mypy.errors import (
Expand Down Expand Up @@ -513,7 +514,11 @@ def check_first_pass(self) -> None:
Deferred functions will be processed by check_second_pass().
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
with (
state.strict_optional_set(self.options.strict_optional),
disallow_str_iteration_state.set(self.options.disallow_str_iteration),
checker_state.set(self),
):
self.errors.set_file(
self.path, self.tree.fullname, scope=self.tscope, options=self.options
)
Expand Down Expand Up @@ -558,7 +563,11 @@ def check_second_pass(
"""
self.allow_constructor_cache = allow_constructor_cache
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
with (
state.strict_optional_set(self.options.strict_optional),
disallow_str_iteration_state.set(self.options.disallow_str_iteration),
checker_state.set(self),
):
if not todo and not self.deferred_nodes:
return False
self.errors.set_file(
Expand Down Expand Up @@ -5378,6 +5387,12 @@ def analyze_iterable_item_type_without_expression(
echk = self.expr_checker
iterable: Type
iterable = get_proper_type(type)

if disallow_str_iteration_state.disallow_str_iteration and self.is_str_iteration_type(
iterable
):
self.msg.str_iteration_disallowed(context, iterable)

iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0]

if (
Expand All @@ -5390,6 +5405,18 @@ def analyze_iterable_item_type_without_expression(
iterable = echk.check_method_call_by_name("__next__", iterator, [], [], context)[0]
return iterator, iterable

def is_str_iteration_type(self, typ: Type) -> bool:
typ = get_proper_type(typ)
if isinstance(typ, LiteralType):
return isinstance(typ.value, str)
if isinstance(typ, Instance):
return is_proper_subtype(typ, self.named_type("builtins.str"))
if isinstance(typ, UnionType):
return any(self.is_str_iteration_type(item) for item in typ.relevant_items())
if isinstance(typ, TypeVarType):
return self.is_str_iteration_type(typ.upper_bound)
return False

def analyze_range_native_int_type(self, expr: Expression) -> Type | None:
"""Try to infer native int item type from arguments to range(...).
Expand Down
25 changes: 25 additions & 0 deletions mypy/disallow_str_iteration_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from typing import Final


class DisallowStrIterationState:
# Wrap this in a class since it's faster that using a module-level attribute.

def __init__(self, disallow_str_iteration: bool) -> None:
# Value varies by file being processed
self.disallow_str_iteration = disallow_str_iteration

@contextmanager
def set(self, value: bool) -> Iterator[None]:
saved = self.disallow_str_iteration
self.disallow_str_iteration = value
try:
yield
finally:
self.disallow_str_iteration = saved


disallow_str_iteration_state: Final = DisallowStrIterationState(disallow_str_iteration=False)
8 changes: 8 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,14 @@ def add_invertible_flag(
group=strictness_group,
)

add_invertible_flag(
"--disallow-str-iteration",
default=True,
strict_flag=False,
help="Disallow iterating over str instances",
group=strictness_group,
)

add_invertible_flag(
"--extra-checks",
default=False,
Expand Down
8 changes: 8 additions & 0 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable

from mypy import join
from mypy.disallow_str_iteration_state import disallow_str_iteration_state
from mypy.erasetype import erase_type
from mypy.maptype import map_instance_to_supertype
from mypy.state import state
Expand All @@ -14,6 +15,7 @@
is_proper_subtype,
is_same_type,
is_subtype,
is_subtype_relation_ignored_to_disallow_str_iteration,
)
from mypy.typeops import is_recursive_pair, make_simplified_union, tuple_fallback
from mypy.types import (
Expand Down Expand Up @@ -596,6 +598,12 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
if right.type.fullname == "builtins.int" and left.type.fullname in MYPYC_NATIVE_INT_NAMES:
return True

if disallow_str_iteration_state.disallow_str_iteration:
if is_subtype_relation_ignored_to_disallow_str_iteration(left, right):
return False
elif is_subtype_relation_ignored_to_disallow_str_iteration(right, left):
return False

# Two unrelated types cannot be partially overlapping: they're disjoint.
if left.type.has_base(right.type.fullname):
left = map_instance_to_supertype(left, right.type)
Expand Down
14 changes: 14 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,10 @@ def wrong_number_values_to_unpack(
def unpacking_strings_disallowed(self, context: Context) -> None:
self.fail("Unpacking a string is disallowed", context, code=codes.STR_UNPACK)

def str_iteration_disallowed(self, context: Context, str_type: Type) -> None:
self.fail(f"Iterating over {format_type(str_type, self.options)} is disallowed", context)
self.note("This is because --disallow-str-iteration is enabled", context)

def type_not_iterable(self, type: Type, context: Context) -> None:
self.fail(f"{format_type(type, self.options)} object is not iterable", context)

Expand Down Expand Up @@ -2210,6 +2214,15 @@ def report_protocol_problems(
conflict_types = get_conflict_protocol_types(
subtype, supertype, class_obj=class_obj, options=self.options
)

if subtype.type.has_base("builtins.str") and supertype.type.has_base("typing.Container"):
# `str` doesn't properly conform to the `Container` protocol, but we don't want to show that as the reason for the error.
conflict_types = [
conflict_type
for conflict_type in conflict_types
if conflict_type[0] != "__contains__"
]

if conflict_types and (
not is_subtype(subtype, erase_type(supertype), options=self.options)
or not subtype.type.defn.type_vars
Expand Down Expand Up @@ -3122,6 +3135,7 @@ def get_conflict_protocol_types(
Return them as a list of ('member', 'got', 'expected', 'is_lvalue').
"""
assert right.type.is_protocol

conflicts: list[tuple[str, Type, Type, bool]] = []
for member in right.type.protocol_members:
if member in ("__init__", "__new__"):
Expand Down
4 changes: 4 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BuildType:
"disallow_any_unimported",
"disallow_incomplete_defs",
"disallow_subclassing_any",
"disallow_str_iteration",
"disallow_untyped_calls",
"disallow_untyped_decorators",
"disallow_untyped_defs",
Expand Down Expand Up @@ -238,6 +239,9 @@ def __init__(self) -> None:
# Disable treating bytearray and memoryview as subtypes of bytes
self.strict_bytes = False

# Disallow iterating over str instances or using them as Sequence[T]
self.disallow_str_iteration = True

# Deprecated, use extra_checks instead.
self.strict_concatenate = False

Expand Down
26 changes: 26 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import mypy.constraints
import mypy.typeops
from mypy.checker_state import checker_state
from mypy.disallow_str_iteration_state import disallow_str_iteration_state
from mypy.erasetype import erase_type
from mypy.expandtype import (
expand_self_type,
Expand Down Expand Up @@ -479,6 +480,13 @@ def visit_instance(self, left: Instance) -> bool:
# dynamic base classes correctly, see #5456.
return not isinstance(self.right, NoneType)
right = self.right

if (
disallow_str_iteration_state.disallow_str_iteration
and isinstance(right, Instance)
and is_subtype_relation_ignored_to_disallow_str_iteration(left, right)
):
return False
if isinstance(right, TupleType) and right.partial_fallback.type.is_enum:
return self._is_subtype(left, mypy.typeops.tuple_fallback(right))
if isinstance(right, TupleType):
Expand Down Expand Up @@ -2311,3 +2319,21 @@ def is_erased_instance(t: Instance) -> bool:
elif not isinstance(get_proper_type(arg), AnyType):
return False
return True


def is_subtype_relation_ignored_to_disallow_str_iteration(left: Instance, right: Instance) -> bool:
return (
left.type.has_base("builtins.str")
and not right.type.has_base("builtins.str")
and any(
right.type.has_base(base)
for base in (
"collections.abc.Collection",
"collections.abc.Iterable",
"collections.abc.Sequence",
"typing.Collection",
"typing.Iterable",
"typing.Sequence",
)
)
)
2 changes: 2 additions & 0 deletions mypy/typeshed/stdlib/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,8 @@ class _GetItemIterable(Protocol[_T_co]):
@overload
def iter(object: SupportsIter[_SupportsNextT_co], /) -> _SupportsNextT_co: ...
@overload
def iter(object: str, /) -> Iterator[str]: ...
@overload
def iter(object: _GetItemIterable[_T], /) -> Iterator[_T]: ...
@overload
def iter(object: Callable[[], _T | None], sentinel: None, /) -> Iterator[_T]: ...
Expand Down
74 changes: 74 additions & 0 deletions test-data/unit/check-flags.test
Original file line number Diff line number Diff line change
Expand Up @@ -2451,6 +2451,80 @@ f(bytearray(b"asdf")) # E: Argument 1 to "f" has incompatible type "bytearray";
f(memoryview(b"asdf")) # E: Argument 1 to "f" has incompatible type "memoryview"; expected "bytes"
[builtins fixtures/primitives.pyi]

[case testDisallowStrIteration]
# flags: --disallow-str-iteration
from abc import abstractmethod
from typing import Collection, Container, Iterable, Mapping, Protocol, Sequence, TypeVar, Union

def takes_str(x: str):
for ch in x: # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled
reveal_type(ch) # N: Revealed type is "builtins.str"
[ch for ch in x] # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled

s = "hello"

def takes_seq_str(x: Sequence[str]) -> None: ...
takes_seq_str(s) # E: Argument 1 to "takes_seq_str" has incompatible type "str"; expected "Sequence[str]"

def takes_iter_str(x: Iterable[str]) -> None: ...
takes_iter_str(s) # E: Argument 1 to "takes_iter_str" has incompatible type "str"; expected "Iterable[str]"

def takes_collection_str(x: Collection[str]) -> None: ...
takes_collection_str(s) # E: Argument 1 to "takes_collection_str" has incompatible type "str"; expected "Collection[str]"

seq: Sequence[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Sequence[str]")
iterable: Iterable[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Iterable[str]")
collection: Collection[str] = s # E: Incompatible types in assignment (expression has type "str", variable has type "Collection[str]")

def takes_maybe_seq(x: "str | Sequence[int]") -> None:
for ch in x: # E: Iterating over "str | Sequence[int]" is disallowed # N: This is because --disallow-str-iteration is enabled
reveal_type(ch) # N: Revealed type is "builtins.str | builtins.int"

T = TypeVar('T', bound=str)
_T_co = TypeVar('_T_co', covariant=True)

def takes_str_upper_bound(x: T) -> None:
for ch in x: # E: Iterating over "T" is disallowed # N: This is because --disallow-str-iteration is enabled
reveal_type(ch) # N: Revealed type is "builtins.str"

class StrSubclass(str):
def __contains__(self, x: object) -> bool: ...

def takes_str_subclass(x: StrSubclass):
for ch in x: # E: Iterating over "StrSubclass" is disallowed # N: This is because --disallow-str-iteration is enabled
reveal_type(ch) # N: Revealed type is "builtins.str"

class CollectionSubclass(Collection[_T_co], Protocol[_T_co]):
@abstractmethod
def __missing_impl__(self): ...

def takes_collection_subclass(x: CollectionSubclass[str]) -> None: ...

takes_collection_subclass(s) # E: Argument 1 to "takes_collection_subclass" has incompatible type "str"; expected "CollectionSubclass[str]" \
# N: "str" is missing following "CollectionSubclass" protocol member: \
# N: __missing_impl__

takes_collection_subclass(StrSubclass()) # E: Argument 1 to "takes_collection_subclass" has incompatible type "StrSubclass"; expected "CollectionSubclass[str]" \
# N: "StrSubclass" is missing following "CollectionSubclass" protocol member: \
# N: __missing_impl__

def dict_unpacking_unaffected_by_union_simplification(x: Mapping[str, Union[str, Sequence[str]]]) -> None:
x = {**x}

def narrowing(x: "str | Sequence[str]"):
if isinstance(x, str):
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "typing.Sequence[builtins.str]"

[builtins fixtures/str-iter.pyi]
[typing fixtures/typing-str-iter.pyi]

[case testIterStrOverload]
# flags: --disallow-str-iteration
reveal_type(iter("foo")) # N: Revealed type is "typing.Iterable[builtins.str]"
[builtins fixtures/dict.pyi]

[case testNoCrashFollowImportsForStubs]
# flags: --config-file tmp/mypy.ini
{**{"x": "y"}}
Expand Down
3 changes: 3 additions & 0 deletions test-data/unit/fixtures/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,7 @@ class ellipsis: pass
class BaseException: pass

def isinstance(x: object, t: Union[type, Tuple[type, ...]]) -> bool: pass
@overload
def iter(__iterable: str) -> Iterable[str]: pass
@overload
def iter(__iterable: Iterable[T]) -> Iterator[T]: pass
Loading
Loading