Skip to content

Commit 12d8b7d

Browse files
author
Nicolas Laurent
committed
[GR-26883] Rewrite Array#recursively_flatten to not use Thread#detect_recursion and reduce splitting.
PullRequest: truffleruby/2054
2 parents 62d491b + 0984608 commit 12d8b7d

File tree

6 files changed

+191
-40
lines changed

6 files changed

+191
-40
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) 2020 Oracle and/or its affiliates. All rights reserved. This
2+
# code is released under a tri EPL/GPL/LGPL license. You can use it,
3+
# redistribute it and/or modify it under the terms of the:
4+
#
5+
# Eclipse Public License version 2.0, or
6+
# GNU General Public License version 2, or
7+
# GNU Lesser General Public License version 2.1.
8+
9+
array = []
10+
current_array = array
11+
100.times do
12+
next_array = []
13+
current_array.append(0, 1, next_array, 3, 4)
14+
current_array = next_array
15+
end
16+
17+
benchmark 'core-array-flatten-recursive' do
18+
array.flatten
19+
end

bench/micro/array/flatten-simple.rb

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2020 Oracle and/or its affiliates. All rights reserved. This
2+
# code is released under a tri EPL/GPL/LGPL license. You can use it,
3+
# redistribute it and/or modify it under the terms of the:
4+
#
5+
# Eclipse Public License version 2.0, or
6+
# GNU General Public License version 2, or
7+
# GNU Lesser General Public License version 2.1.
8+
9+
array = [ [[1, 2], [3,4]], [[5,6], [7,8]] ]
10+
11+
benchmark 'core-array-flatten-simple' do
12+
array.flatten
13+
end
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) 2020 Oracle and/or its affiliates. All rights reserved. This
3+
* code is released under a tri EPL/GPL/LGPL license. You can use it,
4+
* redistribute it and/or modify it under the terms of the:
5+
*
6+
* Eclipse Public License version 2.0, or
7+
* GNU General Public License version 2, or
8+
* GNU Lesser General Public License version 2.1.
9+
*/
10+
package org.truffleruby.collections;
11+
12+
import org.truffleruby.core.array.ArrayUtils;
13+
14+
/** Simplistic array stack implementation that will partial-evaluate nicely, unlike {@link java.util.ArrayDeque}. */
15+
@SuppressWarnings("unchecked")
16+
public class SimpleStack<T> {
17+
18+
Object[] storage;
19+
int index = -1;
20+
21+
public SimpleStack() {
22+
this(16);
23+
}
24+
25+
public SimpleStack(int length) {
26+
this.storage = new Object[length];
27+
}
28+
29+
public boolean isEmpty() {
30+
return index == -1;
31+
}
32+
33+
public void push(T value) {
34+
if (++index == storage.length) {
35+
storage = ArrayUtils.grow(storage, storage.length * 2);
36+
}
37+
storage[index] = value;
38+
}
39+
40+
public T peek() {
41+
return (T) storage[index];
42+
}
43+
44+
public T pop() {
45+
T out = (T) storage[index];
46+
storage[index] = null;
47+
--index;
48+
return out;
49+
}
50+
51+
public int size() {
52+
return index + 1;
53+
}
54+
}

src/main/java/org/truffleruby/core/array/ArrayNodes.java

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import com.oracle.truffle.api.dsl.CachedLanguage;
1919
import com.oracle.truffle.api.profiles.LoopConditionProfile;
20+
import org.graalvm.collections.EconomicSet;
21+
import org.graalvm.collections.Equivalence;
2022
import org.truffleruby.Layouts;
2123
import org.truffleruby.RubyLanguage;
2224
import org.truffleruby.builtins.CoreMethod;
@@ -26,6 +28,7 @@
2628
import org.truffleruby.builtins.Primitive;
2729
import org.truffleruby.builtins.PrimitiveArrayArgumentsNode;
2830
import org.truffleruby.builtins.YieldingCoreMethodNode;
31+
import org.truffleruby.collections.SimpleStack;
2932
import org.truffleruby.core.CoreLibrary;
3033
import org.truffleruby.core.Hashing;
3134
import org.truffleruby.core.array.ArrayBuilderNode.BuilderState;
@@ -2338,4 +2341,98 @@ protected boolean isStoreNative(RubyArray array,
23382341
return stores.isNative(array.store);
23392342
}
23402343
}
2344+
2345+
@Primitive(name = "array_flatten_helper", lowerFixnum = 2)
2346+
public abstract static class FlattenHelperNode extends PrimitiveArrayArgumentsNode {
2347+
2348+
@Specialization(guards = "!canContainObject.execute(array)")
2349+
protected boolean flattenHelperPrimitive(RubyArray array, RubyArray out, int maxLevels,
2350+
@Cached ArrayAppendManyNode concat,
2351+
@Cached TypeNodes.CanContainObjectNode canContainObject) {
2352+
concat.executeAppendMany(out, array);
2353+
return false;
2354+
}
2355+
2356+
@Specialization(replaces = "flattenHelperPrimitive")
2357+
protected boolean flattenHelper(RubyArray array, RubyArray out, int maxLevels,
2358+
@CachedLanguage RubyLanguage language,
2359+
@Cached TypeNodes.CanContainObjectNode canContainObject,
2360+
@Cached ArrayAppendManyNode concat,
2361+
@Cached AtNode at,
2362+
@Cached DispatchNode convert,
2363+
@Cached ArrayAppendOneNode append) {
2364+
2365+
boolean modified = false;
2366+
final EconomicSet<RubyArray> visited = EconomicSet.create(Equivalence.IDENTITY);
2367+
class Entry {
2368+
final RubyArray array;
2369+
final int index;
2370+
2371+
Entry(RubyArray array, int index) {
2372+
this.array = array;
2373+
this.index = index;
2374+
}
2375+
}
2376+
final SimpleStack<Entry> workStack = new SimpleStack<>();
2377+
workStack.push(new Entry(array, 0));
2378+
2379+
while (!workStack.isEmpty()) {
2380+
final Entry e = workStack.pop();
2381+
2382+
if (e.index == 0) {
2383+
if (!canContainObject.execute(e.array)) {
2384+
concat.executeAppendMany(out, e.array);
2385+
continue;
2386+
} else if (contains(visited, e.array)) {
2387+
throw new RaiseException(
2388+
getContext(),
2389+
coreExceptions().argumentError("tried to flatten recursive array", this));
2390+
} else if (maxLevels == workStack.size()) {
2391+
concat.executeAppendMany(out, e.array);
2392+
continue;
2393+
}
2394+
add(visited, e.array);
2395+
}
2396+
2397+
int i = e.index;
2398+
for (; i < e.array.size; ++i) {
2399+
final Object obj = at.executeAt(e.array, i);
2400+
final Object converted = convert.call(
2401+
coreLibrary().truffleTypeModule,
2402+
"rb_check_convert_type",
2403+
obj,
2404+
coreLibrary().arrayClass,
2405+
language.coreSymbols.TO_ARY);
2406+
if (converted == nil) {
2407+
append.executeAppendOne(out, obj);
2408+
} else {
2409+
modified = true;
2410+
workStack.push(new Entry(e.array, i + 1));
2411+
workStack.push(new Entry((RubyArray) converted, 0));
2412+
break;
2413+
}
2414+
}
2415+
if (i == e.array.size) {
2416+
remove(visited, e.array);
2417+
}
2418+
}
2419+
2420+
return modified;
2421+
}
2422+
2423+
@TruffleBoundary
2424+
private static boolean contains(EconomicSet<RubyArray> set, RubyArray array) {
2425+
return set.contains(array);
2426+
}
2427+
2428+
@TruffleBoundary
2429+
private static void remove(EconomicSet<RubyArray> set, RubyArray array) {
2430+
set.remove(array);
2431+
}
2432+
2433+
@TruffleBoundary
2434+
private static void add(EconomicSet<RubyArray> set, RubyArray array) {
2435+
set.add(array);
2436+
}
2437+
}
23412438
}

src/main/java/org/truffleruby/core/support/TypeNodes.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,12 @@ protected Object objectHiddenVarSet(RubyDynamicObject object, Object identifier,
261261
@ImportStatic(ArrayGuards.class)
262262
public abstract static class CanContainObjectNode extends PrimitiveArrayArgumentsNode {
263263

264+
public static CanContainObjectNode create() {
265+
return TypeNodesFactory.CanContainObjectNodeFactory.create(null);
266+
}
267+
268+
abstract public boolean execute(RubyArray array);
269+
264270
@Specialization(
265271
guards = {
266272
"stores.accepts(array.store)",

src/main/ruby/truffleruby/core/array.rb

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def flatten(level=-1)
444444
return self.dup if level == 0
445445

446446
out = self.class.allocate # new_reserved size
447-
recursively_flatten(self, out, level)
447+
Primitive.array_flatten_helper(self, out, level)
448448
Primitive.infect(out, self)
449449
out
450450
end
@@ -456,7 +456,7 @@ def flatten!(level=-1)
456456
return nil if level == 0
457457

458458
out = self.class.allocate # new_reserved size
459-
if recursively_flatten(self, out, level)
459+
if Primitive.array_flatten_helper(self, out, level)
460460
Primitive.steal_array_storage(self, out)
461461
return self
462462
end
@@ -1251,44 +1251,6 @@ def zip(*others)
12511251
end
12521252
end
12531253

1254-
# Helper to recurse through flattening since the method
1255-
# is not allowed to recurse itself. Detects recursive structures.
1256-
def recursively_flatten(array, out, max_levels = -1)
1257-
modified = false
1258-
1259-
# Strict equality since < 0 means 'infinite'
1260-
if max_levels == 0
1261-
out.concat(array)
1262-
return false
1263-
end
1264-
1265-
max_levels -= 1
1266-
recursion = Truffle::ThreadOperations.detect_recursion(array) do
1267-
array = Truffle::Type.coerce_to(array, Array, :to_ary)
1268-
1269-
i = 0
1270-
size = array.size
1271-
1272-
while i < size
1273-
o = array.at i
1274-
1275-
tmp = Truffle::Type.rb_check_convert_type(o, Array, :to_ary)
1276-
if Primitive.nil? tmp
1277-
out << o
1278-
else
1279-
modified = true
1280-
recursively_flatten tmp, out, max_levels
1281-
end
1282-
1283-
i += 1
1284-
end
1285-
end
1286-
1287-
raise ArgumentError, 'tried to flatten recursive array' if recursion
1288-
modified
1289-
end
1290-
private :recursively_flatten
1291-
12921254
private def sort_fallback(&block)
12931255
# Use this instead of #dup as we want an instance of Array
12941256
sorted = Array.new(self)

0 commit comments

Comments
 (0)