@@ -4,43 +4,85 @@ import java.nio.charset.StandardCharsets
4
4
import java .time .LocalDateTime
5
5
6
6
import cats .effect .IO
7
+ import cats .implicits .catsSyntaxApplicativeId
8
+ import fs2 .Pipe
9
+
7
10
import com .intenthq .action_processor .integrations .aggregations .Aggregate
8
11
import com .intenthq .action_processor .integrations .feeds .{Feed , FeedContext }
9
12
import com .intenthq .action_processor .integrations .serializations .csv .CsvSerialization
10
- import fs2 .Pipe
13
+ import com .intenthq .embeddings .Mapping
14
+
11
15
import weaver .SimpleIOSuite
12
16
13
17
object FeedAggregatedSpec extends SimpleIOSuite {
14
18
15
19
test(" should return a stream of aggregated csv feed rows" ) {
16
20
val aggregatedFeed = new PersonsAggregatedByScoreFeed (
17
- Person (" Peter" , LocalDateTime .parse(" 2001-01-01T00:00:00" ), 5 ),
18
- Person (" Gabriela" , LocalDateTime .parse(" 2002-01-01T00:00:00" ), 7 ),
19
- Person (" Jolie" , LocalDateTime .parse(" 2003-01-01T00:00:00" ), 4 ),
20
- Person (" Peter" , LocalDateTime .parse(" 2001-01-01T00:00:00" ), 6 )
21
+ Person (" Peter" , LocalDateTime .parse(" 2000-01-01T00:00:00" ), " music.com" , 3 ),
22
+ Person (" Peter" , LocalDateTime .parse(" 2000-01-01T00:00:00" ), " rap.com" , 5 ),
23
+ Person (" Peter" , LocalDateTime .parse(" 2000-01-01T00:00:00" ), " rock.com" , 7 ),
24
+ Person (" Gabriela" , LocalDateTime .parse(" 2000-01-01T00:00:00" ), " music.com" , 2 ),
25
+ Person (" Gabriela" , LocalDateTime .parse(" 2000-01-01T00:00:00" ), " rap.com" , 10 ),
26
+ Person (" Gabriela" , LocalDateTime .parse(" 2000-01-01T00:00:00" ), " unknown.com" , 10 ),
27
+ Person (" Jolie" , LocalDateTime .parse(" 2000-01-01T00:00:00" ), " rap.com" , 1 ),
28
+ Person (" Jolie" , LocalDateTime .parse(" 2000-01-01T00:00:00" ), " rock.com" , 7 )
21
29
)
22
30
31
+ val mapping = new Mapping [String , List [String ], IO ] {
32
+ private val map = Map (
33
+ " music.com" -> List (" Music" ),
34
+ " rap.com" -> List (" Music" , " Rap" ),
35
+ " rock.com" -> List (" Music" , " Rock" )
36
+ )
37
+ override def get (key : String ): IO [Option [List [String ]]] = map.get(key).pure
38
+ }
39
+
23
40
val expectedResult : Set [String ] = Set (
24
- " Peter,2001-01-01T00:00:00,11" ,
25
- " Gabriela,2002-01-01T00:00:00,7" ,
26
- " Jolie,2003-01-01T00:00:00,4"
41
+ " Peter,2000-01-01T00:00:00,Music,15" ,
42
+ " Peter,2000-01-01T00:00:00,Rap,5" ,
43
+ " Peter,2000-01-01T00:00:00,Rock,7" ,
44
+ " Gabriela,2000-01-01T00:00:00,Music,12" ,
45
+ " Gabriela,2000-01-01T00:00:00,Rap,10" ,
46
+ " Jolie,2000-01-01T00:00:00,Music,8" ,
47
+ " Jolie,2000-01-01T00:00:00,Rock,7" ,
48
+ " Jolie,2000-01-01T00:00:00,Rap,1"
27
49
).map(_ + '\n ' )
28
50
29
51
for {
30
- feedStreamLinesBytes <- aggregatedFeed.stream(TestDefaults .feedContext).compile.toList
52
+ feedStreamLinesBytes <- aggregatedFeed.stream(TestDefaults .feedContext.copy(embeddings = Some (mapping)) ).compile.toList
31
53
feedStreamLines = feedStreamLinesBytes.map(bytes => new String (bytes, StandardCharsets .UTF_8 )).toSet
32
54
} yield expect(feedStreamLines == expectedResult)
33
55
}
34
56
}
35
57
36
- case class Person (name : String , bornDate : LocalDateTime , score : Int )
37
- case class AggregatedPerson (name : String , bornDate : LocalDateTime )
58
+ case class Person (name : String , timestamp : LocalDateTime , domain : String , count : Int )
59
+ case class MappedPerson (name : String , timestamp : LocalDateTime , interests : List [String ], count : Int )
60
+ case class AggregatedPerson (name : String , timestamp : LocalDateTime , interest : String )
61
+
62
+ class PersonsAggregatedByScoreFeed (persons : Person * ) extends Feed [MappedPerson , AggregatedPerson ] {
63
+ override def inputStream (feedContext : FeedContext [IO ]): fs2.Stream [IO , MappedPerson ] =
64
+ fs2
65
+ .Stream (persons : _* )
66
+ .through(mapPersons(feedContext))
38
67
39
- class PersonsAggregatedByScoreFeed (persons : Person * ) extends Feed [Person , AggregatedPerson ] {
40
- override def inputStream (feedContext : FeedContext [IO ]): fs2.Stream [IO , Person ] = fs2.Stream (persons : _* ).covary[IO ]
68
+ private def mapPersons (feedContext : FeedContext [IO ]): Pipe [IO , Person , MappedPerson ] = { in =>
69
+ fs2.Stream .eval(IO .fromOption(feedContext.embeddings)(new RuntimeException (" Mapping required" ))).flatMap { mappings =>
70
+ in.evalMap { person =>
71
+ mappings
72
+ .get(person.domain)
73
+ .map(
74
+ _.map(interests => MappedPerson (person.name, person.timestamp, interests, person.count))
75
+ )
76
+ }.unNone
77
+ }
78
+ }
41
79
42
- override def transform (feedContext : FeedContext [IO ]): Pipe [IO , Person , (AggregatedPerson , Long )] =
43
- Aggregate .aggregateByKey[Person , AggregatedPerson ](feedContext, person => AggregatedPerson (person.name, person.bornDate), _.score.toLong)
80
+ override def transform (feedContext : FeedContext [IO ]): Pipe [IO , MappedPerson , (AggregatedPerson , Long )] =
81
+ Aggregate .aggregateByKeys[MappedPerson , AggregatedPerson ](
82
+ feedContext,
83
+ person => person.interests.map(AggregatedPerson (person.name, person.timestamp, _)),
84
+ _.count.toLong
85
+ )
44
86
45
87
override def serialize (o : AggregatedPerson , counter : Long ): Array [Byte ] = CsvSerialization .serialize((o, counter))
46
88
}
0 commit comments