From 7c6e0cf3b6e9dd4816b569d7216e65a78a1114c1 Mon Sep 17 00:00:00 2001
From: rreddy-22 <rreddy@confluent.io>
Date: Thu, 3 Apr 2025 15:36:11 -0700
Subject: [PATCH 1/9] Add enable2pc to producer config

---
 .../clients/producer/ProducerConfig.java      | 21 +++++++++++++++++
 .../clients/producer/ProducerConfigTest.java  | 23 +++++++++++++++++++
 2 files changed, 44 insertions(+)

diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java
index 949c6c167ba8e..4cd21434d3658 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java
@@ -355,6 +355,11 @@ public class ProducerConfig extends AbstractConfig {
             "By default the TransactionId is not configured, which means transactions cannot be used. " +
             "Note that, by default, transactions require a cluster of at least three brokers which is the recommended setting for production; for development you can change this, by adjusting broker setting <code>transaction.state.log.replication.factor</code>.";
 
+    /** <code> transaction.two.phase.commit.enable </code> */
+    public static final String TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG = "transaction.two.phase.commit.enable";
+    private static final String TRANSACTION_TWO_PHASE_COMMIT_ENABLE_DOC = "If set to true, then the broker is informed that the client is participating in " +
+        "two phase commit protocol and transactions that this client starts never expire.";
+
     /**
      * <code>security.providers</code>
      */
@@ -526,6 +531,11 @@ public class ProducerConfig extends AbstractConfig {
                                         new ConfigDef.NonEmptyString(),
                                         Importance.LOW,
                                         TRANSACTIONAL_ID_DOC)
+                                .define(TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG,
+                                        Type.BOOLEAN,
+                                        false,
+                                        Importance.LOW,
+                                        TRANSACTION_TWO_PHASE_COMMIT_ENABLE_DOC)
                                 .define(CommonClientConfigs.METADATA_RECOVERY_STRATEGY_CONFIG,
                                         Type.STRING,
                                         CommonClientConfigs.DEFAULT_METADATA_RECOVERY_STRATEGY,
@@ -609,6 +619,17 @@ private void postProcessAndValidateIdempotenceConfigs(final Map<String, Object>
         if (!idempotenceEnabled && userConfiguredTransactions) {
             throw new ConfigException("Cannot set a " + ProducerConfig.TRANSACTIONAL_ID_CONFIG + " without also enabling idempotence.");
         }
+
+        // validate that transaction.timeout.ms is not set when transaction.two.phase.commit.enable is true
+        boolean enable2PC = this.getBoolean(TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG);
+        boolean userConfiguredTransactionTimeout = originalConfigs.containsKey(TRANSACTION_TIMEOUT_CONFIG);
+        if (enable2PC && userConfiguredTransactionTimeout) {
+            throw new ConfigException(
+                "Cannot set " + ProducerConfig.TRANSACTION_TIMEOUT_CONFIG +
+                " when " + ProducerConfig.TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG +
+                " is set to true. Transactions will not expire with two-phase commit enabled."
+            );
+        }
     }
 
     private static String parseAcks(String acksString) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/ProducerConfigTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/ProducerConfigTest.java
index 830711c0e5449..207bac6476fc1 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/ProducerConfigTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/ProducerConfigTest.java
@@ -145,4 +145,27 @@ void testUpperboundCheckOfEnableIdempotence() {
         configs.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, "5");
         assertDoesNotThrow(() -> new ProducerConfig(configs));
     }
+
+    @Test
+    void testTwoPhaseCommitIncompatibleWithTransactionTimeout() {
+        Map<String, Object> configs = new HashMap<>();
+        configs.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, keySerializerClass);
+        configs.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, valueSerializerClass);
+        configs.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, true);
+        configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "test-txn-id");
+        configs.put(ProducerConfig.TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG, true);
+        configs.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, 60000);
+        
+        ConfigException ce = assertThrows(ConfigException.class, () -> new ProducerConfig(configs));
+        assertTrue(ce.getMessage().contains(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG));
+        assertTrue(ce.getMessage().contains(ProducerConfig.TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG));
+        
+        // Verify that setting one but not the other is valid
+        configs.remove(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG);
+        assertDoesNotThrow(() -> new ProducerConfig(configs));
+        
+        configs.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, 60000);
+        configs.put(ProducerConfig.TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG, false);
+        assertDoesNotThrow(() -> new ProducerConfig(configs));
+    }
 }

From 5af2d8ce11e1704903fb05092ef189445c4d438a Mon Sep 17 00:00:00 2001
From: rreddy-22 <rreddy@confluent.io>
Date: Tue, 8 Apr 2025 15:15:11 -0700
Subject: [PATCH 2/9] Add enable2pc to txn manager and add overloaded method
 initProducerId(bool)

---
 .../kafka/clients/producer/KafkaProducer.java | 49 +++++++++++++++++--
 .../kafka/clients/producer/MockProducer.java  | 17 +++++++
 .../kafka/clients/producer/Producer.java      |  5 ++
 .../internals/TransactionManager.java         | 26 ++++++++--
 .../clients/producer/KafkaProducerTest.java   |  2 +-
 .../producer/internals/SenderTest.java        | 36 +++++++-------
 .../internals/TransactionManagerTest.java     |  4 +-
 7 files changed, 111 insertions(+), 28 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
index aed5d75d70115..3003f3b3e5b1b 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
@@ -25,7 +25,6 @@
 import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
-import org.apache.kafka.clients.consumer.OffsetCommitCallback;
 import org.apache.kafka.clients.producer.internals.BufferPool;
 import org.apache.kafka.clients.producer.internals.BuiltInPartitioner;
 import org.apache.kafka.clients.producer.internals.KafkaProducerMetrics;
@@ -598,14 +597,17 @@ private TransactionManager configureTransactionState(ProducerConfig config,
 
         if (config.getBoolean(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG)) {
             final String transactionalId = config.getString(ProducerConfig.TRANSACTIONAL_ID_CONFIG);
+            final boolean enable2PC = config.getBoolean(ProducerConfig.TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG);
             final int transactionTimeoutMs = config.getInt(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG);
             final long retryBackoffMs = config.getLong(ProducerConfig.RETRY_BACKOFF_MS_CONFIG);
+            
             transactionManager = new TransactionManager(
                 logContext,
                 transactionalId,
                 transactionTimeoutMs,
                 retryBackoffMs,
-                apiVersions
+                apiVersions,
+                enable2PC
             );
 
             if (transactionManager.isTransactional())
@@ -656,6 +658,47 @@ public void initTransactions() {
         transactionManager.maybeUpdateTransactionV2Enabled(true);
     }
 
+    /**
+     * Performs initialization of transactions functionality in this producer instance. This method bootstraps
+     * the producer with a {@code producerId} and also resets the internal state of the producer following a previous
+     * fatal error. Additionally, it allows setting the {@code keepPreparedTxn} flag which, if set to true, puts the producer
+     * into a restricted state that only allows transaction completion operations.
+     * 
+     * <p>
+     * When {@code keepPreparedTxn} is set to {@code true}, the producer will be able to complete in-flight prepared
+     * transactions, but will only allow calling {@link #commitTransaction()}, {@link #abortTransaction()}, or
+     * the to-be-added {@code completeTransaction()} methods. This is to support recovery of prepared transactions 
+     * after a producer restart.
+     *
+     * <p>
+     * Note that this method should only be called once during the lifetime of a producer instance, and must be
+     * called before any other methods which require a {@code transactionalId} to be specified.
+     *
+     * @param keepPreparedTxn whether to keep prepared transactions, restricting the producer to only support completion of
+     *                        prepared transactions. When set to true, the producer will only allow transaction completion
+     *                        operations after initialization.
+     *
+     * @throws IllegalStateException if no {@code transactional.id} has been configured for the producer
+     * @throws org.apache.kafka.common.errors.UnsupportedVersionException fatal error indicating that the broker
+     *         does not support transactions (i.e. if its version is lower than 0.11.0.0). If this is encountered,
+     *         the producer cannot be used for transactional messaging.
+     * @throws org.apache.kafka.common.errors.AuthorizationException fatal error indicating that the configured
+     *         {@code transactional.id} is not authorized. If this is encountered, the producer cannot be used for
+     *         transactional messaging.
+     * @throws KafkaException if the producer has encountered a previous fatal error or for any other unexpected error
+     * @see #initTransactions()
+     */
+    public void initTransactions(boolean keepPreparedTxn) {
+        throwIfNoTransactionManager();
+        throwIfProducerClosed();
+        long now = time.nanoseconds();
+        TransactionalRequestResult result = transactionManager.initializeTransactions(keepPreparedTxn);
+        sender.wakeup();
+        result.await(maxBlockTimeMs, TimeUnit.MILLISECONDS);
+        producerMetrics.recordInit(time.nanoseconds() - now);
+        transactionManager.maybeUpdateTransactionV2Enabled(true);
+    }
+
     /**
      * Should be called before the start of each new transaction. Note that prior to the first invocation
      * of this method, you must invoke {@link #initTransactions()} exactly one time.
@@ -703,7 +746,7 @@ public void beginTransaction() throws ProducerFencedException {
      * <p>
      * Note, that the consumer should have {@code enable.auto.commit=false} and should
      * also not commit offsets manually (via {@link KafkaConsumer#commitSync(Map) sync} or
-     * {@link KafkaConsumer#commitAsync(Map, OffsetCommitCallback) async} commits).
+     * {@link KafkaConsumer#commitAsync()} (Map, OffsetCommitCallback) async} commits).
      * This method will raise {@link TimeoutException} if the producer cannot send offsets before expiration of {@code max.block.ms}.
      * Additionally, it will raise {@link InterruptException} if interrupted.
      *
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
index a4aac86df09fc..bef55fc64ca21 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
@@ -158,6 +158,23 @@ public void initTransactions() {
         this.sentOffsets = false;
     }
 
+    @Override
+    public void initTransactions(boolean keepPreparedTxn) {
+        verifyNotClosed();
+        verifyNotFenced();
+        if (this.transactionInitialized) {
+            throw new IllegalStateException("MockProducer has already been initialized for transactions.");
+        }
+        if (this.initTransactionException != null) {
+            throw this.initTransactionException;
+        }
+        this.transactionInitialized = true;
+        this.transactionInFlight = false;
+        this.transactionCommitted = false;
+        this.transactionAborted = false;
+        this.sentOffsets = false;
+    }
+
     @Override
     public void beginTransaction() throws ProducerFencedException {
         verifyNotClosed();
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java b/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java
index 798034dda6de2..73228dda0493e 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java
@@ -44,6 +44,11 @@ public interface Producer<K, V> extends Closeable {
      */
     void initTransactions();
 
+    /**
+     * See {@link KafkaProducer#initTransactions(boolean)}
+     */
+    void initTransactions(boolean keepPreparedTxn);
+
     /**
      * See {@link KafkaProducer#beginTransaction()}
      */
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index c78134c72ecf2..787297d4d9c47 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -138,6 +138,7 @@ public class TransactionManager {
      *
      * <ul>
      *     <li>{@link Producer#initTransactions()} calls {@link #initializeTransactions()}</li>
+     *     <li>{@link Producer#initTransactions(boolean)} calls {@link #initializeTransactions(boolean)}</li>
      *     <li>{@link Producer#beginTransaction()} calls {@link #beginTransaction()}</li>
      *     <li>{@link Producer#commitTransaction()}} calls {@link #beginCommit()}</li>
      *     <li>{@link Producer#abortTransaction()} calls {@link #beginAbort()}
@@ -195,6 +196,7 @@ public class TransactionManager {
     private volatile boolean clientSideEpochBumpRequired = false;
     private volatile long latestFinalizedFeaturesEpoch = -1;
     private volatile boolean isTransactionV2Enabled = false;
+    private final boolean enable2PC;
 
     private enum State {
         UNINITIALIZED,
@@ -255,7 +257,8 @@ public TransactionManager(final LogContext logContext,
                               final String transactionalId,
                               final int transactionTimeoutMs,
                               final long retryBackoffMs,
-                              final ApiVersions apiVersions) {
+                              final ApiVersions apiVersions,
+                              final boolean enable2PC) {
         this.producerIdAndEpoch = ProducerIdAndEpoch.NONE;
         this.transactionalId = transactionalId;
         this.log = logContext.logger(TransactionManager.class);
@@ -273,6 +276,7 @@ public TransactionManager(final LogContext logContext,
         this.retryBackoffMs = retryBackoffMs;
         this.txnPartitionMap = new TxnPartitionMap(logContext);
         this.apiVersions = apiVersions;
+        this.enable2PC = enable2PC;
     }
 
     void setPoisonStateOnInvalidTransition(boolean shouldPoisonState) {
@@ -280,10 +284,21 @@ void setPoisonStateOnInvalidTransition(boolean shouldPoisonState) {
     }
 
     public synchronized TransactionalRequestResult initializeTransactions() {
-        return initializeTransactions(ProducerIdAndEpoch.NONE);
+        return initializeTransactions(ProducerIdAndEpoch.NONE, false);
     }
 
     synchronized TransactionalRequestResult initializeTransactions(ProducerIdAndEpoch producerIdAndEpoch) {
+        return initializeTransactions(producerIdAndEpoch, false);
+    }
+
+    public synchronized TransactionalRequestResult initializeTransactions(boolean keepPreparedTxn) {
+        return initializeTransactions(ProducerIdAndEpoch.NONE, keepPreparedTxn);
+    }
+
+    synchronized TransactionalRequestResult initializeTransactions(
+        ProducerIdAndEpoch producerIdAndEpoch,
+        boolean keepPreparedTxn
+    ) {
         maybeFailWithError();
 
         boolean isEpochBump = producerIdAndEpoch != ProducerIdAndEpoch.NONE;
@@ -299,9 +314,12 @@ synchronized TransactionalRequestResult initializeTransactions(ProducerIdAndEpoc
                     .setTransactionalId(transactionalId)
                     .setTransactionTimeoutMs(transactionTimeoutMs)
                     .setProducerId(producerIdAndEpoch.producerId)
-                    .setProducerEpoch(producerIdAndEpoch.epoch);
+                    .setProducerEpoch(producerIdAndEpoch.epoch)
+                    .setEnable2Pc(enable2PC)
+                    .setKeepPreparedTxn(keepPreparedTxn);
+
             InitProducerIdHandler handler = new InitProducerIdHandler(new InitProducerIdRequest.Builder(requestData),
-                    isEpochBump);
+                isEpochBump);
             enqueueRequest(handler);
             return handler.result;
         }, State.INITIALIZING, "initTransactions");
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
index fbb3484a03f7f..507c104ee2671 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
@@ -1289,7 +1289,7 @@ public void testInitTransactionsResponseAfterTimeout() throws Exception {
                     ((FindCoordinatorRequest) request).data().keyType() == FindCoordinatorRequest.CoordinatorType.TRANSACTION.id(),
                 FindCoordinatorResponse.prepareResponse(Errors.NONE, "bad-transaction", NODE));
 
-            Future<?> future = executor.submit(producer::initTransactions);
+            Future<?> future = executor.submit(() -> producer.initTransactions());
             TestUtils.waitForCondition(client::hasInFlightRequests,
                 "Timed out while waiting for expected `InitProducerId` request to be sent");
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index cd1cd2df8828c..085c602df20a1 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -490,7 +490,7 @@ public void senderThreadShouldNotGetStuckWhenThrottledAndAddingPartitionsToTxn()
 
             ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
             apiVersions.update("0", NodeApiVersions.create(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3));
-            TransactionManager txnManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions);
+            TransactionManager txnManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions, false);
 
             setupWithTransactionState(txnManager);
             doInitTransactions(txnManager, producerIdAndEpoch);
@@ -616,7 +616,7 @@ public void testInitProducerIdWithMaxInFlightOne() {
         // Initialize transaction manager. InitProducerId will be queued up until metadata response
         // is processed and FindCoordinator can be sent to `leastLoadedNode`.
         TransactionManager transactionManager = new TransactionManager(new LogContext(), "testInitProducerIdWithPendingMetadataRequest",
-                60000, 100L, new ApiVersions());
+                60000, 100L, new ApiVersions(), false);
         setupWithTransactionState(transactionManager, false, null, false);
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, (short) 0);
         transactionManager.initializeTransactions();
@@ -668,7 +668,7 @@ public void testNodeNotReady() {
         client = new MockClient(time, metadata);
 
         TransactionManager transactionManager = new TransactionManager(new LogContext(), "testNodeNotReady",
-                60000, 100L, new ApiVersions());
+                60000, 100L, new ApiVersions(), false);
         setupWithTransactionState(transactionManager, false, null, true);
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, (short) 0);
         transactionManager.initializeTransactions();
@@ -1510,7 +1510,7 @@ public void testExpiryOfFirstBatchShouldCauseEpochBumpIfFutureBatchesFail() thro
     public void testUnresolvedSequencesAreNotFatal() throws Exception {
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
         apiVersions.update("0", NodeApiVersions.create(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3));
-        TransactionManager txnManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions);
+        TransactionManager txnManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions, false);
 
         setupWithTransactionState(txnManager);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -1795,7 +1795,7 @@ public void testCorrectHandlingOfDuplicateSequenceError() throws Exception {
     @Test
     public void testTransactionalUnknownProducerHandlingWhenRetentionLimitReached() throws Exception {
         final long producerId = 343434L;
-        TransactionManager transactionManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions);
+        TransactionManager transactionManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions, false);
 
         setupWithTransactionState(transactionManager);
         doInitTransactions(transactionManager, new ProducerIdAndEpoch(producerId, (short) 0));
@@ -2352,7 +2352,7 @@ public void testIdempotentSplitBatchAndSend() throws Exception {
     public void testTransactionalSplitBatchAndSend() throws Exception {
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
         TopicPartition tp = new TopicPartition("testSplitBatchAndSend", 1);
-        TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100, apiVersions);
+        TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100, apiVersions, false);
 
         setupWithTransactionState(txnManager);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -2694,7 +2694,7 @@ public void testTransactionalRequestsSentOnShutdown() {
         Metrics m = new Metrics();
         SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
         try {
-            TransactionManager txnManager = new TransactionManager(logContext, "testTransactionalRequestsSentOnShutdown", 6000, 100, apiVersions);
+            TransactionManager txnManager = new TransactionManager(logContext, "testTransactionalRequestsSentOnShutdown", 6000, 100, apiVersions, false);
             Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
                     maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions);
 
@@ -2727,7 +2727,7 @@ public void testRecordsFlushedImmediatelyOnTransactionCompletion() throws Except
             int lingerMs = 50;
             SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
 
-            TransactionManager txnManager = new TransactionManager(logContext, "txnId", 6000, 100, apiVersions);
+            TransactionManager txnManager = new TransactionManager(logContext, "txnId", 6000, 100, apiVersions, false);
             setupWithTransactionState(txnManager, lingerMs);
 
             Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
@@ -2784,7 +2784,7 @@ public void testAwaitPendingRecordsBeforeCommittingTransaction() throws Exceptio
         try (Metrics m = new Metrics()) {
             SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
 
-            TransactionManager txnManager = new TransactionManager(logContext, "txnId", 6000, 100, apiVersions);
+            TransactionManager txnManager = new TransactionManager(logContext, "txnId", 6000, 100, apiVersions, false);
             setupWithTransactionState(txnManager);
 
             Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
@@ -2855,7 +2855,7 @@ public void testIncompleteTransactionAbortOnShutdown() {
         Metrics m = new Metrics();
         SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
         try {
-            TransactionManager txnManager = new TransactionManager(logContext, "testIncompleteTransactionAbortOnShutdown", 6000, 100, apiVersions);
+            TransactionManager txnManager = new TransactionManager(logContext, "testIncompleteTransactionAbortOnShutdown", 6000, 100, apiVersions, false);
             Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
                     maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions);
 
@@ -2889,7 +2889,7 @@ public void testForceShutdownWithIncompleteTransaction() {
         Metrics m = new Metrics();
         SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
         try {
-            TransactionManager txnManager = new TransactionManager(logContext, "testForceShutdownWithIncompleteTransaction", 6000, 100, apiVersions);
+            TransactionManager txnManager = new TransactionManager(logContext, "testForceShutdownWithIncompleteTransaction", 6000, 100, apiVersions, false);
             Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL,
                     maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions);
 
@@ -2919,7 +2919,7 @@ public void testForceShutdownWithIncompleteTransaction() {
     @Test
     public void testTransactionAbortedExceptionOnAbortWithoutError() throws InterruptedException {
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
-        TransactionManager txnManager = new TransactionManager(logContext, "testTransactionAbortedExceptionOnAbortWithoutError", 60000, 100, apiVersions);
+        TransactionManager txnManager = new TransactionManager(logContext, "testTransactionAbortedExceptionOnAbortWithoutError", 60000, 100, apiVersions, false);
 
         setupWithTransactionState(txnManager, false, null);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -2945,7 +2945,7 @@ public void testTransactionAbortedExceptionOnAbortWithoutError() throws Interrup
     public void testDoNotPollWhenNoRequestSent() {
         client = spy(new MockClient(time, metadata));
 
-        TransactionManager txnManager = new TransactionManager(logContext, "testDoNotPollWhenNoRequestSent", 6000, 100, apiVersions);
+        TransactionManager txnManager = new TransactionManager(logContext, "testDoNotPollWhenNoRequestSent", 6000, 100, apiVersions, false);
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
         setupWithTransactionState(txnManager);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -2957,7 +2957,7 @@ public void testDoNotPollWhenNoRequestSent() {
     @Test
     public void testTooLargeBatchesAreSafelyRemoved() throws InterruptedException {
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
-        TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100, apiVersions);
+        TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100, apiVersions, false);
 
         setupWithTransactionState(txnManager, false, null);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -3026,7 +3026,7 @@ public void testSenderShouldRetryWithBackoffOnRetriableError() {
     public void testReceiveFailedBatchTwiceWithTransactions() throws Exception {
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
         apiVersions.update("0", NodeApiVersions.create(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3));
-        TransactionManager txnManager = new TransactionManager(logContext, "testFailTwice", 60000, 100, apiVersions);
+        TransactionManager txnManager = new TransactionManager(logContext, "testFailTwice", 60000, 100, apiVersions, false);
 
         setupWithTransactionState(txnManager);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -3076,7 +3076,7 @@ public void testReceiveFailedBatchTwiceWithTransactions() throws Exception {
     public void testInvalidTxnStateIsAnAbortableError() throws Exception {
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
         apiVersions.update("0", NodeApiVersions.create(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3));
-        TransactionManager txnManager = new TransactionManager(logContext, "testInvalidTxnState", 60000, 100, apiVersions);
+        TransactionManager txnManager = new TransactionManager(logContext, "testInvalidTxnState", 60000, 100, apiVersions, false);
 
         setupWithTransactionState(txnManager);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -3115,7 +3115,7 @@ public void testInvalidTxnStateIsAnAbortableError() throws Exception {
     public void testTransactionAbortableExceptionIsAnAbortableError() throws Exception {
         ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
         apiVersions.update("0", NodeApiVersions.create(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3));
-        TransactionManager txnManager = new TransactionManager(logContext, "textTransactionAbortableException", 60000, 100, apiVersions);
+        TransactionManager txnManager = new TransactionManager(logContext, "textTransactionAbortableException", 60000, 100, apiVersions, false);
 
         setupWithTransactionState(txnManager);
         doInitTransactions(txnManager, producerIdAndEpoch);
@@ -3620,7 +3620,7 @@ private ProduceResponse produceResponse(TopicPartition tp, long offset, Errors e
     }
 
     private TransactionManager createTransactionManager() {
-        return new TransactionManager(new LogContext(), null, 0, RETRY_BACKOFF_MS, new ApiVersions());
+        return new TransactionManager(new LogContext(), null, 0, RETRY_BACKOFF_MS, new ApiVersions(), false);
     }
     
     private void setupWithTransactionState(TransactionManager transactionManager) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index 0d582bf80168d..30eb8e475c5b2 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -189,7 +189,7 @@ private void initializeTransactionManager(Optional<String> transactionalId, bool
             finalizedFeaturesEpoch));
         finalizedFeaturesEpoch += 1;
         this.transactionManager = new TransactionManager(logContext, transactionalId.orElse(null),
-                transactionTimeoutMs, DEFAULT_RETRY_BACKOFF_MS, apiVersions);
+                transactionTimeoutMs, DEFAULT_RETRY_BACKOFF_MS, apiVersions, false);
 
         int batchSize = 16 * 1024;
         int deliveryTimeoutMs = 3000;
@@ -1039,7 +1039,7 @@ public void testTransactionManagerDisablesV2() {
                 .setMinVersionLevel((short) 1)),
             0));
         this.transactionManager = new TransactionManager(logContext, transactionalId,
-            transactionTimeoutMs, DEFAULT_RETRY_BACKOFF_MS, apiVersions);
+            transactionTimeoutMs, DEFAULT_RETRY_BACKOFF_MS, apiVersions, false);
 
         int batchSize = 16 * 1024;
         int deliveryTimeoutMs = 3000;

From 3652d2b5e8124a9a8254523c73dad01c015fa9e5 Mon Sep 17 00:00:00 2001
From: rreddy-22 <rreddy@confluent.io>
Date: Wed, 9 Apr 2025 17:51:52 -0700
Subject: [PATCH 3/9] minor

---
 .../kafka/clients/producer/KafkaProducer.java | 50 +++++++++----------
 .../clients/producer/ProducerConfig.java      |  7 ++-
 .../internals/TransactionManager.java         |  2 +-
 3 files changed, 31 insertions(+), 28 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
index 3003f3b3e5b1b..daed6828cfd1a 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
@@ -25,6 +25,7 @@
 import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.clients.consumer.OffsetCommitCallback;
 import org.apache.kafka.clients.producer.internals.BufferPool;
 import org.apache.kafka.clients.producer.internals.BuiltInPartitioner;
 import org.apache.kafka.clients.producer.internals.KafkaProducerMetrics;
@@ -659,34 +660,33 @@ public void initTransactions() {
     }
 
     /**
-     * Performs initialization of transactions functionality in this producer instance. This method bootstraps
-     * the producer with a {@code producerId} and also resets the internal state of the producer following a previous
-     * fatal error. Additionally, it allows setting the {@code keepPreparedTxn} flag which, if set to true, puts the producer
-     * into a restricted state that only allows transaction completion operations.
-     * 
+     * Initialize the transactional state for this producer, similar to {@link #initTransactions()} but
+     * with additional handling for two-phase commit (2PC). Must be called before any send operations
+     * that require a {@code transactionalId}.
      * <p>
-     * When {@code keepPreparedTxn} is set to {@code true}, the producer will be able to complete in-flight prepared
-     * transactions, but will only allow calling {@link #commitTransaction()}, {@link #abortTransaction()}, or
-     * the to-be-added {@code completeTransaction()} methods. This is to support recovery of prepared transactions 
-     * after a producer restart.
-     *
+     * Unlike the standard {@link #initTransactions()}, when {@code keepPreparedTxn} is set to
+     * {@code true}, the producer does <em>not</em> automatically abort existing transactions
+     * in the “prepare” phase. Instead, it enters a recovery mode allowing only finalization
+     * of those previously prepared transactions. This behavior is crucial for 2PC scenarios,
+     * where transactions should remain intact until the external transaction manager decides
+     * whether to commit or abort.
      * <p>
-     * Note that this method should only be called once during the lifetime of a producer instance, and must be
-     * called before any other methods which require a {@code transactionalId} to be specified.
+     * When {@code keepPreparedTxn} is {@code false}, this behaves like the normal transactional
+     * initialization, aborting any unfinished transactions and resetting the producer for
+     * new writes.
      *
-     * @param keepPreparedTxn whether to keep prepared transactions, restricting the producer to only support completion of
-     *                        prepared transactions. When set to true, the producer will only allow transaction completion
-     *                        operations after initialization.
+     * @param keepPreparedTxn true to retain any in-flight prepared transactions (necessary for 2PC
+     *                        recovery), false to abort existing transactions and behave like
+     *                        the standard initTransactions
      *
-     * @throws IllegalStateException if no {@code transactional.id} has been configured for the producer
-     * @throws org.apache.kafka.common.errors.UnsupportedVersionException fatal error indicating that the broker
-     *         does not support transactions (i.e. if its version is lower than 0.11.0.0). If this is encountered,
-     *         the producer cannot be used for transactional messaging.
-     * @throws org.apache.kafka.common.errors.AuthorizationException fatal error indicating that the configured
-     *         {@code transactional.id} is not authorized. If this is encountered, the producer cannot be used for
-     *         transactional messaging.
-     * @throws KafkaException if the producer has encountered a previous fatal error or for any other unexpected error
-     * @see #initTransactions()
+     * @throws IllegalStateException if no {@code transactional.id} is configured
+     * @throws org.apache.kafka.common.errors.UnsupportedVersionException if the broker does not
+     *         support transactions (broker version < 0.11.0.0)
+     * @throws org.apache.kafka.common.errors.TransactionalIdAuthorizationException if the configured
+     *         {@code transactional.id} is unauthorized either for normal transaction writes or 2PC.
+     * @throws KafkaException if the producer encounters a fatal error or any other unexpected error
+     * @throws TimeoutException if the time taken for initialize the transaction has surpassed <code>max.block.ms</code>.
+     * @throws InterruptException if the thread is interrupted while blocked
      */
     public void initTransactions(boolean keepPreparedTxn) {
         throwIfNoTransactionManager();
@@ -746,7 +746,7 @@ public void beginTransaction() throws ProducerFencedException {
      * <p>
      * Note, that the consumer should have {@code enable.auto.commit=false} and should
      * also not commit offsets manually (via {@link KafkaConsumer#commitSync(Map) sync} or
-     * {@link KafkaConsumer#commitAsync()} (Map, OffsetCommitCallback) async} commits).
+     * {@link KafkaConsumer#commitAsync(Map, OffsetCommitCallback) async} commits).
      * This method will raise {@link TimeoutException} if the producer cannot send offsets before expiration of {@code max.block.ms}.
      * Additionally, it will raise {@link InterruptException} if interrupted.
      *
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java
index 4cd21434d3658..362d205e8c1aa 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java
@@ -358,7 +358,7 @@ public class ProducerConfig extends AbstractConfig {
     /** <code> transaction.two.phase.commit.enable </code> */
     public static final String TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG = "transaction.two.phase.commit.enable";
     private static final String TRANSACTION_TWO_PHASE_COMMIT_ENABLE_DOC = "If set to true, then the broker is informed that the client is participating in " +
-        "two phase commit protocol and transactions that this client starts never expire.";
+            "two phase commit protocol and transactions that this client starts never expire.";
 
     /**
      * <code>security.providers</code>
@@ -620,7 +620,10 @@ private void postProcessAndValidateIdempotenceConfigs(final Map<String, Object>
             throw new ConfigException("Cannot set a " + ProducerConfig.TRANSACTIONAL_ID_CONFIG + " without also enabling idempotence.");
         }
 
-        // validate that transaction.timeout.ms is not set when transaction.two.phase.commit.enable is true
+        // Validate that transaction.timeout.ms is not set when transaction.two.phase.commit.enable is true
+        // In standard Kafka transactions, the broker enforces transaction.timeout.ms and aborts any
+        // transaction that isn't completed in time. With two-phase commit (2PC), an external coordinator
+        // decides when to finalize, so broker-side timeouts don't apply. Disallow using both.
         boolean enable2PC = this.getBoolean(TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG);
         boolean userConfiguredTransactionTimeout = originalConfigs.containsKey(TRANSACTION_TIMEOUT_CONFIG);
         if (enable2PC && userConfiguredTransactionTimeout) {
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index 787297d4d9c47..ef2c0485f2585 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -319,7 +319,7 @@ synchronized TransactionalRequestResult initializeTransactions(
                     .setKeepPreparedTxn(keepPreparedTxn);
 
             InitProducerIdHandler handler = new InitProducerIdHandler(new InitProducerIdRequest.Builder(requestData),
-                isEpochBump);
+                    isEpochBump);
             enqueueRequest(handler);
             return handler.result;
         }, State.INITIALIZING, "initTransactions");

From eda67cf8cf3c4f0f7c224ae6593fee4a8e3403d3 Mon Sep 17 00:00:00 2001
From: rreddy-22 <rreddy@confluent.io>
Date: Mon, 14 Apr 2025 13:57:42 -0700
Subject: [PATCH 4/9] Add test to txn manager

---
 .../internals/TransactionManagerTest.java     | 84 +++++++++++++++++--
 1 file changed, 75 insertions(+), 9 deletions(-)

diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index 30eb8e475c5b2..30ced4e298e77 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -159,17 +159,28 @@ public void setup() {
         this.client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap("test", 2)));
         this.brokerNode = new Node(0, "localhost", 2211);
 
-        initializeTransactionManager(Optional.of(transactionalId), false);
+        initializeTransactionManager(Optional.of(transactionalId), false, false);
+    }
+
+    private void initializeTransactionManager(
+        Optional<String> transactionalId,
+        boolean transactionV2Enabled
+    ) {
+        initializeTransactionManager(transactionalId, transactionV2Enabled, false);
     }
 
-    private void initializeTransactionManager(Optional<String> transactionalId, boolean transactionV2Enabled) {
+    private void initializeTransactionManager(
+        Optional<String> transactionalId,
+        boolean transactionV2Enabled,
+        boolean enable2pc
+    ) {
         Metrics metrics = new Metrics(time);
 
         apiVersions.update("0", new NodeApiVersions(Arrays.asList(
             new ApiVersion()
                 .setApiKey(ApiKeys.INIT_PRODUCER_ID.id)
                 .setMinVersion((short) 0)
-                .setMaxVersion((short) 3),
+                .setMaxVersion((short) 6),
             new ApiVersion()
                 .setApiKey(ApiKeys.PRODUCE.id)
                 .setMinVersion((short) 0)
@@ -189,7 +200,7 @@ private void initializeTransactionManager(Optional<String> transactionalId, bool
             finalizedFeaturesEpoch));
         finalizedFeaturesEpoch += 1;
         this.transactionManager = new TransactionManager(logContext, transactionalId.orElse(null),
-                transactionTimeoutMs, DEFAULT_RETRY_BACKOFF_MS, apiVersions, false);
+                transactionTimeoutMs, DEFAULT_RETRY_BACKOFF_MS, apiVersions, enable2pc);
 
         int batchSize = 16 * 1024;
         int deliveryTimeoutMs = 3000;
@@ -4035,16 +4046,39 @@ private void prepareFindCoordinatorResponse(Errors error, boolean shouldDisconne
         }, FindCoordinatorResponse.prepareResponse(error, coordinatorKey, brokerNode), shouldDisconnect);
     }
 
-    private void prepareInitPidResponse(Errors error, boolean shouldDisconnect, long producerId, short producerEpoch) {
+    private void prepareInitPidResponse(
+        Errors error,
+        boolean shouldDisconnect,
+        long producerId,
+        short producerEpoch
+    ) {
+        prepareInitPidResponse(error, shouldDisconnect, producerId, producerEpoch, false, false, (long) -1, (short) -1);
+    }
+
+    private void prepareInitPidResponse(
+        Errors error,
+        boolean shouldDisconnect,
+        long producerId,
+        short producerEpoch,
+        boolean keepPreparedTxn,
+        boolean enable2Pc,
+        long ongoingProducerId,
+        short ongoingProducerEpoch
+    ) {
         InitProducerIdResponseData responseData = new InitProducerIdResponseData()
-                .setErrorCode(error.code())
-                .setProducerEpoch(producerEpoch)
-                .setProducerId(producerId)
-                .setThrottleTimeMs(0);
+            .setErrorCode(error.code())
+            .setProducerEpoch(producerEpoch)
+            .setProducerId(producerId)
+            .setThrottleTimeMs(0)
+            .setOngoingTxnProducerId(ongoingProducerId)
+            .setOngoingTxnProducerEpoch(ongoingProducerEpoch);
+
         client.prepareResponse(body -> {
             InitProducerIdRequest initProducerIdRequest = (InitProducerIdRequest) body;
             assertEquals(transactionalId, initProducerIdRequest.data().transactionalId());
             assertEquals(transactionTimeoutMs, initProducerIdRequest.data().transactionTimeoutMs());
+            assertEquals(keepPreparedTxn, initProducerIdRequest.data().keepPreparedTxn());
+            assertEquals(enable2Pc, initProducerIdRequest.data().enable2Pc());
             return true;
         }, new InitProducerIdResponse(responseData), shouldDisconnect);
     }
@@ -4373,4 +4407,36 @@ private void runUntil(Supplier<Boolean> condition) {
         ProducerTestUtils.runUntil(sender, condition);
     }
 
+    @Test
+    public void testInitializeTransactionsWithKeepPreparedTxn() {
+        initializeTransactionManager(Optional.of(transactionalId), true, true);
+
+        client.prepareResponse(
+            FindCoordinatorResponse.prepareResponse(Errors.NONE, transactionalId, brokerNode)
+        );
+
+        // Simulate an ongoing prepared transaction (ongoingProducerId != -1).
+        long ongoingProducerId = 999L;
+        short ongoingEpoch = 10;
+        short bumpedEpoch = 11;
+
+        prepareInitPidResponse(
+            Errors.NONE,
+            false,
+            ongoingProducerId,
+            bumpedEpoch,
+            true,
+            true,
+            ongoingProducerId,
+            ongoingEpoch
+        );
+
+        transactionManager.initializeTransactions(true);
+        runUntil(transactionManager::hasProducerId);
+        
+        assertTrue(transactionManager.hasProducerId());
+        assertFalse(transactionManager.hasOngoingTransaction());
+        assertEquals(ongoingProducerId, transactionManager.producerIdAndEpoch().producerId);
+        assertEquals(bumpedEpoch, transactionManager.producerIdAndEpoch().epoch);
+    }
 }

From f99b31f2b366f243b372897be80bbb83d567c659 Mon Sep 17 00:00:00 2001
From: rreddy-22 <rreddy@confluent.io>
Date: Mon, 14 Apr 2025 16:48:40 -0700
Subject: [PATCH 5/9] Add test to kafka producer

---
 .../clients/producer/KafkaProducerTest.java   | 57 +++++++++++++++++++
 1 file changed, 57 insertions(+)

diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
index 507c104ee2671..8f1ef4704a15b 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
@@ -74,6 +74,7 @@
 import org.apache.kafka.common.requests.FindCoordinatorRequest;
 import org.apache.kafka.common.requests.FindCoordinatorResponse;
 import org.apache.kafka.common.requests.InitProducerIdResponse;
+import org.apache.kafka.common.requests.InitProducerIdRequest;
 import org.apache.kafka.common.requests.JoinGroupRequest;
 import org.apache.kafka.common.requests.MetadataResponse;
 import org.apache.kafka.common.requests.ProduceResponse;
@@ -102,7 +103,10 @@
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInfo;
 import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
 import org.junit.jupiter.params.provider.ValueSource;
+import org.junit.jupiter.params.provider.CsvSource;
 import org.mockito.MockedStatic;
 import org.mockito.Mockito;
 import org.mockito.internal.stubbing.answers.CallsRealMethods;
@@ -1364,6 +1368,59 @@ public void testInitTransactionWhileThrottled() {
         }
     }
 
+    @ParameterizedTest
+    @CsvSource({
+        "true, false",
+        "true, true",
+        "false, true"
+    })
+    public void testInitTransactionsWithKeepPreparedTxnAndTwoPhaseCommit(boolean keepPreparedTxn, boolean enable2PC) {
+        Map<String, Object> configs = new HashMap<>();
+        configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "test-txn-id");
+        configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 10000);
+        configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000");
+        if (enable2PC) {
+            configs.put(ProducerConfig.TRANSACTION_TWO_PHASE_COMMIT_ENABLE_CONFIG, true);
+        }
+
+        Time time = new MockTime(1);
+        MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1));
+        ProducerMetadata metadata = newMetadata(0, 0, Long.MAX_VALUE);
+        MockClient client = new MockClient(time, metadata);
+        client.updateMetadata(initialUpdateResponse);
+
+        // Capture flags from the InitProducerIdRequest
+        boolean[] requestFlags = new boolean[2]; // [keepPreparedTxn, enable2Pc]
+        
+        client.prepareResponse(
+            request -> request instanceof FindCoordinatorRequest &&
+                ((FindCoordinatorRequest) request).data().keyType() == FindCoordinatorRequest.CoordinatorType.TRANSACTION.id(),
+            FindCoordinatorResponse.prepareResponse(Errors.NONE, "test-txn-id", NODE));
+            
+        client.prepareResponse(
+            request -> {
+                if (request instanceof InitProducerIdRequest) {
+                    InitProducerIdRequest initRequest = (InitProducerIdRequest) request;
+                    requestFlags[0] = initRequest.data().keepPreparedTxn();
+                    requestFlags[1] = initRequest.data().enable2Pc();
+                    return true;
+                }
+                return false;
+            },
+            initProducerIdResponse(1L, (short) 5, Errors.NONE));
+            
+        try (Producer<String, String> producer = kafkaProducer(configs, new StringSerializer(),
+                new StringSerializer(), metadata, client, null, time)) {
+            producer.initTransactions(keepPreparedTxn);
+            
+            // Verify request flags match expected values
+            assertEquals(keepPreparedTxn, requestFlags[0], 
+                "keepPreparedTxn flag should match input parameter");
+            assertEquals(enable2PC, requestFlags[1], 
+                "enable2Pc flag should match producer configuration");
+        }
+    }
+    
     @Test
     public void testClusterAuthorizationFailure() throws Exception {
         int maxBlockMs = 500;

From b1b0072eca7b00dd356d9036e880c4c9efbd67c6 Mon Sep 17 00:00:00 2001
From: rreddy-22 <rreddy@confluent.io>
Date: Tue, 15 Apr 2025 11:37:46 -0700
Subject: [PATCH 6/9] minor

---
 .../apache/kafka/clients/producer/KafkaProducerTest.java    | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
index 8f1ef4704a15b..a87223b331948 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
@@ -73,8 +73,8 @@
 import org.apache.kafka.common.requests.EndTxnResponse;
 import org.apache.kafka.common.requests.FindCoordinatorRequest;
 import org.apache.kafka.common.requests.FindCoordinatorResponse;
-import org.apache.kafka.common.requests.InitProducerIdResponse;
 import org.apache.kafka.common.requests.InitProducerIdRequest;
+import org.apache.kafka.common.requests.InitProducerIdResponse;
 import org.apache.kafka.common.requests.JoinGroupRequest;
 import org.apache.kafka.common.requests.MetadataResponse;
 import org.apache.kafka.common.requests.ProduceResponse;
@@ -103,10 +103,8 @@
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInfo;
 import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.Arguments;
-import org.junit.jupiter.params.provider.MethodSource;
-import org.junit.jupiter.params.provider.ValueSource;
 import org.junit.jupiter.params.provider.CsvSource;
+import org.junit.jupiter.params.provider.ValueSource;
 import org.mockito.MockedStatic;
 import org.mockito.Mockito;
 import org.mockito.internal.stubbing.answers.CallsRealMethods;

From 5ccb444b8dc5ec29e5cee31c2b7972b0e86bb6d6 Mon Sep 17 00:00:00 2001
From: rreddy-22 <rreddy@confluent.io>
Date: Wed, 16 Apr 2025 14:27:45 -0700
Subject: [PATCH 7/9] minor

---
 .../java/org/apache/kafka/clients/producer/KafkaProducer.java   | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
index dc282dca420f2..af03dea502d52 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
@@ -679,7 +679,7 @@ public void initTransactions() {
      *
      * @throws IllegalStateException if no {@code transactional.id} is configured
      * @throws org.apache.kafka.common.errors.UnsupportedVersionException if the broker does not
-     *         support transactions (broker version < 0.11.0.0)
+     *         support transactions (i.e. if its version is lower than 0.11.0.0)
      * @throws org.apache.kafka.common.errors.TransactionalIdAuthorizationException if the configured
      *         {@code transactional.id} is unauthorized either for normal transaction writes or 2PC.
      * @throws KafkaException if the producer encounters a fatal error or any other unexpected error

From 4618b63668bc1635e736e4c11ef6ddd6266096f1 Mon Sep 17 00:00:00 2001
From: rreddy-22 <rreddy@confluent.io>
Date: Thu, 17 Apr 2025 16:56:37 -0700
Subject: [PATCH 8/9] add log for keepPrepared

---
 .../kafka/clients/producer/internals/TransactionManager.java   | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index ef2c0485f2585..317acc529db5a 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -307,6 +307,9 @@ synchronized TransactionalRequestResult initializeTransactions(
             if (!isEpochBump) {
                 transitionTo(State.INITIALIZING);
                 log.info("Invoking InitProducerId for the first time in order to acquire a producer ID");
+                if (keepPreparedTxn) {
+                    log.info("Invoking InitProducerId with keepPreparedTxn set to true for 2PC transactions");
+                }
             } else {
                 log.info("Invoking InitProducerId with current producer ID and epoch {} in order to bump the epoch", producerIdAndEpoch);
             }

From 4a087136c0b80cb15f4671dde42cf2b11ba5e283 Mon Sep 17 00:00:00 2001
From: rreddy-22 <rreddy@confluent.io>
Date: Fri, 18 Apr 2025 16:21:57 -0700
Subject: [PATCH 9/9] address comments

---
 .../kafka/clients/producer/KafkaProducer.java | 23 +++++++------------
 .../kafka/clients/producer/MockProducer.java  | 17 --------------
 .../kafka/clients/producer/Producer.java      |  4 +++-
 3 files changed, 11 insertions(+), 33 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
index af03dea502d52..4819b8232b9d8 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
@@ -647,27 +647,20 @@ private TransactionManager configureTransactionState(ProducerConfig config,
      * @throws InterruptException if the thread is interrupted while blocked
      */
     public void initTransactions() {
-        throwIfNoTransactionManager();
-        throwIfProducerClosed();
-        long now = time.nanoseconds();
-        TransactionalRequestResult result = transactionManager.initializeTransactions();
-        sender.wakeup();
-        result.await(maxBlockTimeMs, TimeUnit.MILLISECONDS);
-        producerMetrics.recordInit(time.nanoseconds() - now);
-        transactionManager.maybeUpdateTransactionV2Enabled(true);
+        initTransactions(false);
     }
 
     /**
      * Initialize the transactional state for this producer, similar to {@link #initTransactions()} but
-     * with additional handling for two-phase commit (2PC). Must be called before any send operations
-     * that require a {@code transactionalId}.
+     * with additional capabilities to keep a previously prepared transaction.
+     * Must be called before any send operations that require a {@code transactionalId}.
      * <p>
      * Unlike the standard {@link #initTransactions()}, when {@code keepPreparedTxn} is set to
-     * {@code true}, the producer does <em>not</em> automatically abort existing transactions
-     * in the “prepare” phase. Instead, it enters a recovery mode allowing only finalization
-     * of those previously prepared transactions. This behavior is crucial for 2PC scenarios,
-     * where transactions should remain intact until the external transaction manager decides
-     * whether to commit or abort.
+     * {@code true}, the producer does <em>not</em> automatically abort existing transactions.
+     * Instead, it enters a recovery mode allowing only finalization of those previously prepared transactions.
+     *
+     * This behavior is especially crucial for 2PC scenarios, where transactions should remain intact
+     * until the external transaction manager decides whether to commit or abort.
      * <p>
      * When {@code keepPreparedTxn} is {@code false}, this behaves like the normal transactional
      * initialization, aborting any unfinished transactions and resetting the producer for
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
index bef55fc64ca21..e3c5a23ca5195 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
@@ -141,23 +141,6 @@ public MockProducer() {
         this(Cluster.empty(), false, null, null, null);
     }
 
-    @Override
-    public void initTransactions() {
-        verifyNotClosed();
-        verifyNotFenced();
-        if (this.transactionInitialized) {
-            throw new IllegalStateException("MockProducer has already been initialized for transactions.");
-        }
-        if (this.initTransactionException != null) {
-            throw this.initTransactionException;
-        }
-        this.transactionInitialized = true;
-        this.transactionInFlight = false;
-        this.transactionCommitted = false;
-        this.transactionAborted = false;
-        this.sentOffsets = false;
-    }
-
     @Override
     public void initTransactions(boolean keepPreparedTxn) {
         verifyNotClosed();
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java b/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java
index 73228dda0493e..a5cd92295ff96 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java
@@ -42,7 +42,9 @@ public interface Producer<K, V> extends Closeable {
     /**
      * See {@link KafkaProducer#initTransactions()}
      */
-    void initTransactions();
+    default void initTransactions() {
+        initTransactions(false);
+    }
 
     /**
      * See {@link KafkaProducer#initTransactions(boolean)}