(Java) Using Compression.GZIP with TFRecordIO


      In short, `TFRecrdIO.read()` does not seem to work if the entry being read is longer than 8,192 (in terms of byte[] length).  `TFRecordIO.write()` seems to be OK with this though (based on some experiments). Perhaps there is some hard-coded value for this specific length somewhere in the SDK, and I'm wondering if it can be increased or parameterized. 

      I've posted this on StackOverflow, but I was advised to report it here.

      Here are the details:

      We're using Beam Java SDK (and Google Cloud Dataflow to run batch jobs) a lot, and we noticed something weird (possibly a bug?) when we tried to use `TFRecordIO` with `Compression.GZIP`. We were able to come up with some sample code that can reproduce the errors we face.

      To be clear, we are using Beam Java SDK 2.4.

      Suppose we have `PCollection<byte[]>` which can be a PC of proto messages, for instance, in byte[] format.
      We usually write this to GCS (Google Cloud Storage) using Base64 encoding (newline delimited Strings) or using TFRecordIO (without compression). We have had no issue reading the data from GCS in this manner for a very long time (2.5+ years for the former and ~1.5 years for the latter).

      Recently, we tried `TFRecordIO` with `Compression.GZIP` option, and sometimes we get an exception as the data is seen as invalid (while being read). The data itself (the gzip files) is not corrupted, and we've tested various things, and reached the following conclusion.

      When a `byte[]` that is being compressed under `TFRecordIO` is above certain threshold (I'd say when at or above 8192), then `TFRecordIO.read().withCompression(Compression.GZIP)` would not work.
      Specifically, it will throw the following exception:


      // code placeholder
      Exception in thread "main" java.lang.IllegalStateException: Invalid data
      at org.apache.beam.sdk.repackaged.com.google.common.base.Preconditions.checkState(Preconditions.java:444)
      at org.apache.beam.sdk.io.TFRecordIO$TFRecordCodec.read(TFRecordIO.java:642)
      at org.apache.beam.sdk.io.TFRecordIO$TFRecordSource$TFRecordReader.readNextRecord(TFRecordIO.java:526)
      at org.apache.beam.sdk.io.CompressedSource$CompressedReader.readNextRecord(CompressedSource.java:426)
      at org.apache.beam.sdk.io.FileBasedSource$FileBasedReader.advanceImpl(FileBasedSource.java:473)
      at org.apache.beam.sdk.io.FileBasedSource$FileBasedReader.startImpl(FileBasedSource.java:468)
      at org.apache.beam.sdk.io.OffsetBasedSource$OffsetBasedReader.start(OffsetBasedSource.java:261)
      at org.apache.beam.runners.direct.BoundedReadEvaluatorFactory$BoundedReadEvaluator.processElement(BoundedReadEvaluatorFactory.java:141)
      at org.apache.beam.runners.direct.DirectTransformExecutor.processElements(DirectTransformExecutor.java:161)
      at org.apache.beam.runners.direct.DirectTransformExecutor.run(DirectTransformExecutor.java:125)
      at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511)
      at java.util.concurrent.FutureTask.run(FutureTask.java:266)
      at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
      at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
      at java.lang.Thread.run(Thread.java:748)


      This can be reproduced easily, so you can refer to the code at the end. You will also see comments about the byte array length (as I tested with various sizes, I concluded that 8192 is the magic number).

      So I'm wondering if this is a bug or known issue – I couldn't find anything close to this on Apache Beam's Issue Tracker [here][1] but if there is another forum/site I need to check, please let me know!
      If this is indeed a bug, what would be the right channel to report this?

      The following code can reproduce the error we have.

      A successful run (with parameters 1, 39, 100) would show the following message at the end:

      // code placeholder
      ------------ counter metrics from CountDoFn
      [counter] plain_base64_proto_array_len: 8126
      [counter] plain_base64_proto_in: 1
      [counter] plain_base64_proto_val_cnt: 39
      [counter] tfrecord_gz_proto_array_len: 8126
      [counter] tfrecord_gz_proto_in: 1
      [counter] tfrecord_gz_proto_val_cnt: 39
      [counter] tfrecord_uncomp_proto_array_len: 8126
      [counter] tfrecord_uncomp_proto_in: 1
      [counter] tfrecord_uncomp_proto_val_cnt: 39


      With parameters (1, 40, 100) which would push the byte array length over 8192, it will throw the said exception.

      You can tweak the parameters (inside `CreateRandomProtoData` DoFn) to see why the length of `byte[]` being gzipped matters.
      It may help you also to use the following protoc-gen Java class (for `TestProto` used in the main code above. Here it is: [gist link][2]

      [1]: https://issues.apache.org/jira/projects/BEAM/issues/
      [2]: https://gist.github.com/hadenlee/ae127715837bd56f3bc6ba4fe2ccb176

      Main Code:

      (Note that the sample code is writing to and reading from GCS – google cloud storage – but it has nothing do to with the storage as far as I tested.)

      // code placeholder
      package exp.moloco.dataflow2.compression; // NOTE: Change appropriately.
      import java.util.Arrays;
      import java.util.List;
      import java.util.Map;
      import java.util.Map.Entry;
      import java.util.Random;
      import java.util.TreeMap;
      import org.apache.beam.runners.direct.DirectRunner;
      import org.apache.beam.sdk.Pipeline;
      import org.apache.beam.sdk.PipelineResult;
      import org.apache.beam.sdk.io.Compression;
      import org.apache.beam.sdk.io.TFRecordIO;
      import org.apache.beam.sdk.io.TextIO;
      import org.apache.beam.sdk.metrics.Counter;
      import org.apache.beam.sdk.metrics.MetricResult;
      import org.apache.beam.sdk.metrics.Metrics;
      import org.apache.beam.sdk.metrics.MetricsFilter;
      import org.apache.beam.sdk.options.PipelineOptions;
      import org.apache.beam.sdk.options.PipelineOptionsFactory;
      import org.apache.beam.sdk.transforms.Create;
      import org.apache.beam.sdk.transforms.DoFn;
      import org.apache.beam.sdk.transforms.ParDo;
      import org.apache.beam.sdk.values.PCollection;
      import org.apache.commons.codec.binary.Base64;
      import org.joda.time.DateTime;
      import org.joda.time.DateTimeZone;
      import org.slf4j.Logger;
      import org.slf4j.LoggerFactory;
      import com.google.protobuf.InvalidProtocolBufferException;
      import com.moloco.dataflow.test.StackOverflow.TestProto;
      import com.moloco.dataflow2.Main;
      // @formatter:off
      // This code uses TestProto (java class) that is generated by protoc.
      // The message definition is as follows (in proto3, but it shouldn't matter):
      // message TestProto {
      //   int64 count = 1;
      //   string name = 2;
      //   repeated string values = 3;
      // }
      // Note that this code does not depend on whether this proto is used,
      // or any other byte[] is used (see CreateRandomData DoFn later which generates the data being used in the code).
      // We tested both, but are presenting this as a concrete example of how (our) code in production can be affected.
      // @formatter:on
      public class CompressionTester {
        private static final Logger LOG = LoggerFactory.getLogger(CompressionTester.class);
        static final List<String> lines = Arrays.asList("some dummy string that will not used in this job.");
        // Some GCS buckets where data will be written to.
        // %s will be replaced by some timestamped String for easy debugging.
        static final String PATH_TO_GCS_PLAIN_BASE64 = Main.SOME_BUCKET + "/comp-test/%s/output-plain-base64";
        static final String PATH_TO_GCS_TFRECORD_UNCOMP = Main.SOME_BUCKET + "/comp-test/%s/output-tfrecord-uncompressed";
        static final String PATH_TO_GCS_TFRECORD_GZ = Main.SOME_BUCKET + "/comp-test/%s/output-tfrecord-gzip";
        // This DoFn reads byte[] which represents a proto message (TestProto).
        // It simply counts the number of proto objects it processes
        // as well as the number of Strings each proto object contains.
        // When the pipeline terminates, the values of the Counters will be printed out.
        static class CountDoFn extends DoFn<byte[], TestProto> {
          private final Counter protoIn;
          private final Counter protoValuesCnt;
          private final Counter protoByteArrayLength;
          public CountDoFn(String name) {
            protoIn = Metrics.counter(this.getClass(), name + "_proto_in");
            protoValuesCnt = Metrics.counter(this.getClass(), name + "_proto_val_cnt");
            protoByteArrayLength = Metrics.counter(this.getClass(), name + "_proto_array_len");
          public void processElement(ProcessContext c) throws InvalidProtocolBufferException {
            TestProto tp = TestProto.parseFrom(c.element());
        // This DoFn emits a number of TestProto objects as byte[].
        // Input to this DoFn is ignored (not used).
        // Each TestProto object contains three fields: count (int64), name (string), and values (repeated string).
        // The three parameters in DoFn determines
        // (1) the number of proto objects to be generated,
        // (2) the number of (repeated) strings to be added to each proto object, and
        // (3) the length of (each) string.
        // TFRecord with Compression (when reading) fails when the parameters are 1, 40, 100, for instance.
        // TFRecord with Compression (when reading) succeeds when the parameters are 1, 39, 100, for instance.
        static class CreateRandomProtoData extends DoFn<String, byte[]> {
          static final int NUM_PROTOS = 1; // Total number of TestProto objects to be emitted by this DoFn.
          static final int NUM_STRINGS = 40; // Total number of strings in each TestProto object ('repeated string').
          static final int STRING_LEN = 100; // Length of each string object.
          // Returns a random string of length len.
          // For debugging purposes, the string only contains upper-case English alphabets.
          static String getRandomString(Random rd, int len) {
            StringBuffer sb = new StringBuffer();
            for (int i = 0; i < len; i++) {
              sb.append('A' + (rd.nextInt(26)));
            return sb.toString();
          // Returns a randomly generated TestProto object.
          // Each string is generated randomly using getRandomString().
          static TestProto getRandomProto(Random rd) {
            TestProto.Builder tpBuilder = TestProto.newBuilder();
            tpBuilder.setName(getRandomString(rd, STRING_LEN));
            for (int i = 0; i < NUM_STRINGS; i++) {
              tpBuilder.addValues(getRandomString(rd, STRING_LEN));
            return tpBuilder.build();
          // Emits TestProto objects are byte[].
          public void processElement(ProcessContext c) {
            // For debugging purposes, we set the seed here.
            Random rd = new Random();
            for (int n = 0; n < NUM_PROTOS; n++) {
              byte[] data = getRandomProto(rd).toByteArray();
              // With parameters (1, 39, 100), the array length is 8126. It works fine.
              // With parameters (1, 40, 100), the array length is 8329. It breaks TFRecord with GZIP.
              System.out.println("byte array length = " + data.length);
        public static void execute() {
          PipelineOptions options = PipelineOptionsFactory.create();
          // For debugging purposes, write files under 'gcsSubDir' so we can easily distinguish.
          final String gcsSubDir =
              String.format("%s-%d", DateTime.now(DateTimeZone.UTC), DateTime.now(DateTimeZone.UTC).getMillis());
          // Write PCollection<TestProto> in 3 different ways to GCS.
            Pipeline pipeline = Pipeline.create(options);
            // Create dummy data which is a PCollection of byte arrays (each array representing a proto message).
            PCollection<byte[]> data = pipeline.apply(Create.of(lines)).apply(ParDo.of(new CreateRandomProtoData()));
            // 1. Write as plain-text with base64 encoding.
            data.apply(ParDo.of(new DoFn<byte[], String>() {
              public void processElement(ProcessContext c) {
                c.output(new String(Base64.encodeBase64(c.element())));
            })).apply(TextIO.write().to(String.format(PATH_TO_GCS_PLAIN_BASE64, gcsSubDir)).withNumShards(1));
            // 2. Write as TFRecord.
            data.apply(TFRecordIO.write().to(String.format(PATH_TO_GCS_TFRECORD_UNCOMP, gcsSubDir)).withNumShards(1));
            // 3. Write as TFRecord-gzip.
                .to(String.format(PATH_TO_GCS_TFRECORD_GZ, gcsSubDir)).withNumShards(1));
          LOG.info("               READ TEST BEGINS ");
          // Read PCollection<TestProto> in 3 different ways from GCS.
            Pipeline pipeline = Pipeline.create(options);
            // 1. Read as plain-text.
            pipeline.apply(TextIO.read().from(String.format(PATH_TO_GCS_PLAIN_BASE64, gcsSubDir) + "*"))
                .apply(ParDo.of(new DoFn<String, byte[]>() {
                  public void processElement(ProcessContext c) {
                })).apply("plain-base64", ParDo.of(new CountDoFn("plain_base64")));
            // 2. Read as TFRecord -> byte array.
            pipeline.apply(TFRecordIO.read().from(String.format(PATH_TO_GCS_TFRECORD_UNCOMP, gcsSubDir) + "*"))
                .apply("tfrecord-uncomp", ParDo.of(new CountDoFn("tfrecord_uncomp")));
            // 3. Read as TFRecord-gz -> byte array.
            // This seems to fail when 'data size' becomes large.
                    .from(String.format(PATH_TO_GCS_TFRECORD_GZ, gcsSubDir) + "*"))
                .apply("tfrecord_gz", ParDo.of(new CountDoFn("tfrecord_gz")));
            // 4. Run pipeline.
            PipelineResult res = pipeline.run();
            // Check CountDoFn's metrics.
            // The numbers should match.
            Map<String, Long> counterValues = new TreeMap<String, Long>();
            for (MetricResult<Long> counter : res.metrics().queryMetrics(MetricsFilter.builder().build()).counters()) {
              counterValues.put(counter.name().name(), counter.committed());
            StringBuffer sb = new StringBuffer();
            sb.append("\n------------ counter metrics from CountDoFn\n");
            for (Entry<String, Long> entry : counterValues.entrySet()) {
              sb.append(String.format("[counter] %40s: %5d\n", entry.getKey(), entry.getValue()));




