Skip to content

Commit d512ba3

Browse files
committed
Make record type enum
1 parent 7266306 commit d512ba3

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

modules/cli/src/main/scala/tfr/Cli.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,29 @@ object Cli {
3434
given ioContextShift as ContextShift[IO] =
3535
IO.contextShift(scala.concurrent.ExecutionContext.Implicits.global)
3636

37+
object Options {
38+
enum RecordType(value: String) {
39+
case Example extends RecordType("example")
40+
case PredictionLog extends RecordType("prediction_log")
41+
}
42+
43+
given recordValueConverter as ValueConverter[RecordType] =
44+
singleArgConverter[RecordType] { s =>
45+
RecordType.valueOf(s.split("_").fold("")(_ + _.capitalize))
46+
}
47+
}
48+
3749
final class Options(arguments: Seq[String]) extends ScallopConf(arguments) {
50+
import Options.{given _, _}
3851
printedName = "tfr"
3952
banner("""Usage: tfr [options] <files? | STDIN>
4053
|TensorFlow TFRecord reader CLI tool
4154
|Options:
4255
|""".stripMargin)
4356

44-
val record: ScallopOption[String] =
45-
opt[String](
46-
default = Some("example"),
57+
val record: ScallopOption[RecordType] =
58+
opt[RecordType](
59+
default = Some(RecordType.Example),
4760
descr = "Record type to be read { example | prediction_log }"
4861
)
4962
val checkCrc32 = opt[Boolean](
@@ -55,7 +68,7 @@ object Cli {
5568
default = Some(false),
5669
descr = "Output examples as flat JSON objects"
5770
)
58-
val files =trailArg[List[String]](
71+
val files = trailArg[List[String]](
5972
required = false,
6073
descr = "files? | STDIN",
6174
default = Some(List.empty)
@@ -65,19 +78,18 @@ object Cli {
6578

6679
def main(args: Array[String]): Unit = {
6780
val options = Options(ArraySeq.unsafeWrapArray(args))
68-
println(options.files())
6981
val resources = options.files() match
7082
case Nil => Resources.stdin[IO] :: Nil
7183
case l => l.iterator.map(Resources.file[IO]).toList
7284

7385
options.record() match
74-
case "example" =>
86+
case Options.RecordType.Example =>
7587
given exampleEncoder as Encoder[Example] =
7688
if options.flat() then flat.exampleEncoder
7789
else tfr.instances.example.exampleEncoder
7890

7991
run[Example](options, resources)
80-
case "prediction_log" =>
92+
case Options.RecordType.PredictionLog =>
8193
given predictionLogEncoder as Encoder[PredictionLog] =
8294
tfr.instances.prediction.predictionLogEncoder
8395

0 commit comments

Comments
 (0)