Skip to content

fixes #24720; std lib iterators unnecessarily require value copies #24723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/semstmts.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ proc semForVars(c: PContext, n: PNode; flags: TExprFlags): PNode =
# BUGFIX: don't use `iter` here as that would strip away
# the ``tyGenericInst``! See ``tests/compile/tgeneric.nim``
# for an example:
v.typ = iterBase
v.typ = makeIterTupleType(c, iterBase)
n[0] = newSymNode(v)
if sfGenSym notin v.flags and not isDiscardUnderscore(v): addDecl(c, v)
elif v.owner == nil: setOwner(v, getCurrOwner(c))
Expand Down
20 changes: 20 additions & 0 deletions compiler/semtypes.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,23 @@ proc maybeAliasType(c: PContext; typeExpr, prev: PType): PType =
else:
result = nil

proc makeIterTupleType(c: PContext; typ: PType): PType =
if typ.kind == tyTuple:
var hasView = false
result = newTypeS(tyTuple, c)

for i in 0..<typ.len:
if typ[i].kind in {tyVar, tyLent}:
hasView = true
rawAddSon(result, typ[i].skipTypes({tyVar, tyLent}))

if hasView:
result.n = typ.n
else:
result = typ
else:
result = typ

proc fixupTypeOf(c: PContext, prev: PType, typ: PType) =
if prev != nil:
let result = newTypeS(tyAlias, c)
Expand Down Expand Up @@ -2021,6 +2038,9 @@ proc semTypeOf2(c: PContext; n: PNode; prev: PType): PType =
result = base
fixupTypeOf(c, prev, result)

if result.kind == tyTuple:
result = makeIterTupleType(c, result)

proc semTypeIdent(c: PContext, n: PNode): PSym =
if n.kind == nkSym:
result = getGenSym(c, n.sym)
Expand Down
48 changes: 46 additions & 2 deletions compiler/transf.nim
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,47 @@ template assignTupleUnpacking(c: PTransf, e: PNode) =
let rhs = transform(c, newTupleAccess(c.graph, e, i))
result.add(asgnTo(lhs, rhs))

proc makeTupleUnpack(c: PTransf; lhs: PNode; rhs: PNode): PNode =
result = newNodeI(nkStmtList, lhs.info)

let temp = newSym(skTemp, getIdent(c.graph.cache, "tmpTupleAsgn"), c.idgen, getCurrOwner(c), rhs.info)
temp.typ = rhs.typ
temp.flags.incl(sfGenSym)
var v = newNodeI(nkLetSection, rhs.info)
let tempNode = newSymNode(temp) #newIdentNode(getIdent(genPrefix & $temp.id), value.info)
var vpart = newNodeI(nkIdentDefs, v.info, 3)
vpart[0] = tempNode
vpart[1] = c.graph.emptyNode
vpart[2] = rhs
v.add vpart
result.add(v)

var tupleConstr = newNodeIT(nkTupleConstr, lhs.info, lhs.typ)

for i in 0..<rhs.typ.len:
var field: PNode = nil
if rhs.typ[i].kind in {tyVar, tyLent}:
let tupleType = newTupleAccessRaw(tempNode, i)
tupleType.typ() = rhs.typ[i]
field = newDeref(tupleType)
else:
field = newTupleAccessRaw(tempNode, i)

field.typ() = rhs.typ[i].skipTypes({tyVar, tyLent})

tupleConstr.add field

result.add newAsgnStmt(c, nkFastAsgn, lhs, tupleConstr, false)
result = transform(c, result)

proc hasViewTypes(typ: PType): bool =
if typ.kind == tyTuple:
result = false
for i in 0..<typ.len:
if typ[i].kind in {tyVar, tyLent}:
return true
else:
return false

proc transformYield(c: PTransf, n: PNode): PNode =
proc asgnTo(lhs: PNode, rhs: PNode): PNode =
Expand All @@ -394,7 +435,10 @@ proc transformYield(c: PTransf, n: PNode): PNode =
case lhs.kind
of nkSym:
internalAssert c.graph.config, lhs.sym.kind == skForVar
result = newAsgnStmt(c, nkFastAsgn, lhs, rhs, false)
if rhs.typ.kind == tyTuple and hasViewTypes(rhs.typ):
result = makeTupleUnpack(c, lhs, rhs)
else:
result = newAsgnStmt(c, nkFastAsgn, lhs, rhs, false)
of nkDotExpr:
result = newAsgnStmt(c, nkAsgn, lhs, rhs, false)
else:
Expand Down Expand Up @@ -463,7 +507,7 @@ proc transformYield(c: PTransf, n: PNode): PNode =
result.add(asgnTo(lhs, rhs))
else:
let lhs = c.transCon.forStmt[0]
let rhs = transform(c, e)
let rhs = transform(c, e)
result.add(asgnTo(lhs, rhs))


Expand Down
12 changes: 6 additions & 6 deletions lib/pure/collections/tables.nim
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ template withValue*[A, B](t: Table[A, B], key: A,
discard


iterator pairs*[A, B](t: Table[A, B]): (A, B) =
iterator pairs*[A, B](t: Table[A, B]): (lent A, lent B) =
## Iterates over any `(key, value)` pair in the table `t`.
##
## See also:
Expand Down Expand Up @@ -1201,7 +1201,7 @@ proc `==`*[A, B](s, t: TableRef[A, B]): bool =



iterator pairs*[A, B](t: TableRef[A, B]): (A, B) =
iterator pairs*[A, B](t: TableRef[A, B]): (lent A, lent B) =
## Iterates over any `(key, value)` pair in the table `t`.
##
## See also:
Expand Down Expand Up @@ -1789,7 +1789,7 @@ proc `==`*[A, B](s, t: OrderedTable[A, B]): bool =



iterator pairs*[A, B](t: OrderedTable[A, B]): (A, B) =
iterator pairs*[A, B](t: OrderedTable[A, B]): (lent A, lent B) =
## Iterates over any `(key, value)` pair in the table `t` in insertion
## order.
##
Expand Down Expand Up @@ -2212,7 +2212,7 @@ proc `==`*[A, B](s, t: OrderedTableRef[A, B]): bool =



iterator pairs*[A, B](t: OrderedTableRef[A, B]): (A, B) =
iterator pairs*[A, B](t: OrderedTableRef[A, B]): (lent A, lent B) =
## Iterates over any `(key, value)` pair in the table `t` in insertion
## order.
##
Expand Down Expand Up @@ -2622,7 +2622,7 @@ proc `==`*[A](s, t: CountTable[A]): bool =
equalsImpl(s, t)


iterator pairs*[A](t: CountTable[A]): (A, int) =
iterator pairs*[A](t: CountTable[A]): (lent A, int) =
## Iterates over any `(key, value)` pair in the table `t`.
##
## See also:
Expand Down Expand Up @@ -2899,7 +2899,7 @@ proc `==`*[A](s, t: CountTableRef[A]): bool =
else: result = s[] == t[]


iterator pairs*[A](t: CountTableRef[A]): (A, int) =
iterator pairs*[A](t: CountTableRef[A]): (lent A, int) =
## Iterates over any `(key, value)` pair in the table `t`.
##
## See also:
Expand Down
20 changes: 20 additions & 0 deletions tests/arc/t24720.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
discard """
matrix: "--mm:orc"
output: '''
found entry
'''
"""

import std/tables
type NoCopies = object

proc `=copy`(a: var NoCopies, b: NoCopies) {.error.}

# bug #24720
proc foo() =
var t: Table[int, NoCopies]
t[3] = NoCopies() # only moves
for k, v in t.pairs(): # lent values, no need to copy!
echo "found entry"

foo()
Loading