Skip to content

Instantly share code, notes, and snippets.

@hjanuschka
Last active December 15, 2025 01:13
Show Gist options
  • Select an option

  • Save hjanuschka/474536a3484248bb87620d91689a2ab3 to your computer and use it in GitHub Desktop.

Select an option

Save hjanuschka/474536a3484248bb87620d91689a2ab3 to your computer and use it in GitHub Desktop.
JXL Decoder: Chromium wrapper using jxl-rs typestate API

JXL Decoder for Chromium using jxl-rs

This is the Chromium/Blink JXL image decoder implementation using the pure-Rust jxl-rs library.

Architecture

┌─────────────────────────────────────────────────────────────┐
│                    Blink ImageDecoder                       │
│                  (jxl_image_decoder.cc/h)                   │
├─────────────────────────────────────────────────────────────┤
│                     CXX FFI Bridge                          │
│                   (wrapper_lib.rs)                          │
│              Uses JxlDecoderInner directly                  │
├─────────────────────────────────────────────────────────────┤
│                      jxl-rs library                         │
│              (pure Rust JPEG XL decoder)                    │
└─────────────────────────────────────────────────────────────┘

Key Design Decisions

Type-Erased API (JxlDecoderInner)

The wrapper uses JxlDecoderInner - jxl-rs's type-erased API - instead of the typestate pattern. This provides:

  • Simpler code: No complex state machine with DecoderState enum
  • Direct reset/rewind: decoder.reset() and decoder.rewind() work on &mut self
  • Unified process(): Single method handles all decoding stages

Animation Jitter Fix

JXL requires sequential frame decoding - unlike WebP/GIF/AVIF which support random frame access. To prevent animation jitter on pages with multiple animations:

  1. Eager decoding: DecodeAllFrames() decodes all frames upfront when animation is detected
  2. Cache preservation: ClearCacheExceptFrame() keeps ALL frames once all_frames_decoded_ is true

This trades memory for smooth playback - evicting any frame would require expensive re-decode from frame 0.

Pixel Formats

Supports native output in:

  • RGBA8: Standard 8-bit images
  • RGBA_F16: High bit depth / HDR images (when HighBitDepthDecodingOption::kHighBitDepthToHalfFloat)

Files

File Description
jxl_image_decoder.h Header with JXLImageDecoder class definition
jxl_image_decoder.cc Implementation of Blink's ImageDecoder interface
wrapper_lib.rs Rust FFI wrapper using CXX, provides C++-compatible API

Building

Part of Chromium's build system. The Rust wrapper is built via GN/Ninja with CXX code generation.

autoninja -C out/Default chrome

Testing

