Skip to content

TensorBoard cannot correctly recognize the corresponding scalars or tensors tags which tfEvent file generate by tensorboard-sdk-scala3 #7012

@mullerhai

Description

@mullerhai

Hello, I need your help. I'm developing the TensorBoard Scala 3 SDK normally. The tfevent files generated using Scala 3 can be read correctly in the Scala 3 - TensorBoard environment, and all elements are present. Additionally, I also generated tfevent files using the TensorBoard component of PyTorch in a Python environment, which can also be read correctly in Scala 3 - TensorBoard.
I used Scala 3 - TensorBoard to generate scalars and tensors respectively, but TensorBoard cannot correctly recognize the corresponding scalars or tensors tags. I'm really confused. I've attached the generated files. I hope to get your investigation results. I've compared them myself and I think they're quite similar to the files generated by PyTorch. I used the components such as Event generated from the provided proto files to write the tfEvent files.
Could it be a version issue, an encoding problem, a byte misalignment, or is there some missing declaration? It's very difficult for me to determine, as tfevent is a binary file and I can't see its internal structure. I sincerely request that you help me investigate the cause. I'd be extremely grateful. If we can successfully identify the issue, congratulations! We'll have the official version of the TensorBoard - Scala 3 SDK. I'm really eager to use TensorBoard normally in Scala 3. Thank you for your hard work.
because of can not upload tfevent for this issue

scala3-sdk-tensorboard generate tfevent:

the file url https://github.com/mullerhai/storchBoard/blob/master/logz/train2-scala3-scalars-gen
https://github.com/mullerhai/storchBoard/blob/master/logz/train3-scala3-tensor-gen.tfevents
https://github.com/mullerhai/storchBoard/blob/master/logz/train_scala3.tfevents

pytorch generate tfevent
https://github.com/mullerhai/storchBoard/blob/master/logz/events.tensor.tfevents
https://github.com/mullerhai/storchBoard/blob/master/logz/events.out.tfevents

the scala3 to generate tfevent file code

package example

import org.tensorflow.framework.summary.{Summary, SummaryMetadata}
import org.tensorflow.framework.tensor.TensorProto
import org.tensorflow.framework.tensor_shape.TensorShapeProto
import org.tensorflow.framework.types.DataType
import org.tensorflow.util.event.Event
import java.io.{DataInputStream, DataOutputStream, FileInputStream, FileOutputStream}
import java.nio.ByteBuffer
import java.util.zip.CRC32

object TFRecordExampleScala3 extends App {
  val logDir = "logz"
  val logFilePath = s"$logDir/train_scala3.tfevents"

  // 写入日志文件
  writeToLogFile(logFilePath)

  // 读取日志文件
  readFromLogFile(logFilePath)

  def writeToLogFile(filePath: String): Unit = {
    val fileOutputStream = new FileOutputStream(filePath)
    val dataOutputStream = new DataOutputStream(fileOutputStream)

    try {
      // 模拟训练过程
      val numSteps = 50
      for (step <- 0L until numSteps) {
        // 模拟 loss 和 accuracy
        val loss = Math.exp(-step / 10.0)
        val accuracy = 1 - 0.5 * Math.exp(-step / 20.0)

        // 记录 loss
        writeEvent(dataOutputStream, "Loss/train", loss, step)
        // 记录 accuracy
        writeEvent(dataOutputStream, "Accuracy/train", accuracy, step)
      }
    } finally {
      dataOutputStream.close()
    }
  }

  private def writeEvent(outputStream: DataOutputStream, tag: String, value: Double, step: Long): Unit = {
    // 创建 TensorProto 表示标量值
    val tensorProto = TensorProto(
      dtype = DataType.DT_FLOAT,
      tensorShape = Some(TensorShapeProto()),
      floatVal = Seq(value.toFloat)
    )

    // 创建 SummaryMetadata
    val metadata = SummaryMetadata()

    // 创建 SummaryValue
    val summaryValue = Summary.Value(
      tag = tag,
      metadata = Some(metadata),
      value = Summary.Value.Value.Tensor(tensorProto)
    )

    // 创建 Summary
    val summary = Summary(
      value = Seq(summaryValue)
    )

    // 创建 Event 消息
    val event = Event(
      wallTime = System.currentTimeMillis() / 1000.0,
      step = step,
      what = Event.What.Summary(summary)
    )

    // 序列化 Event 为字节数组
    val serializedEvent = event.toByteArray

    // 写入记录长度
    val length = serializedEvent.length.toLong
    val lengthBytes = ByteBuffer.allocate(8).putLong(length).array()
    outputStream.write(lengthBytes)

    // 写入长度的 CRC 校验
    val lengthCrc = calculateCrc(lengthBytes)
    outputStream.writeInt(lengthCrc)

    // 写入数据
    outputStream.write(serializedEvent)

    // 写入数据的 CRC 校验
    val dataCrc = calculateCrc(serializedEvent)
    outputStream.writeInt(dataCrc)
  }

  private def calculateCrc(data: Array[Byte]): Int = {
    val crc32 = new CRC32()
    crc32.update(data)
    (crc32.getValue & 0xFFFFFFFFL).toInt
  }

  def readFromLogFile(filePath: String): Unit = {
    val fileInputStream = new FileInputStream(filePath)
    val dataInputStream = new DataInputStream(fileInputStream)

    try {
      while (dataInputStream.available() > 0) {
        // 读取记录长度
        val lengthBytes = new Array[Byte](8)
        dataInputStream.readFully(lengthBytes)
        val lengthBuffer = ByteBuffer.wrap(lengthBytes)
        val length = lengthBuffer.getLong()

        // 读取长度的 CRC 校验
        val expectedLengthCrc = dataInputStream.readInt()
        val actualLengthCrc = calculateCrc(lengthBytes)
        if (expectedLengthCrc != actualLengthCrc) {
          throw new RuntimeException("Length CRC mismatch")
        }

        // 读取数据
        val data = new Array[Byte](length.toInt)
        dataInputStream.readFully(data)

        // 读取数据的 CRC 校验
        val expectedDataCrc = dataInputStream.readInt()
        val actualDataCrc = calculateCrc(data)
        if (expectedDataCrc != actualDataCrc) {
          throw new RuntimeException("Data CRC mismatch")
        }

        // 解析 Event
        val event = Event.parseFrom(data)
        event.what match {
          case Event.What.Summary(summary) =>
            summary.value.foreach { summaryValue =>
              summaryValue.value match {
                case Summary.Value.Value.Tensor(tensorProto) =>
                  tensorProto.floatVal.headOption.foreach { floatValue =>
                    println(s"Step: ${event.step}, Tag: ${summaryValue.tag}, Value: $floatValue")
                  }
                case _ =>
                  println(s"Unsupported value type for tag: ${summaryValue.tag}")
              }
            }
          case _ =>
            println("No summary found in the event")
        }
      }
    } catch {
      case e: Exception =>
        println(s"Error reading events: ${e.getMessage}")
    } finally {
      dataInputStream.close()
    }
  }
}


Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions