Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
a89ed05
[feat] add FAILED status to AITaskStatus and update progress check
kgy1008 Oct 25, 2025
869fc60
[feat] add AIErrorMessage and AITaskFailedEvent records for error han…
kgy1008 Oct 25, 2025
85395fc
[feat] enhance error handling by adding AIErrorType enum and updating…
kgy1008 Oct 25, 2025
f58ddf4
[feat] add failure tracking columns to student_record_ai_task table
kgy1008 Oct 25, 2025
8d0645a
[feat] implement error message handling in SSEChannelManager and SSEM…
kgy1008 Oct 25, 2025
62f18bd
[feat] handle AI task failure by processing error messages and publis…
kgy1008 Oct 25, 2025
fbdef40
[feat] implement AI task failure handling and compensation logic
kgy1008 Oct 25, 2025
eb90185
[refac] remove unused fromErrorMessage method in AITaskFailedEvent
kgy1008 Oct 25, 2025
16b572e
[fix] add AITaskFailedEvent class and update event handling in AIComp…
kgy1008 Oct 25, 2025
476bdf0
[feat] add fromString method to AIErrorType for error type retrieval
kgy1008 Oct 25, 2025
a1396a6
[refac] refactor point deduction and compensation methods in Member a…
kgy1008 Oct 25, 2025
801911a
[refac] refactor AICompensationListener to use EventListener for AITa…
kgy1008 Oct 25, 2025
3aaacaa
[refac] enhance compensation handling in AICompensationListener with …
kgy1008 Oct 25, 2025
6c828a2
[refac] enhance fromString method in AIErrorType to handle null and b…
kgy1008 Oct 25, 2025
72c9f2b
[refac] improve error handling in handleAITaskFailure method for bett…
kgy1008 Oct 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- Add failure tracking columns to student_record_ai_task table

ALTER TABLE student_record_ai_task
ADD COLUMN failed_at DATETIME NULL AFTER completed_at,
ADD COLUMN error_type VARCHAR(50) NULL AFTER failed_at;
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ public interface RedisStoreService {

String get(String key);

Boolean setIfAbsent(String key, String value, Duration ttl);

void delete(String key);

Long increment(String key, Duration ttl);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.edukit.core.event.ai;

import com.edukit.core.common.service.RedisStoreService;
import com.edukit.core.event.ai.dto.AITaskFailedEvent;
import com.edukit.core.point.service.PointService;
import com.edukit.core.studentrecord.db.entity.StudentRecordAITask;
import com.edukit.core.studentrecord.service.AITaskService;
import java.time.Duration;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.context.event.EventListener;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;

@Slf4j
@Component
@RequiredArgsConstructor
@ConditionalOnBean(RedisStoreService.class)
public class AICompensationListener {

private final PointService pointService;
private final AITaskService aiTaskService;
private final RedisStoreService redisStoreService;

private static final String COMPENSATION_KEY_PREFIX = "compensation:";
private static final int DEDUCTED_POINTS = 100;
private static final Duration COMPENSATION_RECORD_TTL = Duration.ofDays(7);

@Async("aiTaskExecutor")
@EventListener
public void handleAITaskFailure(final AITaskFailedEvent event) {
String taskId = event.taskId();
String compensationKey = COMPENSATION_KEY_PREFIX + taskId;

// 원자적 선점 - Redis SET NX로 중복 보상 방지
if (!tryClaimCompensation(compensationKey)) {
log.warn("Task {} already compensated by another instance, skipping", taskId);
return;
}

try {
// Task 정보 조회
StudentRecordAITask task = aiTaskService.getTaskById(Long.valueOf(taskId));

// Task를 실패로 마킹
aiTaskService.markTaskAsFailed(Long.valueOf(taskId), event.errorType());

// 포인트 보상
pointService.compensatePoints(
task.getMember().getId(),
DEDUCTED_POINTS,
task.getId()
);

log.info("Successfully compensated {} points for taskId: {} (errorType: {})",
DEDUCTED_POINTS, taskId, event.errorType());

} catch (Exception e) {
log.error("Failed to compensate points for taskId: {}", taskId, e);
// 보상 실패 시 Redis 키 삭제하여 재시도 가능하게 함
redisStoreService.delete(compensationKey);
throw e;
}
}

/**
* 원자적 선점 시도
* @return true: 선점 성공 (보상 실행), false: 이미 다른 인스턴스가 선점 (스킵)
*/
private boolean tryClaimCompensation(final String compensationKey) {
Boolean claimed = redisStoreService.setIfAbsent(compensationKey, "COMPENSATED", COMPENSATION_RECORD_TTL);
return Boolean.TRUE.equals(claimed);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.edukit.core.event.ai.dto;

import com.fasterxml.jackson.annotation.JsonProperty;

public record AIErrorMessage(
@JsonProperty("task_id")
String taskId,
@JsonProperty("status")
String status,
@JsonProperty("error_type")
String errorType,
@JsonProperty("error_message")
String errorMessage,
@JsonProperty("retryable")
Boolean retryable
) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.edukit.core.event.ai.dto;

public record AITaskFailedEvent(
String taskId,
String errorType
) {

public static AITaskFailedEvent of(final String taskId, final String errorType) {
return new AITaskFailedEvent(taskId, errorType);
}

public static AITaskFailedEvent fromErrorMessage(final AIErrorMessage errorMessage) {
return AITaskFailedEvent.of(
errorMessage.taskId(),
errorMessage.errorType()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ public static SSEMessage response(final String taskId, final String finalContent
return new SSEMessage(taskId, "RESPONSE", new ResponseData(finalContent, version));
}

public static SSEMessage error(final String taskId, final String errorType, final String errorMessage) {
return new SSEMessage(taskId, "ERROR", new ErrorData(errorType, errorMessage));
}

public record ProgressData(
String message,
int version
Expand All @@ -24,4 +28,10 @@ public record ResponseData(
Integer version
) {
}

public record ErrorData(
String errorType,
String errorMessage
) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ public void updateEmailAndChangeVerifyStatus(final String email) {
this.role = MemberRole.PENDING_TEACHER;
}

public void deductPoints(final int pointsToDeduct) {
this.point -= pointsToDeduct;
public void deductPoints(final int point) {
this.point -= point;
}

public void addPoints(final int point) {
this.point += point;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@ public void deductPoints(final Long memberId, final int pointsToDeduct, final Lo
}

@Transactional
public Member compensatePoints(final Long memberId, final int pointsToCompensate, final Long taskId) {
public void compensatePoints(final Long memberId, final int pointsToCompensate, final Long taskId) {
Member member = memberRepository.findByIdWithLock(memberId)
.orElseThrow(() -> new MemberException(MemberErrorCode.MEMBER_NOT_FOUND));

member.deductPoints(-pointsToCompensate); // 음수 차감으로 복구
member.addPoints(pointsToCompensate); // 음수 차감으로 복구

// 포인트 히스토리 기록
PointHistory history = PointHistory.create(member, PointTransactionType.COMPENSATION, pointsToCompensate, taskId);
pointHistoryRepository.save(history);

return member;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import static jakarta.persistence.FetchType.LAZY;

import com.edukit.core.member.db.entity.Member;
import com.edukit.core.studentrecord.db.enums.AIErrorType;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EnumType;
import jakarta.persistence.Enumerated;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
Expand Down Expand Up @@ -39,6 +42,12 @@ public class StudentRecordAITask {

private LocalDateTime completedAt;

private LocalDateTime failedAt;

@Enumerated(EnumType.STRING)
@Column(length = 50)
private AIErrorType errorType;

@Builder(access = AccessLevel.PRIVATE)
private StudentRecordAITask(final Member member, final String prompt, final LocalDateTime startedAt) {
this.member = member;
Expand All @@ -60,4 +69,9 @@ public void start() {
public void complete() {
this.completedAt = LocalDateTime.now();
}

public void markAsFailed(final AIErrorType errorType) {
this.failedAt = LocalDateTime.now();
this.errorType = errorType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.edukit.core.studentrecord.db.enums;

import java.util.Arrays;
import lombok.Getter;
import lombok.RequiredArgsConstructor;

@Getter
@RequiredArgsConstructor
public enum AIErrorType {
OPENAI_API_ERROR("OpenAI API 호출 실패"),
LAMBDA_ERROR("Lambda 처리 오류"),
UNKNOWN_ERROR("알 수 없는 오류");

private final String description;

public static AIErrorType fromString(final String value) {
if (value == null || value.isBlank()) {
return UNKNOWN_ERROR;
}

return Arrays.stream(AIErrorType.values())
.filter(type -> type.name().equalsIgnoreCase(value))
.findFirst()
.orElse(UNKNOWN_ERROR);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.edukit.core.member.db.entity.Member;
import com.edukit.core.studentrecord.db.entity.StudentRecordAITask;
import com.edukit.core.studentrecord.db.enums.AIErrorType;
import com.edukit.core.studentrecord.db.repository.StudentRecordAITaskRepository;
import com.edukit.core.studentrecord.exception.StudentRecordErrorCode;
import com.edukit.core.studentrecord.exception.StudentRecordException;
Expand Down Expand Up @@ -37,6 +38,20 @@ public void validateUserTask(final long memberId, final String taskId) {
}
}

@Transactional(readOnly = true)
public StudentRecordAITask getTaskById(final Long taskId) {
return aiTaskRepository.findById(taskId)
.orElseThrow(() -> new StudentRecordException(StudentRecordErrorCode.AI_TASK_NOT_FOUND));
}

@Transactional
public void markTaskAsFailed(final Long taskId, final String errorType) {
StudentRecordAITask task = aiTaskRepository.findById(taskId)
.orElseThrow(() -> new StudentRecordException(StudentRecordErrorCode.AI_TASK_NOT_FOUND));
AIErrorType aiErrorType = AIErrorType.fromString(errorType);
task.markAsFailed(aiErrorType);
}

private long parseTaskId(final String taskId) {
try {
return Long.parseLong(taskId);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.edukit.core.studentrecord.service;

import com.edukit.common.infra.ServerInstanceManager;
import com.edukit.core.event.ai.dto.AIErrorMessage;
import com.edukit.core.event.ai.dto.AIProgressMessage;
import com.edukit.core.event.ai.dto.AIResponseMessage;
import com.edukit.core.common.service.RedisStreamService;
import com.edukit.core.event.ai.dto.AITaskFailedEvent;
import com.edukit.core.studentrecord.exception.StudentRecordErrorCode;
import com.edukit.core.studentrecord.exception.StudentRecordException;
import com.edukit.core.studentrecord.service.enums.AITaskStatus;
Expand All @@ -20,6 +22,7 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.data.redis.connection.stream.Consumer;
import org.springframework.data.redis.connection.stream.MapRecord;
import org.springframework.data.redis.connection.stream.ReadOffset;
Expand All @@ -36,6 +39,7 @@ public class RedisStreamConsumer {
private final ServerInstanceManager serverInstanceManager;
private final SSEChannelManager sseChannelManager;
private final ObjectMapper objectMapper;
private final ApplicationEventPublisher applicationEventPublisher;

private static final String STREAM_KEY = "ai-response";
private static final String CONSUMER_GROUP_PREFIX = "edukit-server-";
Expand Down Expand Up @@ -106,6 +110,13 @@ private void processMessage(final MapRecord<String, Object, Object> message) {
String taskId = parseData(messageJson, "task_id");
String status = parseData(messageJson, "status");

// 실패 상태 처리
if (AITaskStatus.isFailure(status)) {
AIErrorMessage errorMessage = objectMapper.readValue(messageJson, AIErrorMessage.class);
handleAITaskFailure(taskId, errorMessage);
return;
}

if (sseChannelManager.hasActivateChannel(taskId)) {
if (AITaskStatus.isInProgress(status)) {
AIProgressMessage responseMessage = objectMapper.readValue(messageJson, AIProgressMessage.class);
Expand All @@ -127,6 +138,25 @@ private void processMessage(final MapRecord<String, Object, Object> message) {
}
}

private void handleAITaskFailure(final String taskId, final AIErrorMessage errorMessage) {
log.warn("AI task failed - taskId: {}, errorType: {}, message: {}",
taskId, errorMessage.errorType(), errorMessage.errorMessage());

// 보상 트랜잭션 이벤트 발행 (실패 시 예외를 상위로 전파하여 재처리)
applicationEventPublisher.publishEvent(
AITaskFailedEvent.fromErrorMessage(errorMessage)
);

// SSE로 실패 알림 전송 (실패해도 보상 로직에 영향 없음)
try {
if (sseChannelManager.hasActivateChannel(taskId)) {
sseChannelManager.sendErrorMessage(taskId, errorMessage);
}
} catch (Exception e) {
log.error("Failed to send SSE error message for taskId: {}, continuing with compensation", taskId, e);
}
}

private String parseData(final String messageJson, final String target) {
try {
JsonNode node = objectMapper.readTree(messageJson);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.edukit.core.studentrecord.service;

import com.edukit.common.infra.ServerInstanceManager;
import com.edukit.core.event.ai.dto.AIErrorMessage;
import com.edukit.core.event.ai.dto.AIProgressMessage;
import com.edukit.core.event.ai.dto.AIResponseMessage;
import com.edukit.core.event.ai.dto.SSEMessage;
Expand Down Expand Up @@ -107,6 +108,26 @@ public void sendProgressMessage(final String taskId, final AIProgressMessage aiP
}
}

public void sendErrorMessage(final String taskId, final AIErrorMessage errorMessage) {
SseEmitter emitter = activeChannels.get(taskId);
if (emitter != null) {
try {
SSEMessage sseMessage = SSEMessage.error(taskId, errorMessage.errorType(), errorMessage.errorMessage());
emitter.send(SseEmitter.event()
.name(SSE_EVENT_NAME)
.data(sseMessage));
log.info("Sent error message to SSE channel for taskId: {}", taskId);
} catch (IOException e) {
log.error("Failed to send error message to SSE channel for taskId: {}", taskId, e);
} finally {
// 에러 발생 시 채널 제거
removeChannel(taskId);
}
} else {
log.warn("No active SSE channel for taskId: {}, error message not sent", taskId);
}
}

public void removeChannel(final String taskId) {
SseEmitter emitter = activeChannels.remove(taskId);
if (emitter != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ public enum AITaskStatus {
PHASE1_COMPLETED("PHASE1_COMPLETED", "생기부 초안 생성 완료.. - 다음 단계로 이동"),
PHASE2_STARTED("PHASE2_STARTED", "금칙어 필터링 중.."),
PHASE3_STARTED("PHASE3_STARTED", "바이트 수 최적화 중.."),
COMPLETED("COMPLETED", "생성 완료");
COMPLETED("COMPLETED", "생성 완료"),
FAILED("FAILED", "생성 실패 - 포인트가 복구됩니다");

private final String status;
private final String message;
Expand All @@ -25,6 +26,10 @@ public static AITaskStatus fromStatus(final String status) {
}

public static boolean isInProgress(final String currentStatus) {
return !COMPLETED.getStatus().equals(currentStatus);
return !COMPLETED.getStatus().equals(currentStatus) && !FAILED.getStatus().equals(currentStatus);
}

public static boolean isFailure(final String currentStatus) {
return FAILED.getStatus().equals(currentStatus);
}
}
Loading