out/Default/blink_platform_unittests --gtest_filter='*JXL*'
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "third_party/blink/renderer/platform/image-decoders/jxl/jxl_image_decoder.h"
#include <array>
#include "base/containers/span.h"
#include "base/logging.h"
#include "base/numerics/byte_conversions.h"
#include "base/numerics/checked_math.h"
#include "base/time/time.h"
#include "third_party/blink/renderer/platform/image-decoders/fast_shared_buffer_reader.h"
#include "third_party/skia/include/core/SkColorSpace.h"
#include "third_party/skia/include/core/SkTypes.h"
namespace blink {
namespace {
// The maximum JXL file size we are willing to decode. This helps prevent
// resource exhaustion from malicious files. Matches AVIF decoder limit.
constexpr uint64_t kMaxJxlFileSize = 0x10000000; // 256 MB
} // namespace
JXLImageDecoder::JXLImageDecoder(AlphaOption alpha_option,
HighBitDepthDecodingOption hbd_option,
ColorBehavior color_behavior,
cc::AuxImage aux_image,
wtf_size_t max_decoded_bytes,
AnimationOption animation_option)
: ImageDecoder(alpha_option,
hbd_option,
color_behavior,
aux_image,
max_decoded_bytes) {
basic_info_ = {};
basic_info_.have_animation = false;
}
JXLImageDecoder::~JXLImageDecoder() = default;
String JXLImageDecoder::FilenameExtension() const {
return "jxl";
}
const AtomicString& JXLImageDecoder::MimeType() const {
DEFINE_STATIC_LOCAL(const AtomicString, jxl_mime_type, ("image/jxl"));
return jxl_mime_type;
}
bool JXLImageDecoder::ImageIsHighBitDepth() {
return is_high_bit_depth_;
}
void JXLImageDecoder::OnSetData(scoped_refptr<SegmentReader> data) {
// OnSetData is called when more data becomes available for the same image.
// We should NOT reset metadata state here - that would destroy animation
// info. The Decode() method handles feeding data to the decoder and rewinding
// for animation loops when needed.
//
// Note: Unlike OnSetData implementations that reset state, we preserve:
// - basic_info_ (image dimensions, animation settings)
// - frame_info_ (frame durations, headers)
// - have_metadata_ (whether we've parsed the header)
// - all_frames_discovered_ (whether we know all frames)
//
// The decoder will be fed fresh data in Decode() which handles this properly.
}
bool JXLImageDecoder::MatchesJXLSignature(
const FastSharedBufferReader& fast_reader) {
uint8_t buffer[12];
if (fast_reader.size() < sizeof(buffer)) {
return false;
}
auto data = fast_reader.GetConsecutiveData(0, sizeof(buffer), buffer);
return jxl_rs_signature_check(
rust::Slice<const uint8_t>(data.data(), data.size()));
}
void JXLImageDecoder::DecodeSize() {
Decode(0, /*only_size=*/true);
}
wtf_size_t JXLImageDecoder::DecodeFrameCount() {
// IMPORTANT: Must parse metadata FIRST to know if this is an animation!
// Otherwise basic_info_.have_animation will be false (default) and we'll
// incorrectly return 1 frame.
if (!have_metadata_) {
Decode(0, /*only_size=*/true);
}
if (!basic_info_.have_animation) {
return 1;
}
// Return the number of frames we've discovered so far.
// Per the DecodeFrameCount API, it's valid to return an increasing count
// as frames are received and parsed (like PNG decoder does).
wtf_size_t count = frame_info_.size();
if (count == 0) {
count = 1;
}
// Ensure frame buffer cache is large enough.
if (frame_buffer_cache_.size() < count) {
frame_buffer_cache_.resize(count);
}
DVLOG(1) << "JXL DecodeFrameCount: " << count
<< " all_discovered=" << all_frames_discovered_
<< " have_animation=" << basic_info_.have_animation;
return count;
}
void JXLImageDecoder::InitializeNewFrame(wtf_size_t index) {
DCHECK_LT(index, frame_buffer_cache_.size());
auto& buffer = frame_buffer_cache_[index];
if (is_high_bit_depth_ &&
high_bit_depth_decoding_option_ == kHighBitDepthToHalfFloat) {
buffer.SetPixelFormat(ImageFrame::PixelFormat::kRGBA_F16);
}
buffer.SetHasAlpha(basic_info_.has_alpha);
buffer.SetPremultiplyAlpha(premultiply_alpha_);
buffer.SetOriginalFrameRect(gfx::Rect(Size()));
buffer.SetRequiredPreviousFrameIndex(kNotFound);
if (index < frame_info_.size()) {
buffer.SetDuration(frame_info_[index].duration);
// Calculate timestamp as sum of all previous frame durations.
base::TimeDelta timestamp;
for (wtf_size_t i = 0; i < index; ++i) {
timestamp += frame_info_[i].duration;
}
buffer.SetTimestamp(timestamp);
}
}
void JXLImageDecoder::Decode(wtf_size_t index) {
Decode(index, false);
}
void JXLImageDecoder::Decode(wtf_size_t index, bool only_size) {
if (Failed()) {
return;
}
// Check file size limit.
if (data_ && data_->size() > kMaxJxlFileSize) {
SetFailed();
return;
}
if (only_size && IsDecodedSizeAvailable() && have_metadata_) {
return;
}
// Early return if the requested frame is already fully decoded and cached.
// This avoids unnecessary re-decoding during animation loops.
if (!only_size && index < frame_buffer_cache_.size()) {
auto status = frame_buffer_cache_[index].GetStatus();
if (status == ImageFrame::kFrameComplete) {
return; // Frame is already cached.
}
}
// For animations, decode ALL frames when first requested.
// Unlike WebP/GIF which can seek to individual frames via their demuxer APIs,
// JXL must decode sequentially. Without eager decoding, requesting frame N
// while at frame M (M < N) would block while decoding M+1 through N,
// causing animation timing jitter.
if (!only_size && basic_info_.have_animation && IsAllDataReceived() &&
all_frames_discovered_ && !all_frames_decoded_) {
DecodeAllFrames();
// After decoding all frames, the requested frame should be cached.
if (index < frame_buffer_cache_.size() &&
frame_buffer_cache_[index].GetStatus() == ImageFrame::kFrameComplete) {
return;
}
}
// Get input data from Blink's buffer (no copying needed).
FastSharedBufferReader reader(data_.get());
size_t data_size = reader.size();
// Determine if we need to rewind the decoder.
bool need_rewind = false;
if (decoder_.has_value()) {
// Rewind when transitioning from metadata scan to actual decode.
// During metadata scan we process frames to discover count/durations,
// so we need to rewind to decode actual pixel data from the beginning.
if (!only_size && all_frames_discovered_ && num_decoded_frames_ == 0) {
need_rewind = true;
}
// Rewind for animation loop: requesting a frame before what we've decoded,
// but only if the frame isn't already cached to avoid re-decoding.
if (!only_size && basic_info_.have_animation) {
bool frame_already_cached =
index < frame_buffer_cache_.size() &&
frame_buffer_cache_[index].GetStatus() == ImageFrame::kFrameComplete;
if (!frame_already_cached) {
// Only rewind if we're truly going backwards (like looping to frame 0).
// Don't rewind if we're just continuing forward or filling gaps.
bool is_sequential_or_forward = index >= num_decoded_frames_;
if (!is_sequential_or_forward) {
// We're requesting a frame before what we've decoded.
// This is a loop/rewind situation.
need_rewind = true;
}
}
}
}
if (need_rewind) {
// Use rewind() for animations (preserves pixel format), reset() otherwise.
if (basic_info_.have_animation) {
(*decoder_)->rewind();
} else {
(*decoder_)->reset();
}
num_decoded_frames_ = 0;
num_frame_events_in_scan_ = 0;
input_offset_ = 0; // Reset input position for rewind.
// Note: We preserve all_frames_discovered_ - once we know the frame count,
// we don't need to re-scan. Only reset it if we're doing a fresh metadata
// scan (only_size=true), not when rewinding for pixel decode.
if (only_size) {
all_frames_discovered_ = false;
}
// Note: We don't clear frame pixel data here because:
// 1. For animations, ClearCacheExceptFrame() prevents clearing, so frames
// remain cached and we'll return early at the top of Decode().
// 2. For non-animated images, there's only one frame.
// 3. If a frame was externally cleared, its status is already kFrameEmpty.
}
// Create decoder if needed.
if (!decoder_.has_value()) {
decoder_ = jxl_rs_decoder_create();
num_decoded_frames_ = 0;
input_offset_ = 0;
}
// Process until we get what we need.
// Data is passed directly to the decoder without buffering.
for (;;) {
// Get remaining input data from current offset.
size_t remaining_size = data_size - input_offset_;
if (remaining_size == 0 && !IsAllDataReceived()) {
// No more data available yet, wait for more.
return;
}
// Use a local buffer for GetConsecutiveData - this is just for the
// SegmentReader interface, the actual data comes from Blink's data_.
// We read in chunks to avoid allocating huge buffers for large files.
constexpr size_t kMaxChunkSize = 1024 * 1024; // 1MB chunks
size_t chunk_size = std::min(remaining_size, kMaxChunkSize);
Vector<uint8_t> chunk_buffer(chunk_size);
auto data_span = reader.GetConsecutiveData(input_offset_, chunk_size,
base::span(chunk_buffer));
JxlRsProcessResult result = (*decoder_)->process(
rust::Slice<const uint8_t>(data_span.data(), data_span.size()),
IsAllDataReceived() && (input_offset_ + chunk_size >= data_size));
JxlRsStatus status = result.status;
switch (status) {
case JxlRsStatus::Error:
SetFailed();
return;
case JxlRsStatus::NeedMoreInput:
// Don't advance input_offset_ - the decoder needs to see the same
// bytes again on the next call with more data appended.
if (IsAllDataReceived()) {
SetFailed();
}
return;
case JxlRsStatus::BasicInfo: {
basic_info_ = (*decoder_)->get_basic_info();
if (!SetSize(basic_info_.width, basic_info_.height)) {
return;
}
// Check for HDR.
if (basic_info_.bits_per_sample > 8) {
is_high_bit_depth_ = true;
}
// Enable F16 decoding for high bit depth images.
decode_to_half_float_ =
ImageIsHighBitDepth() &&
high_bit_depth_decoding_option_ == kHighBitDepthToHalfFloat;
// Configure decoder for F16 output when high bit depth.
if (decode_to_half_float_) {
(*decoder_)->set_pixel_format(JxlRsPixelFormat::RgbaF16);
}
// Extract and set ICC color profile for wide gamut support.
// Skip if color management is disabled (ColorBehavior::kIgnore).
if (!IgnoresColorSpace()) {
auto icc_data = (*decoder_)->get_icc_profile();
if (!icc_data.empty()) {
// Copy ICC data to a Vector for safe span access.
Vector<uint8_t> icc_copy;
icc_copy.AppendRange(icc_data.begin(), icc_data.end());
auto profile = ColorProfile::Create(base::span(icc_copy));
if (profile) {
SetEmbeddedColorProfile(std::move(profile));
}
}
}
have_metadata_ = true;
// For animations, reserve space for first frame info.
// The actual frame info will be filled in when we get the Frame event.
if (basic_info_.have_animation && frame_info_.empty()) {
frame_info_.resize(1);
}
// In only_size mode, we must continue processing to discover all
// frames, so we don't return here, just break.
break;
}
case JxlRsStatus::Frame: {
JxlRsFrameHeader header = (*decoder_)->get_frame_header();
if (basic_info_.have_animation) {
// Frame duration is already in milliseconds from jxl-rs.
FrameInfo info;
info.header = header;
info.duration = base::Milliseconds(header.duration);
info.received = false;
// Determine frame index based on mode.
wtf_size_t frame_idx =
only_size ? num_frame_events_in_scan_ : num_decoded_frames_;
if (frame_idx < frame_info_.size()) {
// Update existing entry (might be from a previous scan).
frame_info_[frame_idx] = info;
} else {
// Add new frame info.
frame_info_.push_back(info);
DVLOG(1) << "JXL discovered frame " << frame_idx
<< " (total: " << frame_info_.size() << ")"
<< " only_size=" << only_size;
}
}
break;
}
case JxlRsStatus::FullImage: {
if (only_size) {
// In metadata scan mode, we don't decode pixels, just update the
// frame count and continue scanning for more frames.
num_frame_events_in_scan_++;
if (!(*decoder_)->has_more_frames()) {
input_offset_ += result.bytes_consumed;
all_frames_discovered_ = true;
return; // End of metadata scan.
}
// Note: Don't advance input_offset_ here - it will be advanced
// after the switch when we continue scanning.
break; // Continue scanning.
}
// Full decode logic.
wtf_size_t frame_index = num_decoded_frames_;
// Ensure frame buffer cache is large enough.
if (frame_buffer_cache_.size() <= frame_index) {
frame_buffer_cache_.resize(frame_index + 1);
}
if (!InitFrameBuffer(frame_index)) {
SetFailed();
return;
}
ImageFrame& frame = frame_buffer_cache_[frame_index];
frame.SetHasAlpha(basic_info_.has_alpha);
base::CheckedNumeric<size_t> checked_pixel_count =
base::CheckMul(basic_info_.width, basic_info_.height);
if (!checked_pixel_count.IsValid()) {
SetFailed();
return;
}
const size_t pixel_count = checked_pixel_count.ValueOrDie();
bool premultiply = frame.PremultiplyAlpha() && frame.HasAlpha();
if (decode_to_half_float_) {
// Native F16 path for wide gamut/HDR.
// jxl-rs outputs F16 directly, 4 channels * 2 bytes = 8 bytes/pixel.
base::CheckedNumeric<size_t> checked_size =
base::CheckMul(pixel_count, 4, sizeof(uint16_t));
if (!checked_size.IsValid()) {
SetFailed();
return;
}
size_t f16_pixel_size = checked_size.ValueOrDie();
if (pixel_buffer_.size() < f16_pixel_size) {
pixel_buffer_.resize(f16_pixel_size);
}
// Get F16 pixels directly from decoder.
auto pixel_span =
rust::Slice<uint8_t>(pixel_buffer_.data(), f16_pixel_size);
if ((*decoder_)->get_pixels(pixel_span) != JxlRsStatus::Success) {
SetFailed();
return;
}
// Copy F16 pixels to frame buffer.
// Use row-based iteration to avoid per-pixel division.
base::span<const uint8_t> buffer_bytes(pixel_buffer_);
const uint32_t width = basic_info_.width;
const uint32_t height = basic_info_.height;
for (uint32_t y = 0; y < height; ++y) {
for (uint32_t x = 0; x < width; ++x) {
// Calculate byte offset for this pixel (row-major layout).
// Each pixel is 4 F16 values = 8 bytes.
size_t byte_offset = (y * width + x) * 8;
auto pixel_bytes = buffer_bytes.subspan(byte_offset, 8u);
// Read F16 values (jxl-rs outputs in native endianness).
uint16_t r = base::U16FromNativeEndian(pixel_bytes.subspan<0, 2>());
uint16_t g = base::U16FromNativeEndian(pixel_bytes.subspan<2, 2>());
uint16_t b = base::U16FromNativeEndian(pixel_bytes.subspan<4, 2>());
uint16_t a = base::U16FromNativeEndian(pixel_bytes.subspan<6, 2>());
// TODO(nicholassig): Premultiply in F16 if needed.
// For now, premultiplication is not applied to F16 output.
// This matches the behavior of other HDR decoders.
(void)premultiply;
uint64_t* dst = frame.GetAddrF16(x, y);
*dst = (static_cast<uint64_t>(a) << 48) |
(static_cast<uint64_t>(b) << 32) |
(static_cast<uint64_t>(g) << 16) |
static_cast<uint64_t>(r);
}
}
} else {
// U8 path for standard 8-bit images.
base::CheckedNumeric<size_t> checked_size =
base::CheckMul(pixel_count, 4);
if (!checked_size.IsValid()) {
SetFailed();
return;
}
size_t pixel_size = checked_size.ValueOrDie();
if (pixel_buffer_.size() < pixel_size) {
pixel_buffer_.resize(pixel_size);
}
// Get U8 pixels from decoder.
auto pixel_span =
rust::Slice<uint8_t>(pixel_buffer_.data(), pixel_size);
if ((*decoder_)->get_pixels(pixel_span) != JxlRsStatus::Success) {
SetFailed();
return;
}
// Use row-based iteration to avoid per-pixel division.
base::span<const uint8_t> src_bytes(pixel_buffer_);
const uint32_t width = basic_info_.width;
const uint32_t height = basic_info_.height;
const size_t row_stride = width * 4;
if (premultiply) {
for (uint32_t y = 0; y < height; ++y) {
auto row = src_bytes.subspan(y * row_stride, row_stride);
for (uint32_t x = 0; x < width; ++x) {
auto pixel = row.subspan(x * 4, 4u);
uint8_t r = pixel[0];
uint8_t g = pixel[1];
uint8_t b = pixel[2];
uint8_t a = pixel[3];
// Fast premultiplication: (x * a + 127) / 255 ≈ (x * a + 128)
// >> 8.
r = (r * a + 128) >> 8;
g = (g * a + 128) >> 8;
b = (b * a + 128) >> 8;
ImageFrame::PixelData* dst = frame.GetAddr(x, y);
*dst = (a << SK_A32_SHIFT) | (r << SK_R32_SHIFT) |
(g << SK_G32_SHIFT) | (b << SK_B32_SHIFT);
}
}
} else {
for (uint32_t y = 0; y < height; ++y) {
auto row = src_bytes.subspan(y * row_stride, row_stride);
for (uint32_t x = 0; x < width; ++x) {
auto pixel = row.subspan(x * 4, 4u);
ImageFrame::PixelData* dst = frame.GetAddr(x, y);
*dst = (pixel[3] << SK_A32_SHIFT) | (pixel[0] << SK_R32_SHIFT) |
(pixel[1] << SK_G32_SHIFT) | (pixel[2] << SK_B32_SHIFT);
}
}
}
}
frame.SetPixelsChanged(true);
frame.SetStatus(ImageFrame::kFrameComplete);
if (frame_index < frame_info_.size()) {
frame_info_[frame_index].received = true;
}
num_decoded_frames_++;
// Check if we've decoded the requested frame.
if (frame_index >= index) {
input_offset_ += result.bytes_consumed;
return;
}
// Check for more frames.
if (!(*decoder_)->has_more_frames()) {
all_frames_discovered_ = true;
}
break;
}
case JxlRsStatus::Success:
input_offset_ += result.bytes_consumed;
all_frames_discovered_ = true;
return;
default:
SetFailed();
return;
}
// Advance input offset after successful processing.
// (NeedMoreInput returns early above without advancing.)
input_offset_ += result.bytes_consumed;
}
}
bool JXLImageDecoder::CanReusePreviousFrameBuffer(
wtf_size_t frame_index) const {
DCHECK(frame_index < frame_buffer_cache_.size());
return true;
}
bool JXLImageDecoder::FrameIsReceivedAtIndex(wtf_size_t index) const {
return IsAllDataReceived() ||
(index < frame_buffer_cache_.size() &&
frame_buffer_cache_[index].GetStatus() == ImageFrame::kFrameComplete);
}
std::optional<base::TimeDelta> JXLImageDecoder::FrameTimestampAtIndex(
wtf_size_t index) const {
return index < frame_buffer_cache_.size()
? frame_buffer_cache_[index].Timestamp()
: std::nullopt;
}
base::TimeDelta JXLImageDecoder::FrameDurationAtIndex(wtf_size_t index) const {
return index < frame_buffer_cache_.size()
? frame_buffer_cache_[index].Duration()
: base::TimeDelta();
}
int JXLImageDecoder::RepetitionCount() const {
if (!basic_info_.have_animation) {
return kAnimationNone;
}
if (basic_info_.animation_loop_count == 0) {
return kAnimationLoopInfinite;
}
return basic_info_.animation_loop_count;
}
wtf_size_t JXLImageDecoder::ClearCacheExceptFrame(
wtf_size_t clear_except_frame) {
// For animated JXL images that have been fully decoded, keep ALL frames.
// JXL requires sequential decoding - evicting any frame means re-decoding
// the entire animation from frame 0, which causes jitter on pages with
// multiple animations. Trade-off: Uses more memory but ensures smooth
// playback.
if (basic_info_.have_animation && all_frames_decoded_) {
return 0; // Keep all frames cached - don't clear anything
}
// For animations still being decoded, keep current and previous frame
// to avoid flicker if the compositor briefly references the previous frame.
if (basic_info_.have_animation && clear_except_frame != kNotFound) {
const wtf_size_t previous_frame =
clear_except_frame ? clear_except_frame - 1 : kNotFound;
return ClearCacheExceptTwoFrames(clear_except_frame, previous_frame);
}
return ImageDecoder::ClearCacheExceptFrame(clear_except_frame);
}
SkColorType JXLImageDecoder::GetSkColorType() const {
if (is_high_bit_depth_ &&
high_bit_depth_decoding_option_ == kHighBitDepthToHalfFloat) {
return kRGBA_F16_SkColorType;
}
return kN32_SkColorType;
}
void JXLImageDecoder::DecodeAllFrames() {
if (all_frames_decoded_ || Failed()) {
return;
}
// Mark as decoded first to prevent re-entry.
all_frames_decoded_ = true;
wtf_size_t total_frames = frame_info_.size();
if (total_frames == 0) {
return;
}
// Decode each frame sequentially: 0, 1, 2, ...
// This is simpler and more reliable than trying to decode the last frame
// and relying on the decode loop to fill in all previous frames.
for (wtf_size_t i = 0; i < total_frames && !Failed(); ++i) {
// Skip if already decoded.
if (i < frame_buffer_cache_.size() &&
frame_buffer_cache_[i].GetStatus() == ImageFrame::kFrameComplete) {
continue;
}
Decode(i, /*only_size=*/false);
}
// Verify all frames are actually decoded.
if (!Failed()) {
for (wtf_size_t i = 0; i < total_frames && i < frame_buffer_cache_.size();
++i) {
if (frame_buffer_cache_[i].GetStatus() != ImageFrame::kFrameComplete) {
all_frames_decoded_ = false;
break;
}
}
}
}
} // namespace blink
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef THIRD_PARTY_BLINK_RENDERER_PLATFORM_IMAGE_DECODERS_JXL_JXL_IMAGE_DECODER_H_
#define THIRD_PARTY_BLINK_RENDERER_PLATFORM_IMAGE_DECODERS_JXL_JXL_IMAGE_DECODER_H_
#include <memory>
#include <optional>
#include <vector>
#include "third_party/blink/renderer/platform/image-decoders/image_decoder.h"
#include "third_party/blink/renderer/platform/wtf/text/atomic_string.h"
#include "third_party/blink/renderer/platform/wtf/vector.h"
#include "third_party/rust/jxl/v0_1/wrapper/lib.rs.h"
#include "third_party/skia/include/core/SkImageInfo.h"
#include "ui/gfx/geometry/point.h"
namespace blink {
class FastSharedBufferReader;
class PLATFORM_EXPORT JXLImageDecoder final : public ImageDecoder {
public:
JXLImageDecoder(AlphaOption,
HighBitDepthDecodingOption,
ColorBehavior,
cc::AuxImage,
wtf_size_t max_decoded_bytes,
AnimationOption);
JXLImageDecoder(const JXLImageDecoder&) = delete;
JXLImageDecoder& operator=(const JXLImageDecoder&) = delete;
~JXLImageDecoder() override;
// ImageDecoder:
String FilenameExtension() const override;
const AtomicString& MimeType() const override;
bool ImageIsHighBitDepth() override;
void OnSetData(scoped_refptr<SegmentReader> data) override;
int RepetitionCount() const override;
bool FrameIsReceivedAtIndex(wtf_size_t) const override;
std::optional<base::TimeDelta> FrameTimestampAtIndex(
wtf_size_t) const override;
base::TimeDelta FrameDurationAtIndex(wtf_size_t) const override;
wtf_size_t ClearCacheExceptFrame(wtf_size_t) override;
// Returns true if the data in fast_reader begins with a valid JXL signature.
static bool MatchesJXLSignature(const FastSharedBufferReader& fast_reader);
private:
// CXX-managed Rust Box for JxlRsDecoder.
using JxlRsDecoderPtr = rust::Box<JxlRsDecoder>;
// Frame information tracked during decoding.
struct FrameInfo {
JxlRsFrameHeader header;
base::TimeDelta duration;
bool received = false;
};
// ImageDecoder:
void DecodeSize() override;
wtf_size_t DecodeFrameCount() override;
void InitializeNewFrame(wtf_size_t) override;
void Decode(wtf_size_t) override;
bool CanReusePreviousFrameBuffer(wtf_size_t) const override;
// Internal decode function that optionally stops after metadata.
void Decode(wtf_size_t index, bool only_size);
// Eagerly decode all animation frames upfront.
void DecodeAllFrames();
// Converts JXL pixel format to Skia color type.
SkColorType GetSkColorType() const;
// Decoder state.
std::optional<JxlRsDecoderPtr> decoder_;
JxlRsBasicInfo basic_info_{};
bool have_metadata_ = false;
wtf_size_t num_decoded_frames_ = 0;
size_t num_frame_events_in_scan_ = 0;
bool all_frames_discovered_ = false;
bool all_frames_decoded_ = false; // True after all animation frames decoded.
size_t input_offset_ = 0; // Current position in input stream.
// Animation frame tracking.
Vector<FrameInfo> frame_info_;
// Color management.
bool is_high_bit_depth_ = false;
bool decode_to_half_float_ = false;
// Pixel buffer for decoded frame.
Vector<uint8_t> pixel_buffer_;
};
} // namespace blink
#endif // THIRD_PARTY_BLINK_RENDERER_PLATFORM_IMAGE_DECODERS_JXL_JXL_IMAGE_DECODER_H_
//! CXX-based FFI wrapper for jxl-rs decoder.
//!
//! This provides a C++-compatible API for the jxl-rs decoder using the CXX crate,
//! designed for integration with Chromium's Blink image decoder infrastructure.
//!
//! Uses JxlDecoderInner directly (type-erased API) for simpler implementation.
use jxl::api::{
JxlBasicInfo, JxlDecoderInner, JxlDecoderOptions, JxlOutputBuffer, JxlPixelFormat,
JxlProgressiveMode, ProcessingResult, check_signature,
};
use jxl::headers::extra_channels::ExtraChannel;
#[cxx::bridge]
mod ffi {
#[derive(Debug)]
enum JxlRsStatus {
Success = 0,
Error = 1,
NeedMoreInput = 2,
BasicInfo = 3,
Frame = 5,
FullImage = 6,
}
#[derive(Debug)]
enum JxlRsPixelFormat {
Rgba8 = 0,
Rgba16 = 1,
RgbaF16 = 2,
RgbaF32 = 3,
}
#[derive(Debug, Clone)]
struct JxlRsBasicInfo {
width: u32,
height: u32,
bits_per_sample: u32,
num_extra_channels: u32,
has_alpha: bool,
alpha_premultiplied: bool,
have_animation: bool,
animation_loop_count: u32,
animation_tps_numerator: u32,
animation_tps_denominator: u32,
uses_original_profile: bool,
orientation: u32,
}
#[derive(Debug, Clone)]
struct JxlRsFrameHeader {
duration: u32,
is_last: bool,
name_length: u32,
}
/// Result of a process call, indicating bytes consumed and next status.
#[derive(Debug, Clone)]
struct JxlRsProcessResult {
status: JxlRsStatus,
bytes_consumed: usize,
}
extern "Rust" {
type JxlRsDecoder;
fn jxl_rs_decoder_create() -> Box<JxlRsDecoder>;
fn reset(self: &mut JxlRsDecoder);
/// Rewind for animation loop replay, preserving pixel format.
fn rewind(self: &mut JxlRsDecoder);
/// Process input data. Returns status and number of bytes consumed.
/// The caller should advance their input position by bytes_consumed.
fn process(self: &mut JxlRsDecoder, data: &[u8], all_input: bool) -> JxlRsProcessResult;
fn get_basic_info(self: &JxlRsDecoder) -> JxlRsBasicInfo;
fn get_frame_header(self: &JxlRsDecoder) -> JxlRsFrameHeader;
fn set_pixel_format(self: &mut JxlRsDecoder, format: JxlRsPixelFormat);
fn get_pixels(self: &mut JxlRsDecoder, buffer: &mut [u8]) -> JxlRsStatus;
fn get_icc_profile(self: &JxlRsDecoder) -> &[u8];
fn has_more_frames(self: &JxlRsDecoder) -> bool;
fn get_error(self: &JxlRsDecoder) -> &str;
fn jxl_rs_signature_check(data: &[u8]) -> bool;
fn jxl_rs_version() -> &'static str;
}
}
pub use ffi::{JxlRsBasicInfo, JxlRsFrameHeader, JxlRsPixelFormat, JxlRsProcessResult, JxlRsStatus};
// =============================================================================
// Decoder using JxlDecoderInner (type-erased API)
// =============================================================================
/// Tracks what stage the decoder has reached
#[derive(Debug, Clone, Copy, PartialEq)]
enum Stage {
/// Initial state, waiting for basic info
Initial,
/// Have basic info, waiting for frame header
HaveBasicInfo,
/// Have frame header, ready to decode pixels
HaveFrameHeader,
}
pub struct JxlRsDecoder {
decoder: JxlDecoderInner,
stage: Stage,
basic_info: JxlRsBasicInfo,
frame_header: JxlRsFrameHeader,
pixel_format: JxlRsPixelFormat,
pixel_format_set: bool,
icc_profile: Vec<u8>,
error_message: String,
pixel_buffer: Vec<u8>,
}
fn default_options() -> JxlDecoderOptions {
let mut opts = JxlDecoderOptions::default();
opts.xyb_output_linear = false;
opts.progressive_mode = JxlProgressiveMode::FullFrame;
opts
}
fn to_internal_pixel_format(format: JxlRsPixelFormat, num_extra_channels: usize) -> JxlPixelFormat {
// Create pixel format with extra channels set to None (ignore).
// This tells jxl-rs to output alpha as part of RGBA (JxlColorType::Rgba)
// rather than as separate extra channel buffers.
// We still need to specify the correct count so the internal assertion passes.
let base = match format {
JxlRsPixelFormat::Rgba8 => JxlPixelFormat::rgba8(0),
JxlRsPixelFormat::Rgba16 => JxlPixelFormat::rgba16(0),
JxlRsPixelFormat::RgbaF16 => JxlPixelFormat::rgba_f16(0),
JxlRsPixelFormat::RgbaF32 => JxlPixelFormat::rgba_f32(0),
_ => JxlPixelFormat::rgba8(0),
};
JxlPixelFormat {
color_type: base.color_type,
color_data_format: base.color_data_format,
// Set extra channels to None - we want alpha embedded in RGBA, not separate
extra_channel_format: vec![None; num_extra_channels],
}
}
fn jxl_rs_decoder_create() -> Box<JxlRsDecoder> {
Box::new(JxlRsDecoder {
decoder: JxlDecoderInner::new(default_options()),
stage: Stage::Initial,
basic_info: JxlRsBasicInfo::default(),
frame_header: JxlRsFrameHeader::default(),
pixel_format: JxlRsPixelFormat::Rgba8,
pixel_format_set: false,
icc_profile: Vec::new(),
error_message: String::new(),
pixel_buffer: Vec::new(),
})
}
impl JxlRsDecoder {
fn reset(&mut self) {
self.decoder.reset();
self.stage = Stage::Initial;
self.basic_info = JxlRsBasicInfo::default();
self.frame_header = JxlRsFrameHeader::default();
self.pixel_format_set = false;
self.icc_profile.clear();
self.error_message.clear();
// Keep pixel_buffer capacity for reuse
}
/// Rewind for animation loop replay, preserving pixel format.
fn rewind(&mut self) {
self.decoder.rewind();
self.stage = Stage::Initial;
// Keep basic_info - it won't change
self.frame_header = JxlRsFrameHeader::default();
// pixel_format_set stays true so we re-apply the format
// Keep icc_profile, pixel_buffer for reuse
self.error_message.clear();
}
fn process(&mut self, data: &[u8], all_input: bool) -> JxlRsProcessResult {
let mut input = data;
let input_len_before = input.len();
// Set pixel format if we have basic info but haven't set format yet
if self.stage != Stage::Initial && !self.pixel_format_set {
let num_extra = self.basic_info.num_extra_channels as usize;
let pixel_format = to_internal_pixel_format(self.pixel_format, num_extra);
self.decoder.set_pixel_format(pixel_format);
self.pixel_format_set = true;
}
// Determine if we need output buffers (only when decoding frame pixels)
let needs_pixels = self.stage == Stage::HaveFrameHeader;
let status = if needs_pixels {
// Decode frame pixels
let width = self.basic_info.width as usize;
let height = self.basic_info.height as usize;
let bytes_per_pixel = match self.pixel_format {
JxlRsPixelFormat::Rgba8 => 4,
JxlRsPixelFormat::Rgba16 => 8,
JxlRsPixelFormat::RgbaF16 => 8,
JxlRsPixelFormat::RgbaF32 => 16,
_ => 4,
};
let buffer_size = width * height * bytes_per_pixel;
self.pixel_buffer.resize(buffer_size, 0);
let bytes_per_row = width * bytes_per_pixel;
let output = JxlOutputBuffer::new(&mut self.pixel_buffer, height, bytes_per_row);
match self.decoder.process(&mut input, Some(&mut [output])) {
Ok(ProcessingResult::Complete { .. }) => {
// Frame decode complete, go back to HaveBasicInfo for next frame
self.stage = Stage::HaveBasicInfo;
JxlRsStatus::FullImage
}
Ok(ProcessingResult::NeedsMoreInput { .. }) => {
if all_input {
self.error_message = "Incomplete frame data".to_string();
JxlRsStatus::Error
} else {
JxlRsStatus::NeedMoreInput
}
}
Err(e) => {
self.error_message = format!("Frame decode error: {}", e);
JxlRsStatus::Error
}
}
} else {
// Process without output buffers (parsing headers)
match self.decoder.process(&mut input, None) {
Ok(ProcessingResult::Complete { .. }) => {
// Check what we got
match self.stage {
Stage::Initial => {
// Should have basic info now
if let Some(info) = self.decoder.basic_info() {
self.basic_info = JxlRsBasicInfo::from(info);
// Extract ICC profile
if let Some(color_profile) = self.decoder.embedded_color_profile() {
let icc = color_profile.as_icc();
if !icc.is_empty() {
self.icc_profile = icc.into_owned();
}
}
self.stage = Stage::HaveBasicInfo;
JxlRsStatus::BasicInfo
} else {
self.error_message = "No basic info after process".to_string();
JxlRsStatus::Error
}
}
Stage::HaveBasicInfo => {
// Should have frame header now
if let Some(fh) = self.decoder.frame_header() {
self.frame_header.duration = fh.duration.map(|d| d as u32).unwrap_or(0);
self.frame_header.is_last = false;
self.frame_header.name_length = fh.name.len() as u32;
self.stage = Stage::HaveFrameHeader;
JxlRsStatus::Frame
} else {
// No more frames
JxlRsStatus::Success
}
}
Stage::HaveFrameHeader => {
// Shouldn't happen - we handle this in the needs_pixels branch
JxlRsStatus::Success
}
}
}
Ok(ProcessingResult::NeedsMoreInput { .. }) => {
if all_input {
self.error_message = "Incomplete JXL data".to_string();
JxlRsStatus::Error
} else {
JxlRsStatus::NeedMoreInput
}
}
Err(e) => {
self.error_message = format!("Decoder error: {}", e);
JxlRsStatus::Error
}
}
};
JxlRsProcessResult {
status,
bytes_consumed: input_len_before - input.len(),
}
}
fn get_basic_info(&self) -> JxlRsBasicInfo {
self.basic_info.clone()
}
fn get_frame_header(&self) -> JxlRsFrameHeader {
self.frame_header.clone()
}
fn set_pixel_format(&mut self, format: JxlRsPixelFormat) {
self.pixel_format = format;
self.pixel_format_set = false; // Will be applied on next process()
}
fn get_pixels(&mut self, buffer: &mut [u8]) -> JxlRsStatus {
if self.pixel_buffer.is_empty() {
self.error_message = "No decoded image available".to_string();
return JxlRsStatus::Error;
}
if buffer.len() < self.pixel_buffer.len() {
self.error_message = "Buffer too small".to_string();
return JxlRsStatus::Error;
}
buffer[..self.pixel_buffer.len()].copy_from_slice(&self.pixel_buffer);
JxlRsStatus::Success
}
fn get_icc_profile(&self) -> &[u8] {
&self.icc_profile
}
fn has_more_frames(&self) -> bool {
self.decoder.has_more_frames()
}
fn get_error(&self) -> &str {
&self.error_message
}
}
fn jxl_rs_signature_check(data: &[u8]) -> bool {
data.len() >= 2
&& matches!(
check_signature(&data[..data.len().min(12)]),
ProcessingResult::Complete { result: Some(_) }
)
}
fn jxl_rs_version() -> &'static str {
"jxl-rs 0.1"
}
// =============================================================================
// Default Implementations
// =============================================================================
impl Default for JxlRsBasicInfo {
fn default() -> Self {
Self {
width: 0,
height: 0,
bits_per_sample: 8,
num_extra_channels: 0,
has_alpha: false,
alpha_premultiplied: false,
have_animation: false,
animation_loop_count: 0,
animation_tps_numerator: 1,
animation_tps_denominator: 1000,
uses_original_profile: false,
orientation: 1,
}
}
}
impl From<&JxlBasicInfo> for JxlRsBasicInfo {
fn from(info: &JxlBasicInfo) -> Self {
let has_alpha = info.extra_channels.iter().any(|ec| {
matches!(ec.ec_type, ExtraChannel::Alpha)
});
let (animation_loop_count, animation_tps_numerator, animation_tps_denominator) =
match &info.animation {
Some(anim) => (anim.num_loops, anim.tps_numerator, anim.tps_denominator),
None => (0, 1, 1000),
};
Self {
width: info.size.0 as u32,
height: info.size.1 as u32,
bits_per_sample: info.bit_depth.bits_per_sample(),
num_extra_channels: info.extra_channels.len() as u32,
has_alpha,
alpha_premultiplied: false,
have_animation: info.animation.is_some(),
animation_loop_count,
animation_tps_numerator,
animation_tps_denominator,
uses_original_profile: info.uses_original_profile,
orientation: info.orientation as u32,
}
}
}
impl Default for JxlRsFrameHeader {
fn default() -> Self {
Self {
duration: 0,
is_last: false,
name_length: 0,
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment