580 lines
13 KiB
Go
580 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
// Config holds scanner configuration
|
|
type Config struct {
|
|
Host string
|
|
Port int
|
|
User string
|
|
Password string
|
|
Database string
|
|
Patterns []string
|
|
JSON bool
|
|
Output string
|
|
}
|
|
|
|
// Finding represents a detected issue
|
|
type Finding struct {
|
|
Table string `json:"table"`
|
|
Column string `json:"column"`
|
|
RowID string `json:"row_id"`
|
|
RiskLevel string `json:"risk_level"`
|
|
Pattern string `json:"pattern"`
|
|
Match string `json:"match"`
|
|
Snippet string `json:"snippet"`
|
|
LineNumbers []int `json:"line_numbers,omitempty"`
|
|
}
|
|
|
|
// ScanResult holds complete scan results
|
|
type ScanResult struct {
|
|
Database string `json:"database"`
|
|
Tables int `json:"tables_scanned"`
|
|
Rows int `json:"rows_scanned"`
|
|
Findings []Finding `json:"findings"`
|
|
HighRisk int `json:"high_risk_count"`
|
|
MedRisk int `json:"medium_risk_count"`
|
|
LowRisk int `json:"low_risk_count"`
|
|
}
|
|
|
|
// Pattern defines a detection pattern
|
|
type Pattern struct {
|
|
Name string
|
|
Risk string
|
|
Regex *regexp.Regexp
|
|
Sensitive bool
|
|
}
|
|
|
|
var patterns []Pattern
|
|
|
|
func init() {
|
|
// Initialize detection patterns
|
|
patterns = []Pattern{
|
|
// PHP dangerous functions - HIGH RISK
|
|
{
|
|
Name: "eval_function",
|
|
Risk: "high",
|
|
Regex: regexp.MustCompile(`\beval\s*\(`),
|
|
},
|
|
{
|
|
Name: "assert_function",
|
|
Risk: "high",
|
|
Regex: regexp.MustCompile(`\bassert\s*\(`),
|
|
},
|
|
{
|
|
Name: "create_function",
|
|
Risk: "high",
|
|
Regex: regexp.MustCompile(`\bcreate_function\s*\(`),
|
|
},
|
|
{
|
|
Name: "exec_function",
|
|
Risk: "high",
|
|
Regex: regexp.MustCompile(`\b(?:exec|system|passthru|shell_exec)\s*\(`),
|
|
},
|
|
|
|
// Base64 with decode - HIGH RISK
|
|
{
|
|
Name: "base64_decode",
|
|
Risk: "high",
|
|
Regex: regexp.MustCompile(`\bbase64_decode\s*\(`),
|
|
},
|
|
{
|
|
Name: "gzinflate",
|
|
Risk: "high",
|
|
Regex: regexp.MustCompile(`\bgzinflate\s*\(`),
|
|
},
|
|
{
|
|
Name: "str_rot13",
|
|
Risk: "medium",
|
|
Regex: regexp.MustCompile(`\bstr_rot13\s*\(`),
|
|
},
|
|
|
|
// Long base64 strings - MEDIUM RISK
|
|
{
|
|
Name: "long_base64",
|
|
Risk: "medium",
|
|
Regex: regexp.MustCompile(`[A-Za-z0-9+/]{200,}={0,2}`),
|
|
},
|
|
|
|
// HTML/JS injection - MEDIUM RISK
|
|
{
|
|
Name: "script_tag",
|
|
Risk: "medium",
|
|
Regex: regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`),
|
|
},
|
|
{
|
|
Name: "iframe_tag",
|
|
Risk: "medium",
|
|
Regex: regexp.MustCompile(`(?i)<iframe[^>]*(?:hidden|width=0|height=0)[^>]*>`),
|
|
},
|
|
{
|
|
Name: "javascript_protocol",
|
|
Risk: "medium",
|
|
Regex: regexp.MustCompile(`(?i)javascript:\s*[^\s"'>]+`),
|
|
},
|
|
{
|
|
Name: "event_handler",
|
|
Risk: "medium",
|
|
Regex: regexp.MustCompile(`(?i)\b(?:onerror|onload|onclick|onmouseover)\s*=`),
|
|
},
|
|
|
|
// Web shell patterns - HIGH RISK
|
|
{
|
|
Name: "superglobal_eval",
|
|
Risk: "high",
|
|
Regex: regexp.MustCompile(`(?:\$_(?:GET|POST|REQUEST|COOKIE))\[.*\]\s*\(\s*\$`),
|
|
},
|
|
{
|
|
Name: "error_reporting_suppress",
|
|
Risk: "medium",
|
|
Regex: regexp.MustCompile(`error_reporting\s*\(\s*0\s*\)`),
|
|
},
|
|
|
|
// Variable function calls - MEDIUM RISK
|
|
{
|
|
Name: "variable_function",
|
|
Risk: "medium",
|
|
Regex: regexp.MustCompile(`\$[a-zA-Z_]\w*\s*\(`),
|
|
},
|
|
|
|
// Common malware comment signatures - LOW RISK
|
|
{
|
|
Name: "malware_comment",
|
|
Risk: "low",
|
|
Regex: regexp.MustCompile(`(?i)/\*.*(?:shell|backdoor|webshell|hack|c99|r57).*\*/`),
|
|
},
|
|
|
|
// Obfuscation indicators - MEDIUM RISK
|
|
{
|
|
Name: "hex_encode",
|
|
Risk: "medium",
|
|
Regex: regexp.MustCompile(`\\x[0-9a-fA-F]{2}`),
|
|
},
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
config := parseFlags()
|
|
|
|
if config.User == "" || config.Database == "" {
|
|
log.Fatal("Error: --user and --db are required")
|
|
}
|
|
|
|
// Connect to database
|
|
db, err := connectDB(config)
|
|
if err != nil {
|
|
log.Fatalf("Failed to connect to database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
fmt.Printf("Connected to %s@%s:%d/%s\n", config.User, config.Host, config.Port, config.Database)
|
|
|
|
// Scan database
|
|
result, err := scanDatabase(db, config)
|
|
if err != nil {
|
|
log.Fatalf("Scan failed: %v", err)
|
|
}
|
|
|
|
// Output results
|
|
if config.JSON {
|
|
outputJSON(result, config.Output)
|
|
} else {
|
|
outputHuman(result)
|
|
}
|
|
}
|
|
|
|
func parseFlags() Config {
|
|
host := flag.String("host", getEnv("DB_HOST", "localhost"), "Database host")
|
|
port := flag.Int("port", getEnvInt("DB_PORT", 3306), "Database port")
|
|
user := flag.String("user", getEnv("DB_USER", ""), "Database user")
|
|
password := flag.String("password", getEnv("DB_PASSWORD", ""), "Database password")
|
|
db := flag.String("db", getEnv("DB_NAME", ""), "Database name")
|
|
patterns := flag.String("patterns", "", "Comma-separated patterns to scan (default: all)")
|
|
jsonOut := flag.Bool("json", false, "Output JSON instead of human-readable")
|
|
output := flag.String("output", "", "Output file path (default: stdout)")
|
|
|
|
flag.Parse()
|
|
|
|
var patternList []string
|
|
if *patterns != "" {
|
|
patternList = strings.Split(*patterns, ",")
|
|
for i := range patternList {
|
|
patternList[i] = strings.TrimSpace(patternList[i])
|
|
}
|
|
}
|
|
|
|
return Config{
|
|
Host: *host,
|
|
Port: *port,
|
|
User: *user,
|
|
Password: *password,
|
|
Database: *db,
|
|
Patterns: patternList,
|
|
JSON: *jsonOut,
|
|
Output: *output,
|
|
}
|
|
}
|
|
|
|
func connectDB(config Config) (*sql.DB, error) {
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
|
config.User,
|
|
config.Password,
|
|
config.Host,
|
|
config.Port,
|
|
config.Database,
|
|
)
|
|
|
|
db, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func scanDatabase(db *sql.DB, config Config) (*ScanResult, error) {
|
|
result := &ScanResult{
|
|
Database: config.Database,
|
|
Findings: []Finding{},
|
|
}
|
|
|
|
// Get all tables
|
|
tables, err := getTables(db, config.Database)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get tables: %w", err)
|
|
}
|
|
|
|
fmt.Printf("Found %d tables\n", len(tables))
|
|
|
|
for _, table := range tables {
|
|
fmt.Printf("Scanning table: %s\n", table)
|
|
|
|
columns, err := getTextColumns(db, config.Database, table)
|
|
if err != nil {
|
|
log.Printf("Warning: failed to get columns for %s: %v", table, err)
|
|
continue
|
|
}
|
|
|
|
if len(columns) == 0 {
|
|
fmt.Printf(" No text columns, skipping\n")
|
|
continue
|
|
}
|
|
|
|
fmt.Printf(" Scanning %d text columns: %v\n", len(columns), columns)
|
|
|
|
rows, err := scanTable(db, table, columns, config.Patterns)
|
|
if err != nil {
|
|
log.Printf("Warning: failed to scan table %s: %v", table, err)
|
|
continue
|
|
}
|
|
|
|
result.Findings = append(result.Findings, rows...)
|
|
result.Tables++
|
|
result.Rows += len(rows)
|
|
}
|
|
|
|
// Count risk levels
|
|
for _, f := range result.Findings {
|
|
switch f.RiskLevel {
|
|
case "high":
|
|
result.HighRisk++
|
|
case "medium":
|
|
result.MedRisk++
|
|
case "low":
|
|
result.LowRisk++
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func getTables(db *sql.DB, dbName string) ([]string, error) {
|
|
query := `SELECT table_name FROM information_schema.tables
|
|
WHERE table_schema = ? AND table_type = 'BASE TABLE'`
|
|
|
|
rows, err := db.Query(query, dbName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tables []string
|
|
for rows.Next() {
|
|
var table string
|
|
if err := rows.Scan(&table); err != nil {
|
|
return nil, err
|
|
}
|
|
tables = append(tables, table)
|
|
}
|
|
|
|
return tables, nil
|
|
}
|
|
|
|
func getTextColumns(db *sql.DB, dbName, table string) ([]string, error) {
|
|
query := `SELECT column_name FROM information_schema.columns
|
|
WHERE table_schema = ? AND table_name = ?
|
|
AND data_type IN ('text', 'longtext', 'mediumtext', 'varchar', 'char')
|
|
ORDER BY ordinal_position`
|
|
|
|
rows, err := db.Query(query, dbName, table)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var columns []string
|
|
for rows.Next() {
|
|
var col string
|
|
if err := rows.Scan(&col); err != nil {
|
|
return nil, err
|
|
}
|
|
columns = append(columns, col)
|
|
}
|
|
|
|
return columns, nil
|
|
}
|
|
|
|
func scanTable(db *sql.DB, table string, columns, activePatterns []string) ([]Finding, error) {
|
|
var findings []Finding
|
|
|
|
// Build SELECT query with ID column if exists
|
|
query, idColumn := buildSelectQuery(table, columns)
|
|
|
|
rows, err := db.Query(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
// Get column names for scanning
|
|
colNames, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for rows.Next() {
|
|
// Scan row into values
|
|
values := make([]interface{}, len(colNames))
|
|
valuePtrs := make([]interface{}, len(colNames))
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
continue
|
|
}
|
|
|
|
// Get row identifier
|
|
rowID := ""
|
|
if idColumn != "" {
|
|
for i, name := range colNames {
|
|
if name == idColumn && values[i] != nil {
|
|
rowID = fmt.Sprintf("%v", values[i])
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Scan each text column
|
|
for i, name := range colNames {
|
|
if values[i] == nil {
|
|
continue
|
|
}
|
|
|
|
// Skip ID column
|
|
if name == idColumn {
|
|
continue
|
|
}
|
|
|
|
content := fmt.Sprintf("%v", values[i])
|
|
if len(content) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Run patterns
|
|
for _, pattern := range patterns {
|
|
// Filter patterns if specified
|
|
if len(activePatterns) > 0 && !contains(activePatterns, pattern.Name) {
|
|
continue
|
|
}
|
|
|
|
if pattern.Regex.MatchString(content) {
|
|
matches := pattern.Regex.FindAllString(content, 10)
|
|
|
|
for _, match := range matches {
|
|
finding := Finding{
|
|
Table: table,
|
|
Column: name,
|
|
RowID: rowID,
|
|
RiskLevel: pattern.Risk,
|
|
Pattern: pattern.Name,
|
|
Match: match,
|
|
Snippet: truncate(content, 200),
|
|
}
|
|
|
|
findings = append(findings, finding)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return findings, nil
|
|
}
|
|
|
|
func buildSelectQuery(table string, columns []string) (string, string) {
|
|
// Try to find an ID column for row identification
|
|
idColumns := []string{"id", "ID", "post_id", "comment_id", "user_id", "option_id", "term_id"}
|
|
var idColumn string
|
|
|
|
// For now, we'll select all columns and figure out ID during scan
|
|
selectCols := "*"
|
|
if len(columns) > 0 {
|
|
// Include ID column if we can guess it
|
|
for _, idCol := range idColumns {
|
|
if contains(columns, idCol) {
|
|
idColumn = idCol
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
query := fmt.Sprintf("SELECT %s FROM %s", selectCols, table)
|
|
return query, idColumn
|
|
}
|
|
|
|
func outputHuman(result *ScanResult) {
|
|
fmt.Printf("\n=== SCAN RESULTS ===\n")
|
|
fmt.Printf("Database: %s\n", result.Database)
|
|
fmt.Printf("Tables scanned: %d\n", result.Tables)
|
|
fmt.Printf("Findings: %d total (High: %d, Medium: %d, Low: %d)\n\n",
|
|
len(result.Findings), result.HighRisk, result.MedRisk, result.LowRisk)
|
|
|
|
if len(result.Findings) == 0 {
|
|
fmt.Println("✓ No suspicious patterns detected")
|
|
return
|
|
}
|
|
|
|
// Group by risk level
|
|
high := []Finding{}
|
|
medium := []Finding{}
|
|
low := []Finding{}
|
|
|
|
for _, f := range result.Findings {
|
|
switch f.RiskLevel {
|
|
case "high":
|
|
high = append(high, f)
|
|
case "medium":
|
|
medium = append(medium, f)
|
|
case "low":
|
|
low = append(low, f)
|
|
}
|
|
}
|
|
|
|
// Print high risk first
|
|
if len(high) > 0 {
|
|
fmt.Printf("🔴 HIGH RISK (%d)\n", len(high))
|
|
printFindings(high, 5)
|
|
}
|
|
|
|
if len(medium) > 0 {
|
|
fmt.Printf("\n🟡 MEDIUM RISK (%d)\n", len(medium))
|
|
printFindings(medium, 3)
|
|
}
|
|
|
|
if len(low) > 0 {
|
|
fmt.Printf("\n🟢 LOW RISK (%d)\n", len(low))
|
|
printFindings(low, 2)
|
|
}
|
|
}
|
|
|
|
func printFindings(findings []Finding, maxDisplay int) {
|
|
display := min(maxDisplay, len(findings))
|
|
for i := 0; i < display; i++ {
|
|
f := findings[i]
|
|
fmt.Printf("\n [%s] %s.%s", f.RiskLevel, f.Table, f.Column)
|
|
if f.RowID != "" {
|
|
fmt.Printf(" (ID: %s)", f.RowID)
|
|
}
|
|
fmt.Printf("\n Pattern: %s\n", f.Pattern)
|
|
fmt.Printf(" Match: %s\n", truncate(f.Match, 100))
|
|
fmt.Printf(" Snippet: %s\n", f.Snippet)
|
|
}
|
|
|
|
if len(findings) > maxDisplay {
|
|
fmt.Printf("\n ... and %d more\n", len(findings)-maxDisplay)
|
|
}
|
|
}
|
|
|
|
func outputJSON(result *ScanResult, outputPath string) {
|
|
data, err := json.MarshalIndent(result, "", " ")
|
|
if err != nil {
|
|
log.Fatalf("Failed to marshal JSON: %v", err)
|
|
}
|
|
|
|
if outputPath != "" {
|
|
if err := os.WriteFile(outputPath, data, 0644); err != nil {
|
|
log.Fatalf("Failed to write output file: %v", err)
|
|
}
|
|
fmt.Printf("Results written to %s\n", outputPath)
|
|
} else {
|
|
fmt.Println(string(data))
|
|
}
|
|
}
|
|
|
|
// Utility functions
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if value := os.Getenv(key); value != "" {
|
|
return value
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvInt(key string, fallback int) int {
|
|
if value := os.Getenv(key); value != "" {
|
|
var i int
|
|
fmt.Sscanf(value, "%d", &i)
|
|
return i
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func contains(slice []string, item string) bool {
|
|
for _, s := range slice {
|
|
if strings.EqualFold(s, item) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func truncate(s string, maxLen int) string {
|
|
if len(s) <= maxLen {
|
|
return s
|
|
}
|
|
return s[:maxLen] + "..."
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|