Skip to content

Tensor to Proto Bug with SparseTensor: " java.lang.IllegalArgumentException: size of dimensions must equals size of values" #855

@austinzh

Description

@austinzh

Call Stacks

  java.lang.IllegalArgumentException: size of dimensions must equals size of values
  at ml.combust.mleap.tensor.Tensor$.normalizeDimensions(Tensor.scala:63)
  at ml.combust.mleap.tensor.Tensor$.create(Tensor.scala:33)
  at ml.combust.bundle.tensor.TensorSerializer$.fromProto(TensorSerializer.scala:74)
  at ml.combust.bundle.dsl.Value.getTensor(Value.scala:323)

Possible cause
In ml.combust.bundle.tensor.TensorSerializer$.toProto, we save rawValue,
But in ml.combust.bundle.tensor.TensorSerializer$.fromProto we load it as DenseTensor.
the size of SparseVector rawValue is much smaller array, so it cause this error.
I suggest we separate SparseTensor and DenseTensor

  def toProto[T](t: tensor.Tensor[T]): Tensor = {
    val (tpe, values) = t.base.runtimeClass match {
      case tensor.Tensor.BooleanClass =>
        (BasicType.BOOLEAN, BooleanArraySerializer.write(t.rawValues.asInstanceOf[Array[Boolean]]))
      case tensor.Tensor.ByteClass =>
        (BasicType.BYTE, ByteArraySerializer.write(t.rawValues.asInstanceOf[Array[Byte]]))
      case tensor.Tensor.ShortClass =>
        (BasicType.SHORT, ShortArraySerializer.write(t.rawValues.asInstanceOf[Array[Short]]))
      case tensor.Tensor.IntClass =>
        (BasicType.INT, IntArraySerializer.write(t.rawValues.asInstanceOf[Array[Int]]))
      case tensor.Tensor.LongClass =>
        (BasicType.LONG, LongArraySerializer.write(t.rawValues.asInstanceOf[Array[Long]]))
      case tensor.Tensor.FloatClass =>
        (BasicType.FLOAT, FloatArraySerializer.write(t.rawValues.asInstanceOf[Array[Float]]))
      case tensor.Tensor.DoubleClass =>
        (BasicType.DOUBLE, DoubleArraySerializer.write(t.rawValues.asInstanceOf[Array[Double]]))
      case tensor.Tensor.StringClass =>
        (BasicType.STRING, StringArraySerializer.write(t.rawValues.asInstanceOf[Array[String]]))
      case tensor.Tensor.ByteStringClass =>
        (BasicType.BYTE_STRING, ByteStringArraySerializer.write(t.rawValues.asInstanceOf[Array[ByteString]]))
      case _ => throw new IllegalArgumentException(s"unsupported tensor type ${t.base}")
    }
  def fromProto[T](t: Tensor): tensor.Tensor[T] = {
    val dimensions = t.shape.get.dimensions.map(_.size)
    val valueBytes = t.value.toByteArray

    val tn = t.base match {
      case BasicType.BOOLEAN =>
        tensor.Tensor.create(BooleanArraySerializer.read(valueBytes), dimensions)
      case BasicType.BYTE =>
        tensor.Tensor.create(ByteArraySerializer.read(valueBytes), dimensions)
      case BasicType.SHORT =>
        tensor.Tensor.create(ShortArraySerializer.read(valueBytes), dimensions)
      case BasicType.INT =>
        tensor.Tensor.create(IntArraySerializer.read(valueBytes), dimensions)
      case BasicType.LONG =>
        tensor.Tensor.create(LongArraySerializer.read(valueBytes), dimensions)
      case BasicType.FLOAT =>
        tensor.Tensor.create(FloatArraySerializer.read(valueBytes), dimensions)
      case BasicType.DOUBLE =>
        tensor.Tensor.create(DoubleArraySerializer.read(valueBytes), dimensions)
      case BasicType.STRING =>
        tensor.Tensor.create(StringArraySerializer.read(valueBytes), dimensions)
      case BasicType.BYTE_STRING =>
        tensor.Tensor.create(ByteStringArraySerializer.read(valueBytes), dimensions)
      case _ => throw new IllegalArgumentException(s"unsupported tensor type ${t.base}")
    }

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions